From 2f8734ee0f59993a6f836d0879720c4d745649bd Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 19 Feb 2026 01:15:47 +0000 Subject: [PATCH] working, 35 mfu, have to optimize --- nanochat/gpt.py | 66 ++++++++++++++++++++++++++++++------------- scripts/base_train.py | 7 +++++ scripts/chat_rl.py | 1 + scripts/chat_sft.py | 3 ++ 4 files changed, 58 insertions(+), 19 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..52c6751 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -24,6 +24,7 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn +from nanochat.moe import MoE @dataclass class GPTConfig: @@ -33,6 +34,8 @@ class GPTConfig: n_head: int = 6 # number of query heads n_kv_head: int = 6 # number of key/value heads (GQA) n_embd: int = 768 + num_experts: int = 8 # MoE: number of expert MLPs + top_k: int = 2 # MoE: number of active experts per token # Sliding window attention pattern string, tiled across layers. Final layer always L. # Characters: L=long (full context), S=short (half context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long @@ -118,28 +121,15 @@ class CausalSelfAttention(nn.Module): return y -class MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) - - def forward(self, x): - x = self.c_fc(x) - x = F.relu(x).square() - x = self.c_proj(x) - return x - - class Block(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.attn = CausalSelfAttention(config, layer_idx) - self.mlp = MLP(config) + self.moe = MoE(config) def forward(self, x, ve, cos_sin, window_size, kv_cache): x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache) - x = x + self.mlp(norm(x)) + x = x + self.moe(norm(x)) return x @@ -197,8 +187,9 @@ class GPT(nn.Module): attn.c_k: uniform, std=1/sqrt(n_embd) attn.c_v: uniform, std=1/sqrt(n_embd) attn.c_proj: zeros - mlp.c_fc: uniform, std=1/sqrt(n_embd) - mlp.c_proj: zeros + moe.router.gate: uniform, std=1/sqrt(n_embd) + moe.experts.w_ups: uniform, std=1/sqrt(n_embd) + moe.experts.w_downs: zeros """ # Embedding and unembedding @@ -213,8 +204,15 @@ class GPT(nn.Module): torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero - torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s) - torch.nn.init.zeros_(block.mlp.c_proj.weight) + # MoE: router gate and expert up-projections get uniform, down-projections get zero + torch.nn.init.uniform_(block.moe.router.gate.weight, -s, s) + for w_up in block.moe.experts.w_ups: + torch.nn.init.uniform_(w_up, -s, s) + for w_down in block.moe.experts.w_downs: + torch.nn.init.zeros_(w_down) + # MoE load balancing buffers (zero after to_empty from meta device) + block.moe.router.expert_bias.zero_() + block.moe.router.tokens_per_expert_counter.zero_() # Per-layer scalars self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init @@ -306,6 +304,11 @@ class GPT(nn.Module): value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + self.resid_lambdas.numel() + self.x0_lambdas.numel()) + # MoE: only top_k/num_experts fraction of expert params active per token + expert_hidden = 4 * self.config.n_embd // self.config.top_k + expert_params_per_layer = self.config.num_experts * 2 * self.config.n_embd * expert_hidden + inactive_per_layer = expert_params_per_layer * (self.config.num_experts - self.config.top_k) // self.config.num_experts + nparams_exclude += inactive_per_layer * self.config.n_layer h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -385,6 +388,31 @@ class GPT(nn.Module): group["initial_lr"] = group["lr"] return optimizer + def update_moe_balancing(self, coeff=1e-3): + """Update expert routing bias for load balancing. Call before optimizer.step().""" + for block in self.transformer.h: + block.moe.router.update_expert_bias(coeff) + + def get_moe_stats(self): + """Collect MoE routing statistics for logging. Call BEFORE update_moe_balancing (which resets counters).""" + all_counts = [] + all_biases = [] + for block in self.transformer.h: + router = block.moe.router + all_counts.append(router.tokens_per_expert_counter) + all_biases.append(router.expert_bias) + counts = torch.stack(all_counts).float() # (n_layer, num_experts) + biases = torch.stack(all_biases).float() # (n_layer, num_experts) + # Load imbalance: coefficient of variation (std/mean) per layer, averaged + counts_mean = counts.mean(dim=-1).clamp(min=1) + counts_std = counts.std(dim=-1) + load_imbalance = (counts_std / counts_mean).mean().item() + return { + "moe/load_imbalance": load_imbalance, + "moe/expert_bias_std": biases.std().item(), + "moe/expert_bias_max": biases.abs().max().item(), + } + def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): B, T = idx.size() diff --git a/scripts/base_train.py b/scripts/base_train.py index 24091b6..304c529 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -237,6 +237,10 @@ def disable_fp8(model): # ----------------------------------------------------------------------------- # Compile the model +# MoE uses torch._grouped_mm with cumulative offsets — dynamo needs this to +# trace through scalar tensor operations that arise from cumsum/histc in routing +torch._dynamo.config.capture_scalar_outputs = True + orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe @@ -506,6 +510,8 @@ while True: if group['kind'] == 'muon': group["momentum"] = muon_momentum group["weight_decay"] = muon_weight_decay + moe_stats = orig_model.get_moe_stats() if step % 100 == 0 else {} + model.update_moe_balancing() optimizer.step() model.zero_grad(set_to_none=True) train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point @@ -547,6 +553,7 @@ while True: "train/mfu": mfu, "train/epoch": epoch, } + log_data.update(moe_stats) wandb_run.log(log_data) # state update diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index 20a1a0a..2295b31 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -305,6 +305,7 @@ for step in range(num_steps): lrm = get_lr_multiplier(step) for group in optimizer.param_groups: group["lr"] = group["initial_lr"] * lrm + model.update_moe_balancing() optimizer.step() model.zero_grad(set_to_none=True) wandb_run.log({ diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index a783ed2..00c0b11 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -118,6 +118,8 @@ for name, fallback, source in [ print0(f"Using {name}={arg_val}") orig_model = model +# MoE uses torch._grouped_mm — dynamo needs this for scalar tensor tracing +torch._dynamo.config.capture_scalar_outputs = True model = torch.compile(model, dynamic=False) depth = model.config.n_layer num_flops_per_token = model.estimate_flops() @@ -430,6 +432,7 @@ while True: group["lr"] = group["initial_lr"] * lrm if group['kind'] == 'muon': group["momentum"] = muon_momentum + model.update_moe_balancing() optimizer.step() model.zero_grad(set_to_none=True) synchronize()