perf: architecture + optimizer optimizations — 94.6 min to GPT-2 (4.3% speedup)

Two rounds of WeCo-guided D12 optimization, validated on D24.
Key changes: smaller sliding windows (seq/8), VE every 3rd layer,
RoPE 200K, smear removed, exponential residual decay, optimizer
buffer pre-allocation. Mean CORE=0.2591 across 3 D24 runs.
This commit is contained in:
Yan Meng 2026-03-19 23:15:26 +00:00
parent 5019accc5b
commit 741e54f360
3 changed files with 54 additions and 48 deletions

View File

@ -47,12 +47,27 @@ class Linear(nn.Linear):
Replaces autocast: master weights stay fp32 for optimizer precision,
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
def forward(self, x):
return F.linear(x, self.weight.to(dtype=x.dtype))
w = self.weight
if w.dtype != x.dtype:
w = w.to(dtype=x.dtype)
return F.linear(x, w)
class EmbeddingLinear(nn.Module):
"""Lightweight linear layer for lm_head without redundant dtype casting."""
def __init__(self, in_features, out_features, bias=False, device=None, dtype=None):
super().__init__()
assert not bias
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
def forward(self, x):
return F.linear(x, self.weight)
def has_ve(layer_idx, n_layer):
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
return layer_idx % 2 == (n_layer - 1) % 2
"""Returns True if GPT layer should have Value Embedding (every 3rd layer, last layer always included)."""
return layer_idx % 3 == (n_layer - 1) % 3
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
@ -172,19 +187,16 @@ class GPT(nn.Module):
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
self.lm_head = EmbeddingLinear(config.n_embd, padded_vocab_size, bias=False)
# Per-layer learnable scalars (inspired by modded-nanogpt)
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
# Separate parameters so they can have different optimizer treatment
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# Smear: mix previous token's embedding into current token (cheap bigram-like info)
self.smear_gate = Linear(24, 1, bias=False)
self.smear_lambda = nn.Parameter(torch.zeros(1))
# Backout: subtract cached mid-layer residual before final norm to remove low-level features
self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
# Value embeddings (ResFormer-style): alternating layers, last layer always included
# Value embeddings (ResFormer-style): every 3rd layer, last layer always included
head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
@ -224,19 +236,27 @@ class GPT(nn.Module):
for block in self.transformer.h:
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
torch.nn.init.uniform_(block.attn.c_v.weight, -0.85 * s, 0.85 * s)
torch.nn.init.uniform_(block.attn.c_proj.weight, -0.008, 0.008) # small nonzero init
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc
torch.nn.init.zeros_(block.mlp.c_proj.weight)
# Per-layer scalars
# Per-layer resid init: stronger residual at early layers, weaker at deep layers
# Per-layer resid init: exponential decay, stronger at early layers
import math
n_layer = self.config.n_layer
resid_start, resid_end = 1.18, 1.06
resid_decay = math.log(resid_start / resid_end) / max(n_layer - 1, 1)
for i in range(n_layer):
self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1))
# Decaying x0 init: earlier layers get more input embedding blending
self.resid_lambdas.data[i] = resid_start * math.exp(-resid_decay * i)
# x0 init: first-half only, linearly decaying, zero for deep layers
half_depth = max(1, n_layer // 2)
for i in range(n_layer):
self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1))
if i < half_depth:
frac = i / max(half_depth - 1, 1)
self.x0_lambdas.data[i] = 0.24 * (1.0 - frac) + 0.08 * frac
else:
self.x0_lambdas.data[i] = 0.0
# Value embeddings (init like c_v: uniform with same std)
for ve in self.value_embeds.values():
@ -257,10 +277,11 @@ class GPT(nn.Module):
# because GradScaler cannot unscale fp16 gradients.
if COMPUTE_DTYPE != torch.float16:
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
self.lm_head.to(dtype=COMPUTE_DTYPE)
for ve in self.value_embeds.values():
ve.to(dtype=COMPUTE_DTYPE)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=200000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
# autodetect the device from model embeddings
if device is None:
@ -292,7 +313,7 @@ class GPT(nn.Module):
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
# Map characters to window sizes
long_window = config.sequence_len
short_window = -(-long_window // 4 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
short_window = max(256, -(-long_window // 8 // 128) * 128) # ceil to FA3 tile size (2048 -> 256)
char_to_window = {
"L": (long_window, 0),
"S": (short_window, 0),
@ -326,7 +347,7 @@ class GPT(nn.Module):
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
self.backout_lambda.numel())
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
@ -354,7 +375,7 @@ class GPT(nn.Module):
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.backout_lambda.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return {
@ -377,8 +398,8 @@ class GPT(nn.Module):
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
backout_params = [self.backout_lambda]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(backout_params)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
@ -392,7 +413,7 @@ class GPT(nn.Module):
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=backout_params, lr=0.15, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
]
# Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}):
@ -424,25 +445,6 @@ class GPT(nn.Module):
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
x = norm(x)
# Smear: mix previous token's embedding into current position (cheap bigram info)
if kv_cache is None:
# Training / naive generate: full sequence available, use fast slice
assert T > 1, "Training forward pass should have T > 1"
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
else:
# KV cache inference: read prev embedding from cache, store current for next step
x_pre_smear = kv_cache.prev_embedding
kv_cache.prev_embedding = x[:, -1:, :]
if T > 1:
# Prefill: apply smear to positions 1+, same as training
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
elif x_pre_smear is not None:
# Decode: single token, use cached prev embedding
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
x = x + gate * x_pre_smear
# Forward the trunk of the Transformer
x0 = x # save initial normalized embedding for x0 residual
n_layer = self.config.n_layer

View File

@ -253,9 +253,12 @@ class MuonAdamW(torch.optim.Optimizer):
second_momentum_buffer = state["second_momentum_buffer"]
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Stack grads and params (NOTE: this assumes all params have the same shape)
stacked_grads = torch.stack([p.grad for p in params])
stacked_params = torch.stack(params)
# Stack grads and params using pre-allocated buffers (NOTE: this assumes all params have the same shape)
stacked_grads = torch.empty(num_params, *shape, dtype=dtype, device=device)
stacked_params = torch.empty(num_params, *shape, dtype=dtype, device=device)
for i, param in enumerate(params):
stacked_grads[i].copy_(param.grad)
stacked_params[i].copy_(param)
# Fill all the 0-D tensors with current values
self._muon_momentum_t.fill_(group["momentum"])
@ -278,7 +281,8 @@ class MuonAdamW(torch.optim.Optimizer):
)
# Copy back to original params
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
for i, param in enumerate(params):
param.copy_(stacked_params[i])
@torch.no_grad()
def step(self):

View File

@ -65,8 +65,8 @@ parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious w
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup")
parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown")
parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR")
parser.add_argument("--warmdown-ratio", type=float, default=0.58, help="ratio of iterations for LR warmdown")
parser.add_argument("--final-lr-frac", type=float, default=0.10, help="final LR as fraction of initial LR")
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
# Evaluation
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
@ -367,7 +367,7 @@ def get_lr_multiplier(it):
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * args.final_lr_frac
# Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.90 during LR warmdown)
# Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.92 during LR warmdown)
def get_muon_momentum(it):
warmdown_iters = round(args.warmdown_ratio * num_iterations)
warmdown_start = num_iterations - warmdown_iters
@ -376,7 +376,7 @@ def get_muon_momentum(it):
return (1 - frac) * 0.85 + frac * 0.97
elif it >= warmdown_start:
progress = (it - warmdown_start) / warmdown_iters
return 0.97 * (1 - progress) + 0.90 * progress
return 0.97 * (1 - progress) + 0.92 * progress
else:
return 0.97