optimizations

This commit is contained in:
Andrej Karpathy 2026-02-19 01:48:02 +00:00
parent 9e854ab78b
commit a5e51a93ae
3 changed files with 89 additions and 51 deletions

View File

@ -34,8 +34,9 @@ class GPTConfig:
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768
num_experts: int = 8 # MoE: number of expert MLPs
top_k: int = 2 # MoE: number of active experts per token
num_experts: int = 8 # MoE: number of routed expert MLPs
top_k: int = 2 # MoE: number of active routed experts per token
num_shared_experts: int = 1 # MoE: number of shared (always-active) experts
# Sliding window attention pattern string, tiled across layers. Final layer always L.
# Characters: L=long (full context), S=short (half context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
@ -188,8 +189,10 @@ class GPT(nn.Module):
attn.c_v: uniform, std=1/sqrt(n_embd)
attn.c_proj: zeros
moe.router.gate: uniform, std=1/sqrt(n_embd)
moe.experts.w_ups: uniform, std=1/sqrt(n_embd)
moe.experts.w_downs: zeros
moe.experts.w_up: uniform, std=1/sqrt(n_embd)
moe.experts.w_down: zeros
moe.shared_expert.w_up: uniform, std=1/sqrt(n_embd)
moe.shared_expert.w_down: zeros
"""
# Embedding and unembedding
@ -206,10 +209,11 @@ class GPT(nn.Module):
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
# MoE: router gate and expert up-projections get uniform, down-projections get zero
torch.nn.init.uniform_(block.moe.router.gate.weight, -s, s)
for w_up in block.moe.experts.w_ups:
torch.nn.init.uniform_(w_up, -s, s)
for w_down in block.moe.experts.w_downs:
torch.nn.init.zeros_(w_down)
torch.nn.init.uniform_(block.moe.experts.w_up, -s, s)
torch.nn.init.zeros_(block.moe.experts.w_down)
if block.moe.shared_expert is not None:
torch.nn.init.uniform_(block.moe.shared_expert.w_up.weight, -s, s)
torch.nn.init.zeros_(block.moe.shared_expert.w_down.weight)
# MoE load balancing buffers (zero after to_empty from meta device)
block.moe.router.expert_bias.zero_()
block.moe.router.tokens_per_expert_counter.zero_()
@ -304,10 +308,11 @@ class GPT(nn.Module):
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel())
# MoE: only top_k/num_experts fraction of expert params active per token
expert_hidden = 4 * self.config.n_embd // self.config.top_k
expert_params_per_layer = self.config.num_experts * 2 * self.config.n_embd * expert_hidden
inactive_per_layer = expert_params_per_layer * (self.config.num_experts - self.config.top_k) // self.config.num_experts
# MoE: only top_k/num_experts fraction of routed expert params active per token
# Shared expert is always active so its params stay in the active count
expert_hidden = self.transformer.h[0].moe.expert_hidden_dim
routed_params_per_layer = self.config.num_experts * 2 * self.config.n_embd * expert_hidden
inactive_per_layer = routed_params_per_layer * (self.config.num_experts - self.config.top_k) // self.config.num_experts
nparams_exclude += inactive_per_layer * self.config.n_layer
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window

View File

@ -5,14 +5,17 @@ Replaces the standard dense MLP in each transformer block. Each token picks its
top-K experts via a learned sigmoid router, so total parameters scale with
num_experts but per-token FLOPs remain constant (iso-FLOP with the dense MLP).
Expert hidden dim = 4 * dim / top_k ensures FLOPs match:
Dense: 4 * dim * (4*dim) = 16*dim²
MoE: top_k * 4 * dim * (4*dim/top_k) = 16*dim²
Expert hidden dim = 4 * dim / (top_k + num_shared), rounded to 128, ensures
approximately iso-FLOP with the dense MLP:
Dense: 2 * dim * (4*dim) = 8*dim²
MoE per token: (top_k + num_shared) * 2 * dim * H 8*dim²
Expert weights are stored as separate 2D parameters (not stacked 3D) so they
integrate natively with the Muon optimizer, which expects 2D matrices. At forward
time, params are stacked into 3D and fed to torch._grouped_mm for efficient
batched computation (single kernel per projection instead of a Python for-loop).
Expert weights are 3D tensors of shape (num_experts, hidden, dim). Muon's Polar
Express orthogonalization operates on the last two dims, so the expert dimension
acts as a batch dim and each expert is independently orthogonalized.
At forward time, torch._grouped_mm dispatches tokens to experts via cumulative
offsets a single kernel per projection instead of a Python for-loop.
"""
import torch
@ -46,7 +49,7 @@ class TopKRouter(nn.Module):
scores = torch.sigmoid(scores.float()) # values in (0, 1)
# Bias affects expert SELECTION but not gating weights (DeepSeekV3)
biased_scores = scores + self.expert_bias
_, selected_experts = torch.topk(biased_scores, k=self.top_k, dim=-1)
_, selected_experts = torch.topk(biased_scores, k=self.top_k, dim=-1, sorted=False)
top_scores = scores.gather(dim=-1, index=selected_experts)
num_tokens_per_expert = torch.histc(
selected_experts.float().view(-1),
@ -120,24 +123,36 @@ def _run_experts_for_loop(w_up, w_down, x, num_tokens_per_expert):
return torch.cat(outputs, dim=0)
class SharedExpert(nn.Module):
"""Dense MLP shared expert — processes ALL tokens (no routing).
Same architecture as each routed expert (up ReLU² down) but uses
standard nn.Linear layers (2D weights, regular matmul) since there's
no need for the grouped_mm dispatch machinery.
"""
def __init__(self, dim, expert_hidden_dim):
super().__init__()
self.w_up = nn.Linear(dim, expert_hidden_dim, bias=False)
self.w_down = nn.Linear(expert_hidden_dim, dim, bias=False)
def forward(self, x):
h = F.relu(self.w_up(x)).square()
return self.w_down(h)
class ExpertGroup(nn.Module):
"""
N independent expert MLPs, each with separate 2D weight matrices.
Separate 2D params (not stacked 3D) for native Muon optimizer compatibility.
At forward time, params are stacked and dispatched via torch._grouped_mm.
N independent expert MLPs stored as 3D weight tensors.
Shape (num_experts, hidden, dim) Muon's Polar Express operates on the
last two dims, so each expert matrix is independently orthogonalized.
"""
def __init__(self, dim, expert_hidden_dim, num_experts):
super().__init__()
self.num_experts = num_experts
self.w_ups = nn.ParameterList([
nn.Parameter(torch.empty(expert_hidden_dim, dim))
for _ in range(num_experts)
])
self.w_downs = nn.ParameterList([
nn.Parameter(torch.empty(dim, expert_hidden_dim))
for _ in range(num_experts)
])
self.w_up = nn.Parameter(torch.empty(num_experts, expert_hidden_dim, dim))
self.w_down = nn.Parameter(torch.empty(num_experts, dim, expert_hidden_dim))
def forward(self, x, num_tokens_per_expert):
"""
@ -147,27 +162,24 @@ class ExpertGroup(nn.Module):
Returns:
output: (T*K, dim)
"""
# Stack separate 2D params into 3D for grouped_mm
# (autograd handles gradient propagation back to individual params)
w_up = torch.stack(list(self.w_ups)) # (num_experts, expert_hidden_dim, dim)
w_down = torch.stack(list(self.w_downs)) # (num_experts, dim, expert_hidden_dim)
if x.is_cuda:
return _run_experts_grouped_mm(w_up, w_down, x, num_tokens_per_expert)
return _run_experts_for_loop(w_up, w_down, x, num_tokens_per_expert)
return _run_experts_grouped_mm(self.w_up, self.w_down, x, num_tokens_per_expert)
return _run_experts_for_loop(self.w_up, self.w_down, x, num_tokens_per_expert)
class MoE(nn.Module):
"""
Mixture of Experts layer iso-FLOP replacement for the dense MLP.
Mixture of Experts layer approximately iso-FLOP replacement for the dense MLP.
For each token:
1. Router scores all experts via sigmoid(gate(x))
2. Top-K experts are selected
3. Token is dispatched to those experts (weighted by routing score)
4. Expert outputs are summed back together
1. Shared expert processes all tokens via standard dense matmul
2. Router scores all routed experts via sigmoid(gate(x))
3. Top-K routed experts are selected
4. Token is dispatched to those experts (weighted by routing score)
5. Routed + shared expert outputs are summed together
Total params are ~num_experts/top_k times larger than dense MLP,
but per-token FLOPs are identical.
Total active experts per token = top_k + num_shared_experts.
Expert hidden dim is sized so total active FLOPs dense MLP FLOPs.
"""
def __init__(self, config):
@ -175,11 +187,16 @@ class MoE(nn.Module):
dim = config.n_embd
num_experts = config.num_experts
top_k = config.top_k
num_shared = config.num_shared_experts
self.top_k = top_k
# Iso-FLOP sizing: expert_hidden = 4*dim/top_k so MoE FLOPs = dense MLP FLOPs
expert_hidden_dim = 4 * dim // top_k
# Iso-FLOP sizing: total active experts per token = top_k + num_shared
# Round to nearest 128 for tensor core alignment
active_experts = top_k + num_shared
expert_hidden_dim = round(4 * dim / active_experts / 128) * 128
self.expert_hidden_dim = expert_hidden_dim
self.router = TopKRouter(dim, num_experts, top_k)
self.experts = ExpertGroup(dim, expert_hidden_dim, num_experts)
self.shared_expert = SharedExpert(dim, expert_hidden_dim) if num_shared > 0 else None
def forward(self, x):
"""
@ -202,10 +219,14 @@ class MoE(nn.Module):
# Step 3: Pre-multiply by routing scores (score_before_experts strategy)
routed_input = (routed_input.float() * scores_sorted.unsqueeze(-1)).to(x.dtype)
# Step 4: Run experts on their assigned token blocks
# Step 4: Shared expert — runs on ALL tokens via standard dense matmul
# Launched before routed experts so compute can overlap (no data dependency)
shared_output = self.shared_expert(x_flat) if self.shared_expert is not None else None
# Step 5: Run routed experts on their assigned token blocks
routed_output = self.experts(routed_input, num_tokens_per_expert)
# Step 5: Scatter outputs back to original positions and sum over top-K
# Step 6: Scatter outputs back to original positions and sum over top-K
combined = torch.zeros(
bs * slen * self.top_k, dim,
dtype=routed_output.dtype, device=routed_output.device,
@ -213,4 +234,8 @@ class MoE(nn.Module):
combined[token_indices_sorted] = routed_output
output = combined.view(bs * slen, self.top_k, dim).sum(dim=1) # (T, dim)
# Step 7: Add shared expert output
if shared_output is not None:
output = output + shared_output
return output.view(bs, slen, dim)

View File

@ -246,9 +246,13 @@ class MuonAdamW(torch.optim.Optimizer):
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
momentum_buffer = state["momentum_buffer"]
# Second momentum buffer is factored, either per-row or per-column
# Second momentum buffer is factored, either per-row or per-column.
# Uses *shape[:-1] / *shape[:-2] to preserve leading dims (e.g. expert dim for 3D MoE params).
if "second_momentum_buffer" not in state:
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
if shape[-2] >= shape[-1]:
state_shape = (num_params, *shape[:-1], 1)
else:
state_shape = (num_params, *shape[:-2], 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"]
red_dim = -1 if shape[-2] >= shape[-1] else -2
@ -463,8 +467,12 @@ class DistMuonAdamW(torch.optim.Optimizer):
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
# Second momentum buffer: preserve leading dims for 3D MoE params
if "second_momentum_buffer" not in state:
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
if shape[-2] >= shape[-1]:
state_shape = (chunk_size, *shape[:-1], 1)
else:
state_shape = (chunk_size, *shape[:-2], 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
red_dim = -1 if shape[-2] >= shape[-1] else -2