feat: pad vocab size to 64 for DDP optimizers and efficiency

This commit is contained in:
Matěj Kripner 2025-12-09 12:38:18 +01:00
parent d5759400f9
commit f1bf69d562
3 changed files with 14 additions and 6 deletions

View File

@ -27,6 +27,7 @@ class DistAdamW(torch.optim.Optimizer):
for group in self.param_groups:
params: list[Tensor] = group["params"]
for base_i in range(len(params)):
assert params[base_i].shape[0] % world_size == 0, f"Parameter shape {params[base_i].shape} must be divisible by world size {world_size}"
grad = params[base_i].grad
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])

View File

@ -26,6 +26,7 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
def document_batches():
parquet_paths = list_parquet_files()
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None

View File

@ -135,14 +135,19 @@ class Block(nn.Module):
class GPT(nn.Module):
def __init__(self, config):
def __init__(self, config, pad_vocab_size_to=64):
super().__init__()
self.config = config
# For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see:
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
if padded_vocab_size != config.vocab_size:
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}")
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
# To support meta device initialization, we init the rotary embeddings here, but it's fake
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them, but assert fail if we ever reach that amount.
@ -220,8 +225,7 @@ class GPT(nn.Module):
# Create the AdamW optimizer for the embedding and lm_head
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
if rank == 0:
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
@ -260,7 +264,9 @@ class GPT(nn.Module):
# Forward the lm_head (compute logits)
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
logits = self.lm_head(x) # (B, T, vocab_size) <- very big tensor, large amount of memory
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
# slice to remove padding
logits = logits[..., :self.config.vocab_size]
logits = logits.float() # switch to fp32 for logit softcap and loss computation
logits = softcap * torch.tanh(logits / softcap) # squash the logits