From 48abd7d85f3b7a06fe8a457de2353047cff3d951 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 1 Jan 2026 21:14:26 +0000 Subject: [PATCH] simplify, clarify and slightly tune model initialization. should be very slightly better possibly, but certainly a lot clearer --- nanochat/gpt.py | 55 ++++++++++++++++++++++++++----------------- scripts/base_train.py | 5 ++-- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 69899ee..e6027a9 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -146,9 +146,9 @@ class GPT(nn.Module): "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), }) self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False) - # To support meta device initialization, we init the rotary embeddings here, but it's fake + # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, - # so let's just over-compute them, but assert fail if we ever reach that amount. + # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. # In the future we can dynamically grow the cache, for now it's fine. self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer? head_dim = config.n_embd // config.n_head @@ -157,35 +157,46 @@ class GPT(nn.Module): self.register_buffer("sin", sin, persistent=False) def init_weights(self): - self.apply(self._init_weights) - # zero out classifier weights - torch.nn.init.zeros_(self.lm_head.weight) - # zero out c_proj weights in all blocks + """ + Initialize the full model in this one function for maximum clarity. + + wte (embedding): normal, std=1.0 + lm_head: normal, std=0.001 + for each block: + attn.c_q: uniform, std=1/sqrt(n_embd) + attn.c_k: uniform, std=1/sqrt(n_embd) + attn.c_v: uniform, std=1/sqrt(n_embd) + attn.c_proj: zeros + mlp.c_fc: uniform, std=1/sqrt(n_embd) + mlp.c_proj: zeros + """ + + # Embedding and unembedding + torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0) + torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001) + + # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal) + n_embd = self.config.n_embd + s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal for block in self.transformer.h: + torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers + torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) + torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) + torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero + torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s) torch.nn.init.zeros_(block.mlp.c_proj.weight) - torch.nn.init.zeros_(block.attn.c_proj.weight) - # init the rotary embeddings + + # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.cos, self.sin = cos, sin - # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations + + # Cast token embeddings to bf16: optimizer can tolerate it and it saves memory if self.transformer.wte.weight.device.type == "cuda": self.transformer.wte.to(dtype=torch.bfloat16) - def _init_weights(self, module): - if isinstance(module, nn.Linear): - # https://arxiv.org/pdf/2310.17813 - fan_out = module.weight.size(0) - fan_in = module.weight.size(1) - std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) - torch.nn.init.normal_(module.weight, mean=0.0, std=std) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) - - # TODO: bump base theta more, e.g. 100K is more common more recently def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): + # TODO: bump base theta more? e.g. 100K is more common more recently # autodetect the device from model embeddings if device is None: device = self.transformer.wte.weight.device diff --git a/scripts/base_train.py b/scripts/base_train.py index afa3b7a..4f66eb0 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -112,10 +112,11 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: { # Create a new model with random weights model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) with torch.device("meta"): + # All tensors are created as meta tensors (they have shape/dtype but no data) model_config = GPTConfig(**model_config_kwargs) model = GPT(model_config) -model.to_empty(device=device) -model.init_weights() +model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data +model.init_weights() # All tensors get initialized # If we are resuming, overwrite the model parameters with those of the checkpoint base_dir = get_base_dir()