working, 35 mfu, have to optimize

This commit is contained in:
Andrej Karpathy 2026-02-19 01:15:47 +00:00
parent 48804bff3a
commit 2f8734ee0f
4 changed files with 58 additions and 19 deletions

View File

@ -24,6 +24,7 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn
from nanochat.moe import MoE
@dataclass
class GPTConfig:
@ -33,6 +34,8 @@ class GPTConfig:
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768
num_experts: int = 8 # MoE: number of expert MLPs
top_k: int = 2 # MoE: number of active experts per token
# Sliding window attention pattern string, tiled across layers. Final layer always L.
# Characters: L=long (full context), S=short (half context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
@ -118,28 +121,15 @@ class CausalSelfAttention(nn.Module):
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
self.moe = MoE(config)
def forward(self, x, ve, cos_sin, window_size, kv_cache):
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
x = x + self.mlp(norm(x))
x = x + self.moe(norm(x))
return x
@ -197,8 +187,9 @@ class GPT(nn.Module):
attn.c_k: uniform, std=1/sqrt(n_embd)
attn.c_v: uniform, std=1/sqrt(n_embd)
attn.c_proj: zeros
mlp.c_fc: uniform, std=1/sqrt(n_embd)
mlp.c_proj: zeros
moe.router.gate: uniform, std=1/sqrt(n_embd)
moe.experts.w_ups: uniform, std=1/sqrt(n_embd)
moe.experts.w_downs: zeros
"""
# Embedding and unembedding
@ -213,8 +204,15 @@ class GPT(nn.Module):
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
torch.nn.init.zeros_(block.mlp.c_proj.weight)
# MoE: router gate and expert up-projections get uniform, down-projections get zero
torch.nn.init.uniform_(block.moe.router.gate.weight, -s, s)
for w_up in block.moe.experts.w_ups:
torch.nn.init.uniform_(w_up, -s, s)
for w_down in block.moe.experts.w_downs:
torch.nn.init.zeros_(w_down)
# MoE load balancing buffers (zero after to_empty from meta device)
block.moe.router.expert_bias.zero_()
block.moe.router.tokens_per_expert_counter.zero_()
# Per-layer scalars
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
@ -306,6 +304,11 @@ class GPT(nn.Module):
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel())
# MoE: only top_k/num_experts fraction of expert params active per token
expert_hidden = 4 * self.config.n_embd // self.config.top_k
expert_params_per_layer = self.config.num_experts * 2 * self.config.n_embd * expert_hidden
inactive_per_layer = expert_params_per_layer * (self.config.num_experts - self.config.top_k) // self.config.num_experts
nparams_exclude += inactive_per_layer * self.config.n_layer
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
@ -385,6 +388,31 @@ class GPT(nn.Module):
group["initial_lr"] = group["lr"]
return optimizer
def update_moe_balancing(self, coeff=1e-3):
"""Update expert routing bias for load balancing. Call before optimizer.step()."""
for block in self.transformer.h:
block.moe.router.update_expert_bias(coeff)
def get_moe_stats(self):
"""Collect MoE routing statistics for logging. Call BEFORE update_moe_balancing (which resets counters)."""
all_counts = []
all_biases = []
for block in self.transformer.h:
router = block.moe.router
all_counts.append(router.tokens_per_expert_counter)
all_biases.append(router.expert_bias)
counts = torch.stack(all_counts).float() # (n_layer, num_experts)
biases = torch.stack(all_biases).float() # (n_layer, num_experts)
# Load imbalance: coefficient of variation (std/mean) per layer, averaged
counts_mean = counts.mean(dim=-1).clamp(min=1)
counts_std = counts.std(dim=-1)
load_imbalance = (counts_std / counts_mean).mean().item()
return {
"moe/load_imbalance": load_imbalance,
"moe/expert_bias_std": biases.std().item(),
"moe/expert_bias_max": biases.abs().max().item(),
}
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()

View File

@ -237,6 +237,10 @@ def disable_fp8(model):
# -----------------------------------------------------------------------------
# Compile the model
# MoE uses torch._grouped_mm with cumulative offsets — dynamo needs this to
# trace through scalar tensor operations that arise from cumsum/histc in routing
torch._dynamo.config.capture_scalar_outputs = True
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
@ -506,6 +510,8 @@ while True:
if group['kind'] == 'muon':
group["momentum"] = muon_momentum
group["weight_decay"] = muon_weight_decay
moe_stats = orig_model.get_moe_stats() if step % 100 == 0 else {}
model.update_moe_balancing()
optimizer.step()
model.zero_grad(set_to_none=True)
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
@ -547,6 +553,7 @@ while True:
"train/mfu": mfu,
"train/epoch": epoch,
}
log_data.update(moe_stats)
wandb_run.log(log_data)
# state update

View File

@ -305,6 +305,7 @@ for step in range(num_steps):
lrm = get_lr_multiplier(step)
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
model.update_moe_balancing()
optimizer.step()
model.zero_grad(set_to_none=True)
wandb_run.log({

View File

@ -118,6 +118,8 @@ for name, fallback, source in [
print0(f"Using {name}={arg_val}")
orig_model = model
# MoE uses torch._grouped_mm — dynamo needs this for scalar tensor tracing
torch._dynamo.config.capture_scalar_outputs = True
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
@ -430,6 +432,7 @@ while True:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon':
group["momentum"] = muon_momentum
model.update_moe_balancing()
optimizer.step()
model.zero_grad(set_to_none=True)
synchronize()