simplify, clarify and slightly tune model initialization. should be very slightly better possibly, but certainly a lot clearer

This commit is contained in:
Andrej Karpathy 2026-01-01 21:14:26 +00:00
parent 10231dfb40
commit 48abd7d85f
2 changed files with 36 additions and 24 deletions

View File

@ -146,9 +146,9 @@ class GPT(nn.Module):
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
# To support meta device initialization, we init the rotary embeddings here, but it's fake
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them, but assert fail if we ever reach that amount.
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
# In the future we can dynamically grow the cache, for now it's fine.
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
head_dim = config.n_embd // config.n_head
@ -157,35 +157,46 @@ class GPT(nn.Module):
self.register_buffer("sin", sin, persistent=False)
def init_weights(self):
self.apply(self._init_weights)
# zero out classifier weights
torch.nn.init.zeros_(self.lm_head.weight)
# zero out c_proj weights in all blocks
"""
Initialize the full model in this one function for maximum clarity.
wte (embedding): normal, std=1.0
lm_head: normal, std=0.001
for each block:
attn.c_q: uniform, std=1/sqrt(n_embd)
attn.c_k: uniform, std=1/sqrt(n_embd)
attn.c_v: uniform, std=1/sqrt(n_embd)
attn.c_proj: zeros
mlp.c_fc: uniform, std=1/sqrt(n_embd)
mlp.c_proj: zeros
"""
# Embedding and unembedding
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
n_embd = self.config.n_embd
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
for block in self.transformer.h:
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
torch.nn.init.zeros_(block.mlp.c_proj.weight)
torch.nn.init.zeros_(block.attn.c_proj.weight)
# init the rotary embeddings
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
# Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
# https://arxiv.org/pdf/2310.17813
fan_out = module.weight.size(0)
fan_in = module.weight.size(1)
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
# TODO: bump base theta more, e.g. 100K is more common more recently
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
# autodetect the device from model embeddings
if device is None:
device = self.transformer.wte.weight.device

View File

@ -112,10 +112,11 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {
# Create a new model with random weights
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
with torch.device("meta"):
# All tensors are created as meta tensors (they have shape/dtype but no data)
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data
model.init_weights() # All tensors get initialized
# If we are resuming, overwrite the model parameters with those of the checkpoint
base_dir = get_base_dir()