diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index b7d2191..63f257f 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -20,33 +20,32 @@ def log0(message): if int(os.environ.get('RANK', 0)) == 0: logger.info(message) -def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): - assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now - os.makedirs(checkpoint_dir, exist_ok=True) - # Save the model state (parameters) - model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") - torch.save(model_data, model_path) - log0(f"Saved model file to: {model_path}") - # Save the optimizer state (useful for SFT or any other fine-tuning) +def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): + if rank == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + # Save the model state parameters + model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") + torch.save(model_data, model_path) + logger.info(f"Saved model parameters to: {model_path}") + # Save the metadata dict as json + meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") + with open(meta_path, "w", encoding="utf-8") as f: + json.dump(meta_data, f, indent=2) + logger.info(f"Saved metadata to: {meta_path}") + # Note that optimizer state is sharded across ranks, so each rank must save its own. if optimizer_data is not None: - optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") torch.save(optimizer_data, optimizer_path) - log0(f"Saved optimizer file to: {optimizer_path}") - # Save the metadata dict as json - meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") - with open(meta_path, "w", encoding="utf-8") as f: - json.dump(meta_data, f, indent=2) - log0(f"Saved metadata file to: {meta_path}") + logger.info(f"Saved optimizer state to: {optimizer_path}") - -def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False): +def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): # Load the model state model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") model_data = torch.load(model_path, map_location=device) # Load the optimizer state if requested optimizer_data = None if load_optimizer: - optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") optimizer_data = torch.load(optimizer_path, map_location=device) # Load the metadata meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") diff --git a/nanochat/common.py b/nanochat/common.py index d4a9828..8f36f94 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -148,6 +148,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" # Reproducibility + # Note that we set the global seeds here, but most of the code uses explicit rng objects. + # The only place where global rng might be used is nn.Module initialization of the model weights. torch.manual_seed(42) if device_type == "cuda": torch.cuda.manual_seed(42) diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 6c864d3..3271298 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -1,49 +1,87 @@ from collections import deque import torch +import pyarrow.parquet as pq from nanochat.common import get_dist_info -from nanochat.dataset import parquets_iter_batched +from nanochat.dataset import list_parquet_files from nanochat.tokenizer import get_tokenizer -def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"): - """Stream pretraining text from parquet files, tokenize, yield training batches.""" +def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None): + """ + Stream pretraining text from parquet files, tokenize, yield training batches. + + This implementation became a bit more complex because we wish to support approximate resume training. + Instead of turning this into a Class, we opt to return the state_dict with every batch, + and then the caller can pass in a state_dict to resume training from a desired point. + Note that this resumption is atm only *approximate* for simplicity. + We won't repeat the same documents but we might skip a few. + The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume. + + Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm. + """ assert split in ["train", "val"], "split must be 'train' or 'val'" + + # infinite iterator over document batches (list of text strings) ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + def document_batches(): + parquet_paths = list_parquet_files() + parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] + resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 + resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None + pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0) + while True: # iterate infinitely (multi-epoch) + while pq_idx < len(parquet_paths): # iterate over all parquet files + filepath = parquet_paths[pq_idx] + pf = pq.ParquetFile(filepath) + # Start from resume point if resuming on same file, otherwise from DDP rank + # I know this state resumption is a little bit tricky and a little bit hacky... sigh. + if resume_rg_idx is not None: + base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size + base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming + rg_idx = base_idx * ddp_world_size + ddp_rank + resume_rg_idx = None # set to None as we only want to do this a single time + else: + rg_idx = ddp_rank + while rg_idx < pf.num_row_groups: + rg = pf.read_row_group(rg_idx) + batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows + # the tokenizer encode might want to go in even smaller batches, e.g. 128 rows + for i in range(0, len(batch), tokenizer_batch_size): + yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx) + rg_idx += ddp_world_size # advance to the next row group (in DDP) + pq_idx += 1 # advance to the next parquet file + batches = document_batches() + + # Now emit batches of tokens. needed_tokens = B * T + 1 # +1 is because we also need the target at the last token # get the tokenizer and the bos token tokenizer = get_tokenizer() bos_token = tokenizer.get_bos_token_id() # scratch buffer holds the tokens for one iteration token_buffer = deque() # we stream tokens on the right and pop from the left - - # infinite iterator over document batches - def document_batches(): - while True: - # batch will iterate in group size of the parquet files, usually e.g. 1024 rows - for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size): - # for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows - for i in range(0, len(batch), tokenizer_batch_size): - yield batch[i:i+tokenizer_batch_size] - batches = document_batches() - - batch_index = 0 while True: # Accumulate enough tokens for one iteration before yielding. while len(token_buffer) < needed_tokens: - doc_batch = next(batches) + doc_batch, (pq_idx, rg_idx) = next(batches) token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) for tokens in token_lists: token_buffer.extend(tokens) - batch_index += 1 # Move tokens from the deque into the scratch buffer tokens = [token_buffer.popleft() for _ in range(needed_tokens)] - # CUDA supports memory pinning for faster transfers between CPU and GPU: - scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda")) + # CUDA supports memory pinning for asynchronous transfers between CPU and GPU + use_cuda_optimizations = device == "cuda" + scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64 # Create the inputs/targets as 1D tensors - inputs_cpu = scratch[:-1].to(dtype=torch.int32) + inputs_cpu = scratch[:-1] targets_cpu = scratch[1:] # Reshape to 2D and move to GPU async - inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True) - targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True) + inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) + targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) + state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training + yield inputs, targets, state_dict + +def tokenizing_distributed_data_loader(*args, **kwargs): + # helper function that only emits the inputs/targets and not the state_dict + for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs): yield inputs, targets diff --git a/scripts/base_train.py b/scripts/base_train.py index 594c709..c9ea6c9 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -20,10 +20,10 @@ import wandb import torch from nanochat.gpt import GPT, GPTConfig -from nanochat.dataloader import tokenizing_distributed_data_loader +from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type from nanochat.tokenizer import get_tokenizer, get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint +from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine from scripts.base_eval import evaluate_model @@ -52,12 +52,14 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled) warmup_ratio = 0.0 # ratio of iterations for LR warmup warmdown_ratio = 0.2 # ratio of iterations for LR warmdown final_lr_frac = 0.0 # final LR is this fraction of the initial LR +resume_from_step = -1 # resume training from this step of the optimization (-1 = disable) # Evaluation eval_every = 250 # every how many steps to evaluate the model for val bpb eval_tokens = 20*524288 # number of tokens to evaluate val loss on core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable) core_metric_max_per_task = 500 # examples per task in estimating the core metric sample_every = 2000 # every how many steps to sample from the model +save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run) # Output model_tag = "" # optionally override the model tag for the output checkpoint directory name # now allow CLI to override the settings via the configurator lol @@ -103,16 +105,31 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") + # ----------------------------------------------------------------------------- # Initialize the Model + +# Create a new model with random weights model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) with torch.device("meta"): model_config = GPTConfig(**model_config_kwargs) model = GPT(model_config) model.to_empty(device=device) model.init_weights() -orig_model = model # original, uncompiled model, for saving raw model state_dict -model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through + +# If we are resuming, overwrite the model parameters with those of the checkpoint +base_dir = get_base_dir() +output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12 +checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) +resuming = resume_from_step != -1 +if resuming: + print0(f"Resuming optimization from step {resume_from_step}") + model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank) + model.load_state_dict(model_data, strict=True, assign=True) + del model_data # free up this memory after the copy + +orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) +model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe num_params = sum(p.numel() for p in model.parameters()) print0(f"Number of parameters: {num_params:,}") num_flops_per_token = model.estimate_flops() @@ -143,12 +160,18 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) adamw_optimizer, muon_optimizer = optimizers +if resuming: + for opt, dat in zip(optimizers, optimizer_data): + opt.load_state_dict(dat) + del optimizer_data # free up the memory + +# ----------------------------------------------------------------------------- # Initialize the DataLoaders for train/val -base_dir = get_base_dir() tokens_dir = os.path.join(base_dir, "tokenized_data") -train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device) +dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] +train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device) -x, y = next(train_loader) # kick off load of the very first batch of data +x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data # ----------------------------------------------------------------------------- # Set up hyperparameter schedulers @@ -171,15 +194,25 @@ def get_muon_momentum(it): momentum = (1 - frac) * 0.85 + frac * 0.95 return momentum +# ----------------------------------------------------------------------------- +# Loop state (variables updated by the training loop) + +if not resuming: + step = 0 + min_val_bpb = float("inf") + smooth_train_loss = 0 # EMA of training loss + total_training_time = 0 # total wall-clock time of training +else: + step = meta_data["step"] + loop_state = meta_data["loop_state"] + min_val_bpb = loop_state["min_val_bpb"] + smooth_train_loss = loop_state["smooth_train_loss"] + total_training_time = loop_state["total_training_time"] + # ----------------------------------------------------------------------------- # Training loop -min_val_bpb = float("inf") -smooth_train_loss = 0 # EMA of training loss -ema_beta = 0.9 # EMA decay factor -total_training_time = 0 # total wall-clock time of training -# note that we run +1 steps only so that we can eval and save at the end -for step in range(num_iterations + 1): - last_step = step == num_iterations +while True: + last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end flops_so_far = num_flops_per_token * total_batch_size * step # once in a while: evaluate the val bpb (all ranks participate) @@ -237,25 +270,31 @@ for step in range(num_iterations + 1): print0(tokenizer.decode(sample[0])) model.train() - # save checkpoint at the end of the run (only on master process) - if master_process and last_step: - output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12 - checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) + # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step + if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0): save_checkpoint( checkpoint_dir, step, - orig_model.state_dict(), - [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly - { + orig_model.state_dict(), # model parameters + [opt.state_dict() for opt in optimizers], # optimizer states + { # metadata saved as json "step": step, "val_bpb": val_bpb, # loss at last step "model_config": model_config_kwargs, "user_config": user_config, # inputs to the training script "device_batch_size": device_batch_size, "max_seq_len": max_seq_len, - } + "dataloader_state_dict": dataloader_state_dict, + "loop_state": { # all loop state (other than step) so that we can resume training + "min_val_bpb": min_val_bpb, + "smooth_train_loss": smooth_train_loss, + "total_training_time": total_training_time, + }, + }, + rank=ddp_rank, ) + # termination conditions (TODO: possibly also add loss explosions etc.) if last_step: break @@ -270,7 +309,7 @@ for step in range(num_iterations + 1): train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() - x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward # gradient clipping grad_clip_enabled = grad_clip > 0.0 if grad_clip_enabled: @@ -293,6 +332,7 @@ for step in range(num_iterations + 1): # ------------------------------------------------------------------------- # logging + ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA pct_done = 100 * step / num_iterations @@ -319,6 +359,9 @@ for step in range(num_iterations + 1): log_data["train/grad_norm"] = grad_norm wandb_run.log(log_data) + # state update + step += 1 + # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m")