mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-20 18:34:14 +00:00
feat: pad vocab size to 64 for DDP optimizers and efficiency
This commit is contained in:
parent
d5759400f9
commit
f1bf69d562
|
|
@ -27,6 +27,7 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for base_i in range(len(params)):
|
||||
assert params[base_i].shape[0] % world_size == 0, f"Parameter shape {params[base_i].shape} must be divisible by world size {world_size}"
|
||||
grad = params[base_i].grad
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
|
|||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
def document_batches():
|
||||
parquet_paths = list_parquet_files()
|
||||
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
|
|
|
|||
|
|
@ -135,14 +135,19 @@ class Block(nn.Module):
|
|||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, pad_vocab_size_to=64):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see:
|
||||
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
||||
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
||||
if padded_vocab_size != config.vocab_size:
|
||||
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}")
|
||||
self.transformer = nn.ModuleDict({
|
||||
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
||||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's fake
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them, but assert fail if we ever reach that amount.
|
||||
|
|
@ -220,8 +225,7 @@ class GPT(nn.Module):
|
|||
# Create the AdamW optimizer for the embedding and lm_head
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
if rank == 0:
|
||||
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
adam_groups = [
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
|
|
@ -260,7 +264,9 @@ class GPT(nn.Module):
|
|||
|
||||
# Forward the lm_head (compute logits)
|
||||
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
||||
logits = self.lm_head(x) # (B, T, vocab_size) <- very big tensor, large amount of memory
|
||||
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
||||
# slice to remove padding
|
||||
logits = logits[..., :self.config.vocab_size]
|
||||
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
||||
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user