mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-14 08:53:13 +00:00
optimizations
This commit is contained in:
parent
9e854ab78b
commit
a5e51a93ae
|
|
@ -34,8 +34,9 @@ class GPTConfig:
|
|||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_embd: int = 768
|
||||
num_experts: int = 8 # MoE: number of expert MLPs
|
||||
top_k: int = 2 # MoE: number of active experts per token
|
||||
num_experts: int = 8 # MoE: number of routed expert MLPs
|
||||
top_k: int = 2 # MoE: number of active routed experts per token
|
||||
num_shared_experts: int = 1 # MoE: number of shared (always-active) experts
|
||||
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
||||
# Characters: L=long (full context), S=short (half context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
|
|
@ -188,8 +189,10 @@ class GPT(nn.Module):
|
|||
attn.c_v: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_proj: zeros
|
||||
moe.router.gate: uniform, std=1/sqrt(n_embd)
|
||||
moe.experts.w_ups: uniform, std=1/sqrt(n_embd)
|
||||
moe.experts.w_downs: zeros
|
||||
moe.experts.w_up: uniform, std=1/sqrt(n_embd)
|
||||
moe.experts.w_down: zeros
|
||||
moe.shared_expert.w_up: uniform, std=1/sqrt(n_embd)
|
||||
moe.shared_expert.w_down: zeros
|
||||
"""
|
||||
|
||||
# Embedding and unembedding
|
||||
|
|
@ -206,10 +209,11 @@ class GPT(nn.Module):
|
|||
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||
# MoE: router gate and expert up-projections get uniform, down-projections get zero
|
||||
torch.nn.init.uniform_(block.moe.router.gate.weight, -s, s)
|
||||
for w_up in block.moe.experts.w_ups:
|
||||
torch.nn.init.uniform_(w_up, -s, s)
|
||||
for w_down in block.moe.experts.w_downs:
|
||||
torch.nn.init.zeros_(w_down)
|
||||
torch.nn.init.uniform_(block.moe.experts.w_up, -s, s)
|
||||
torch.nn.init.zeros_(block.moe.experts.w_down)
|
||||
if block.moe.shared_expert is not None:
|
||||
torch.nn.init.uniform_(block.moe.shared_expert.w_up.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.moe.shared_expert.w_down.weight)
|
||||
# MoE load balancing buffers (zero after to_empty from meta device)
|
||||
block.moe.router.expert_bias.zero_()
|
||||
block.moe.router.tokens_per_expert_counter.zero_()
|
||||
|
|
@ -304,10 +308,11 @@ class GPT(nn.Module):
|
|||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
# MoE: only top_k/num_experts fraction of expert params active per token
|
||||
expert_hidden = 4 * self.config.n_embd // self.config.top_k
|
||||
expert_params_per_layer = self.config.num_experts * 2 * self.config.n_embd * expert_hidden
|
||||
inactive_per_layer = expert_params_per_layer * (self.config.num_experts - self.config.top_k) // self.config.num_experts
|
||||
# MoE: only top_k/num_experts fraction of routed expert params active per token
|
||||
# Shared expert is always active so its params stay in the active count
|
||||
expert_hidden = self.transformer.h[0].moe.expert_hidden_dim
|
||||
routed_params_per_layer = self.config.num_experts * 2 * self.config.n_embd * expert_hidden
|
||||
inactive_per_layer = routed_params_per_layer * (self.config.num_experts - self.config.top_k) // self.config.num_experts
|
||||
nparams_exclude += inactive_per_layer * self.config.n_layer
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
|
|
|
|||
|
|
@ -5,14 +5,17 @@ Replaces the standard dense MLP in each transformer block. Each token picks its
|
|||
top-K experts via a learned sigmoid router, so total parameters scale with
|
||||
num_experts but per-token FLOPs remain constant (iso-FLOP with the dense MLP).
|
||||
|
||||
Expert hidden dim = 4 * dim / top_k ensures FLOPs match:
|
||||
Dense: 4 * dim * (4*dim) = 16*dim²
|
||||
MoE: top_k * 4 * dim * (4*dim/top_k) = 16*dim²
|
||||
Expert hidden dim = 4 * dim / (top_k + num_shared), rounded to 128, ensures
|
||||
approximately iso-FLOP with the dense MLP:
|
||||
Dense: 2 * dim * (4*dim) = 8*dim²
|
||||
MoE per token: (top_k + num_shared) * 2 * dim * H ≈ 8*dim²
|
||||
|
||||
Expert weights are stored as separate 2D parameters (not stacked 3D) so they
|
||||
integrate natively with the Muon optimizer, which expects 2D matrices. At forward
|
||||
time, params are stacked into 3D and fed to torch._grouped_mm for efficient
|
||||
batched computation (single kernel per projection instead of a Python for-loop).
|
||||
Expert weights are 3D tensors of shape (num_experts, hidden, dim). Muon's Polar
|
||||
Express orthogonalization operates on the last two dims, so the expert dimension
|
||||
acts as a batch dim and each expert is independently orthogonalized.
|
||||
|
||||
At forward time, torch._grouped_mm dispatches tokens to experts via cumulative
|
||||
offsets — a single kernel per projection instead of a Python for-loop.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
|
@ -46,7 +49,7 @@ class TopKRouter(nn.Module):
|
|||
scores = torch.sigmoid(scores.float()) # values in (0, 1)
|
||||
# Bias affects expert SELECTION but not gating weights (DeepSeekV3)
|
||||
biased_scores = scores + self.expert_bias
|
||||
_, selected_experts = torch.topk(biased_scores, k=self.top_k, dim=-1)
|
||||
_, selected_experts = torch.topk(biased_scores, k=self.top_k, dim=-1, sorted=False)
|
||||
top_scores = scores.gather(dim=-1, index=selected_experts)
|
||||
num_tokens_per_expert = torch.histc(
|
||||
selected_experts.float().view(-1),
|
||||
|
|
@ -120,24 +123,36 @@ def _run_experts_for_loop(w_up, w_down, x, num_tokens_per_expert):
|
|||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
class SharedExpert(nn.Module):
|
||||
"""Dense MLP shared expert — processes ALL tokens (no routing).
|
||||
|
||||
Same architecture as each routed expert (up → ReLU² → down) but uses
|
||||
standard nn.Linear layers (2D weights, regular matmul) since there's
|
||||
no need for the grouped_mm dispatch machinery.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, expert_hidden_dim):
|
||||
super().__init__()
|
||||
self.w_up = nn.Linear(dim, expert_hidden_dim, bias=False)
|
||||
self.w_down = nn.Linear(expert_hidden_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
h = F.relu(self.w_up(x)).square()
|
||||
return self.w_down(h)
|
||||
|
||||
|
||||
class ExpertGroup(nn.Module):
|
||||
"""
|
||||
N independent expert MLPs, each with separate 2D weight matrices.
|
||||
Separate 2D params (not stacked 3D) for native Muon optimizer compatibility.
|
||||
At forward time, params are stacked and dispatched via torch._grouped_mm.
|
||||
N independent expert MLPs stored as 3D weight tensors.
|
||||
Shape (num_experts, hidden, dim) — Muon's Polar Express operates on the
|
||||
last two dims, so each expert matrix is independently orthogonalized.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, expert_hidden_dim, num_experts):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.w_ups = nn.ParameterList([
|
||||
nn.Parameter(torch.empty(expert_hidden_dim, dim))
|
||||
for _ in range(num_experts)
|
||||
])
|
||||
self.w_downs = nn.ParameterList([
|
||||
nn.Parameter(torch.empty(dim, expert_hidden_dim))
|
||||
for _ in range(num_experts)
|
||||
])
|
||||
self.w_up = nn.Parameter(torch.empty(num_experts, expert_hidden_dim, dim))
|
||||
self.w_down = nn.Parameter(torch.empty(num_experts, dim, expert_hidden_dim))
|
||||
|
||||
def forward(self, x, num_tokens_per_expert):
|
||||
"""
|
||||
|
|
@ -147,27 +162,24 @@ class ExpertGroup(nn.Module):
|
|||
Returns:
|
||||
output: (T*K, dim)
|
||||
"""
|
||||
# Stack separate 2D params into 3D for grouped_mm
|
||||
# (autograd handles gradient propagation back to individual params)
|
||||
w_up = torch.stack(list(self.w_ups)) # (num_experts, expert_hidden_dim, dim)
|
||||
w_down = torch.stack(list(self.w_downs)) # (num_experts, dim, expert_hidden_dim)
|
||||
if x.is_cuda:
|
||||
return _run_experts_grouped_mm(w_up, w_down, x, num_tokens_per_expert)
|
||||
return _run_experts_for_loop(w_up, w_down, x, num_tokens_per_expert)
|
||||
return _run_experts_grouped_mm(self.w_up, self.w_down, x, num_tokens_per_expert)
|
||||
return _run_experts_for_loop(self.w_up, self.w_down, x, num_tokens_per_expert)
|
||||
|
||||
|
||||
class MoE(nn.Module):
|
||||
"""
|
||||
Mixture of Experts layer — iso-FLOP replacement for the dense MLP.
|
||||
Mixture of Experts layer — approximately iso-FLOP replacement for the dense MLP.
|
||||
|
||||
For each token:
|
||||
1. Router scores all experts via sigmoid(gate(x))
|
||||
2. Top-K experts are selected
|
||||
3. Token is dispatched to those experts (weighted by routing score)
|
||||
4. Expert outputs are summed back together
|
||||
1. Shared expert processes all tokens via standard dense matmul
|
||||
2. Router scores all routed experts via sigmoid(gate(x))
|
||||
3. Top-K routed experts are selected
|
||||
4. Token is dispatched to those experts (weighted by routing score)
|
||||
5. Routed + shared expert outputs are summed together
|
||||
|
||||
Total params are ~num_experts/top_k times larger than dense MLP,
|
||||
but per-token FLOPs are identical.
|
||||
Total active experts per token = top_k + num_shared_experts.
|
||||
Expert hidden dim is sized so total active FLOPs ≈ dense MLP FLOPs.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
|
|
@ -175,11 +187,16 @@ class MoE(nn.Module):
|
|||
dim = config.n_embd
|
||||
num_experts = config.num_experts
|
||||
top_k = config.top_k
|
||||
num_shared = config.num_shared_experts
|
||||
self.top_k = top_k
|
||||
# Iso-FLOP sizing: expert_hidden = 4*dim/top_k so MoE FLOPs = dense MLP FLOPs
|
||||
expert_hidden_dim = 4 * dim // top_k
|
||||
# Iso-FLOP sizing: total active experts per token = top_k + num_shared
|
||||
# Round to nearest 128 for tensor core alignment
|
||||
active_experts = top_k + num_shared
|
||||
expert_hidden_dim = round(4 * dim / active_experts / 128) * 128
|
||||
self.expert_hidden_dim = expert_hidden_dim
|
||||
self.router = TopKRouter(dim, num_experts, top_k)
|
||||
self.experts = ExpertGroup(dim, expert_hidden_dim, num_experts)
|
||||
self.shared_expert = SharedExpert(dim, expert_hidden_dim) if num_shared > 0 else None
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
|
|
@ -202,10 +219,14 @@ class MoE(nn.Module):
|
|||
# Step 3: Pre-multiply by routing scores (score_before_experts strategy)
|
||||
routed_input = (routed_input.float() * scores_sorted.unsqueeze(-1)).to(x.dtype)
|
||||
|
||||
# Step 4: Run experts on their assigned token blocks
|
||||
# Step 4: Shared expert — runs on ALL tokens via standard dense matmul
|
||||
# Launched before routed experts so compute can overlap (no data dependency)
|
||||
shared_output = self.shared_expert(x_flat) if self.shared_expert is not None else None
|
||||
|
||||
# Step 5: Run routed experts on their assigned token blocks
|
||||
routed_output = self.experts(routed_input, num_tokens_per_expert)
|
||||
|
||||
# Step 5: Scatter outputs back to original positions and sum over top-K
|
||||
# Step 6: Scatter outputs back to original positions and sum over top-K
|
||||
combined = torch.zeros(
|
||||
bs * slen * self.top_k, dim,
|
||||
dtype=routed_output.dtype, device=routed_output.device,
|
||||
|
|
@ -213,4 +234,8 @@ class MoE(nn.Module):
|
|||
combined[token_indices_sorted] = routed_output
|
||||
output = combined.view(bs * slen, self.top_k, dim).sum(dim=1) # (T, dim)
|
||||
|
||||
# Step 7: Add shared expert output
|
||||
if shared_output is not None:
|
||||
output = output + shared_output
|
||||
|
||||
return output.view(bs, slen, dim)
|
||||
|
|
|
|||
|
|
@ -246,9 +246,13 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
# Second momentum buffer is factored, either per-row or per-column.
|
||||
# Uses *shape[:-1] / *shape[:-2] to preserve leading dims (e.g. expert dim for 3D MoE params).
|
||||
if "second_momentum_buffer" not in state:
|
||||
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
|
||||
if shape[-2] >= shape[-1]:
|
||||
state_shape = (num_params, *shape[:-1], 1)
|
||||
else:
|
||||
state_shape = (num_params, *shape[:-2], 1, shape[-1])
|
||||
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
|
@ -463,8 +467,12 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||
# Second momentum buffer: preserve leading dims for 3D MoE params
|
||||
if "second_momentum_buffer" not in state:
|
||||
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
|
||||
if shape[-2] >= shape[-1]:
|
||||
state_shape = (chunk_size, *shape[:-1], 1)
|
||||
else:
|
||||
state_shape = (chunk_size, *shape[:-2], 1, shape[-1])
|
||||
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user