diff --git a/nanochat/adamw.py b/nanochat/adamw.py index db591de..2cbf1d4 100644 --- a/nanochat/adamw.py +++ b/nanochat/adamw.py @@ -27,6 +27,7 @@ class DistAdamW(torch.optim.Optimizer): for group in self.param_groups: params: list[Tensor] = group["params"] for base_i in range(len(params)): + assert params[base_i].shape[0] % world_size == 0, f"Parameter shape {params[base_i].shape} must be divisible by world size {world_size}" grad = params[base_i].grad rank_size = grad.shape[0] // world_size grad_slice = torch.empty_like(grad[:rank_size]) diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 6be9820..4136802 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -26,6 +26,7 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() def document_batches(): parquet_paths = list_parquet_files() + assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?" parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None diff --git a/nanochat/gpt.py b/nanochat/gpt.py index ba465d7..501cc4a 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -135,14 +135,19 @@ class Block(nn.Module): class GPT(nn.Module): - def __init__(self, config): + def __init__(self, config, pad_vocab_size_to=64): super().__init__() self.config = config + # For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see: + # https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings + padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to + if padded_vocab_size != config.vocab_size: + print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}") self.transformer = nn.ModuleDict({ - "wte": nn.Embedding(config.vocab_size, config.n_embd), + "wte": nn.Embedding(padded_vocab_size, config.n_embd), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), }) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False) # To support meta device initialization, we init the rotary embeddings here, but it's fake # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them, but assert fail if we ever reach that amount. @@ -220,8 +225,7 @@ class GPT(nn.Module): # Create the AdamW optimizer for the embedding and lm_head # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 - if rank == 0: - print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") + print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") adam_groups = [ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), @@ -260,7 +264,9 @@ class GPT(nn.Module): # Forward the lm_head (compute logits) softcap = 15 # smoothly cap the logits to the range [-softcap, softcap] - logits = self.lm_head(x) # (B, T, vocab_size) <- very big tensor, large amount of memory + logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory + # slice to remove padding + logits = logits[..., :self.config.vocab_size] logits = logits.float() # switch to fp32 for logit softcap and loss computation logits = softcap * torch.tanh(logits / softcap) # squash the logits