mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 13:15:21 +00:00
perf: architecture + optimizer optimizations — 94.6 min to GPT-2 (4.3% speedup)
Two rounds of WeCo-guided D12 optimization, validated on D24. Key changes: smaller sliding windows (seq/8), VE every 3rd layer, RoPE 200K, smear removed, exponential residual decay, optimizer buffer pre-allocation. Mean CORE=0.2591 across 3 D24 runs.
This commit is contained in:
parent
5019accc5b
commit
741e54f360
|
|
@ -47,12 +47,27 @@ class Linear(nn.Linear):
|
|||
Replaces autocast: master weights stay fp32 for optimizer precision,
|
||||
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
|
||||
def forward(self, x):
|
||||
return F.linear(x, self.weight.to(dtype=x.dtype))
|
||||
w = self.weight
|
||||
if w.dtype != x.dtype:
|
||||
w = w.to(dtype=x.dtype)
|
||||
return F.linear(x, w)
|
||||
|
||||
|
||||
class EmbeddingLinear(nn.Module):
|
||||
"""Lightweight linear layer for lm_head without redundant dtype casting."""
|
||||
def __init__(self, in_features, out_features, bias=False, device=None, dtype=None):
|
||||
super().__init__()
|
||||
assert not bias
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
|
||||
def forward(self, x):
|
||||
return F.linear(x, self.weight)
|
||||
|
||||
|
||||
def has_ve(layer_idx, n_layer):
|
||||
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
||||
return layer_idx % 2 == (n_layer - 1) % 2
|
||||
"""Returns True if GPT layer should have Value Embedding (every 3rd layer, last layer always included)."""
|
||||
return layer_idx % 3 == (n_layer - 1) % 3
|
||||
|
||||
def apply_rotary_emb(x, cos, sin):
|
||||
assert x.ndim == 4 # multihead attention
|
||||
|
|
@ -172,19 +187,16 @@ class GPT(nn.Module):
|
|||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
self.lm_head = EmbeddingLinear(config.n_embd, padded_vocab_size, bias=False)
|
||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# Smear: mix previous token's embedding into current token (cheap bigram-like info)
|
||||
self.smear_gate = Linear(24, 1, bias=False)
|
||||
self.smear_lambda = nn.Parameter(torch.zeros(1))
|
||||
# Backout: subtract cached mid-layer residual before final norm to remove low-level features
|
||||
self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
|
||||
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
||||
# Value embeddings (ResFormer-style): every 3rd layer, last layer always included
|
||||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
||||
|
|
@ -224,19 +236,27 @@ class GPT(nn.Module):
|
|||
for block in self.transformer.h:
|
||||
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
||||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||
torch.nn.init.uniform_(block.attn.c_v.weight, -0.85 * s, 0.85 * s)
|
||||
torch.nn.init.uniform_(block.attn.c_proj.weight, -0.008, 0.008) # small nonzero init
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
|
||||
# Per-layer scalars
|
||||
# Per-layer resid init: stronger residual at early layers, weaker at deep layers
|
||||
# Per-layer resid init: exponential decay, stronger at early layers
|
||||
import math
|
||||
n_layer = self.config.n_layer
|
||||
resid_start, resid_end = 1.18, 1.06
|
||||
resid_decay = math.log(resid_start / resid_end) / max(n_layer - 1, 1)
|
||||
for i in range(n_layer):
|
||||
self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1))
|
||||
# Decaying x0 init: earlier layers get more input embedding blending
|
||||
self.resid_lambdas.data[i] = resid_start * math.exp(-resid_decay * i)
|
||||
# x0 init: first-half only, linearly decaying, zero for deep layers
|
||||
half_depth = max(1, n_layer // 2)
|
||||
for i in range(n_layer):
|
||||
self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1))
|
||||
if i < half_depth:
|
||||
frac = i / max(half_depth - 1, 1)
|
||||
self.x0_lambdas.data[i] = 0.24 * (1.0 - frac) + 0.08 * frac
|
||||
else:
|
||||
self.x0_lambdas.data[i] = 0.0
|
||||
|
||||
# Value embeddings (init like c_v: uniform with same std)
|
||||
for ve in self.value_embeds.values():
|
||||
|
|
@ -257,10 +277,11 @@ class GPT(nn.Module):
|
|||
# because GradScaler cannot unscale fp16 gradients.
|
||||
if COMPUTE_DTYPE != torch.float16:
|
||||
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
|
||||
self.lm_head.to(dtype=COMPUTE_DTYPE)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=COMPUTE_DTYPE)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=200000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
# autodetect the device from model embeddings
|
||||
if device is None:
|
||||
|
|
@ -292,7 +313,7 @@ class GPT(nn.Module):
|
|||
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
||||
# Map characters to window sizes
|
||||
long_window = config.sequence_len
|
||||
short_window = -(-long_window // 4 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
|
||||
short_window = max(256, -(-long_window // 8 // 128) * 128) # ceil to FA3 tile size (2048 -> 256)
|
||||
char_to_window = {
|
||||
"L": (long_window, 0),
|
||||
"S": (short_window, 0),
|
||||
|
|
@ -326,7 +347,7 @@ class GPT(nn.Module):
|
|||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
|
||||
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
|
||||
self.backout_lambda.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
|
|
@ -354,7 +375,7 @@ class GPT(nn.Module):
|
|||
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
||||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.backout_lambda.numel()
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
return {
|
||||
|
|
@ -377,8 +398,8 @@ class GPT(nn.Module):
|
|||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
|
||||
backout_params = [self.backout_lambda]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(backout_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
|
|
@ -392,7 +413,7 @@ class GPT(nn.Module):
|
|||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
|
||||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=backout_params, lr=0.15, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
|
||||
]
|
||||
# Muon groups (matrix params, grouped by shape for stacking)
|
||||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
|
|
@ -424,25 +445,6 @@ class GPT(nn.Module):
|
|||
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
|
||||
x = norm(x)
|
||||
|
||||
# Smear: mix previous token's embedding into current position (cheap bigram info)
|
||||
if kv_cache is None:
|
||||
# Training / naive generate: full sequence available, use fast slice
|
||||
assert T > 1, "Training forward pass should have T > 1"
|
||||
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
|
||||
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
|
||||
else:
|
||||
# KV cache inference: read prev embedding from cache, store current for next step
|
||||
x_pre_smear = kv_cache.prev_embedding
|
||||
kv_cache.prev_embedding = x[:, -1:, :]
|
||||
if T > 1:
|
||||
# Prefill: apply smear to positions 1+, same as training
|
||||
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
|
||||
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
|
||||
elif x_pre_smear is not None:
|
||||
# Decode: single token, use cached prev embedding
|
||||
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
|
||||
x = x + gate * x_pre_smear
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
n_layer = self.config.n_layer
|
||||
|
|
|
|||
|
|
@ -253,9 +253,12 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Stack grads and params (NOTE: this assumes all params have the same shape)
|
||||
stacked_grads = torch.stack([p.grad for p in params])
|
||||
stacked_params = torch.stack(params)
|
||||
# Stack grads and params using pre-allocated buffers (NOTE: this assumes all params have the same shape)
|
||||
stacked_grads = torch.empty(num_params, *shape, dtype=dtype, device=device)
|
||||
stacked_params = torch.empty(num_params, *shape, dtype=dtype, device=device)
|
||||
for i, param in enumerate(params):
|
||||
stacked_grads[i].copy_(param.grad)
|
||||
stacked_params[i].copy_(param)
|
||||
|
||||
# Fill all the 0-D tensors with current values
|
||||
self._muon_momentum_t.fill_(group["momentum"])
|
||||
|
|
@ -278,7 +281,8 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
)
|
||||
|
||||
# Copy back to original params
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
for i, param in enumerate(params):
|
||||
param.copy_(stacked_params[i])
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
|
|
|
|||
|
|
@ -65,8 +65,8 @@ parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious w
|
|||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.58, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.10, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
|
|
@ -367,7 +367,7 @@ def get_lr_multiplier(it):
|
|||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.90 during LR warmdown)
|
||||
# Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.92 during LR warmdown)
|
||||
def get_muon_momentum(it):
|
||||
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
||||
warmdown_start = num_iterations - warmdown_iters
|
||||
|
|
@ -376,7 +376,7 @@ def get_muon_momentum(it):
|
|||
return (1 - frac) * 0.85 + frac * 0.97
|
||||
elif it >= warmdown_start:
|
||||
progress = (it - warmdown_start) / warmdown_iters
|
||||
return 0.97 * (1 - progress) + 0.90 * progress
|
||||
return 0.97 * (1 - progress) + 0.92 * progress
|
||||
else:
|
||||
return 0.97
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user