mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-07 01:40:30 +00:00
Merge 5422d3a132 into 83dccc20ae
This commit is contained in:
commit
719eaa7f37
|
|
@ -24,6 +24,7 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW
|
|||
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
from nanochat.flash_attention import flash_attn
|
||||
from nanochat.moe import MoE
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
|
|
@ -33,6 +34,9 @@ class GPTConfig:
|
|||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_embd: int = 768
|
||||
num_experts: int = 8 # MoE: number of routed expert MLPs
|
||||
top_k: int = 2 # MoE: number of active routed experts per token
|
||||
num_shared_experts: int = 1 # MoE: number of shared (always-active) experts
|
||||
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
||||
# Characters: L=long (full context), S=short (half context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
|
|
@ -118,28 +122,15 @@ class CausalSelfAttention(nn.Module):
|
|||
return y
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_fc(x)
|
||||
x = F.relu(x).square()
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp = MLP(config)
|
||||
self.moe = MoE(config)
|
||||
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
x = x + self.moe(norm(x))
|
||||
return x
|
||||
|
||||
|
||||
|
|
@ -197,8 +188,11 @@ class GPT(nn.Module):
|
|||
attn.c_k: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_v: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_proj: zeros
|
||||
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
||||
mlp.c_proj: zeros
|
||||
moe.router.gate: uniform, std=1/sqrt(n_embd)
|
||||
moe.experts.w_up: uniform, std=1/sqrt(n_embd)
|
||||
moe.experts.w_down: zeros
|
||||
moe.shared_expert.w_up: uniform, std=1/sqrt(n_embd)
|
||||
moe.shared_expert.w_down: zeros
|
||||
"""
|
||||
|
||||
# Embedding and unembedding
|
||||
|
|
@ -213,8 +207,16 @@ class GPT(nn.Module):
|
|||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
# MoE: router gate and expert up-projections get uniform, down-projections get zero
|
||||
torch.nn.init.uniform_(block.moe.router.gate.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.moe.experts.w_up, -s, s)
|
||||
torch.nn.init.zeros_(block.moe.experts.w_down)
|
||||
if block.moe.shared_expert is not None:
|
||||
torch.nn.init.uniform_(block.moe.shared_expert.w_up.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.moe.shared_expert.w_down.weight)
|
||||
# MoE load balancing buffers (zero after to_empty from meta device)
|
||||
block.moe.router.expert_bias.zero_()
|
||||
block.moe.router.tokens_per_expert_counter.zero_()
|
||||
|
||||
# Per-layer scalars
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
|
|
@ -306,6 +308,12 @@ class GPT(nn.Module):
|
|||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
# MoE: only top_k/num_experts fraction of routed expert params active per token
|
||||
# Shared expert is always active so its params stay in the active count
|
||||
expert_hidden = self.transformer.h[0].moe.expert_hidden_dim
|
||||
routed_params_per_layer = self.config.num_experts * 2 * self.config.n_embd * expert_hidden
|
||||
inactive_per_layer = routed_params_per_layer * (self.config.num_experts - self.config.top_k) // self.config.num_experts
|
||||
nparams_exclude += inactive_per_layer * self.config.n_layer
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
|
|
@ -327,6 +335,10 @@ class GPT(nn.Module):
|
|||
|
||||
Returns a dict with counts for each parameter group, so downstream analysis
|
||||
can experiment with which combination gives the cleanest scaling laws.
|
||||
|
||||
For MoE, 'active_*' fields count only the parameters active per token
|
||||
(top_k out of num_experts routed experts, plus shared experts).
|
||||
Following DeepSeek convention of reporting both total and active params.
|
||||
"""
|
||||
# Count each group separately (mirrors the grouping in setup_optimizers)
|
||||
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
||||
|
|
@ -336,13 +348,24 @@ class GPT(nn.Module):
|
|||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
# MoE: only top_k/num_experts fraction of routed expert params active per token
|
||||
# Shared expert is always active so its params stay in the active count
|
||||
expert_hidden = self.transformer.h[0].moe.expert_hidden_dim
|
||||
routed_params_per_layer = self.config.num_experts * 2 * self.config.n_embd * expert_hidden
|
||||
inactive_per_layer = routed_params_per_layer * (self.config.num_experts - self.config.top_k) // self.config.num_experts
|
||||
moe_inactive = inactive_per_layer * self.config.n_layer
|
||||
active_transformer_matrices = transformer_matrices - moe_inactive
|
||||
active_total = total - moe_inactive
|
||||
return {
|
||||
'wte': wte,
|
||||
'value_embeds': value_embeds,
|
||||
'lm_head': lm_head,
|
||||
'transformer_matrices': transformer_matrices,
|
||||
'active_transformer_matrices': active_transformer_matrices,
|
||||
'scalars': scalars,
|
||||
'moe_inactive': moe_inactive,
|
||||
'total': total,
|
||||
'active_total': active_total,
|
||||
}
|
||||
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
|
|
@ -385,6 +408,31 @@ class GPT(nn.Module):
|
|||
group["initial_lr"] = group["lr"]
|
||||
return optimizer
|
||||
|
||||
def update_moe_balancing(self, coeff=1e-3):
|
||||
"""Update expert routing bias for load balancing. Call before optimizer.step()."""
|
||||
for block in self.transformer.h:
|
||||
block.moe.router.update_expert_bias(coeff)
|
||||
|
||||
def get_moe_stats(self):
|
||||
"""Collect MoE routing statistics for logging. Call BEFORE update_moe_balancing (which resets counters)."""
|
||||
all_counts = []
|
||||
all_biases = []
|
||||
for block in self.transformer.h:
|
||||
router = block.moe.router
|
||||
all_counts.append(router.tokens_per_expert_counter)
|
||||
all_biases.append(router.expert_bias)
|
||||
counts = torch.stack(all_counts).float() # (n_layer, num_experts)
|
||||
biases = torch.stack(all_biases).float() # (n_layer, num_experts)
|
||||
# Load imbalance: coefficient of variation (std/mean) per layer, averaged
|
||||
counts_mean = counts.mean(dim=-1).clamp(min=1)
|
||||
counts_std = counts.std(dim=-1)
|
||||
load_imbalance = (counts_std / counts_mean).mean().item()
|
||||
return {
|
||||
"moe/load_imbalance": load_imbalance,
|
||||
"moe/expert_bias_std": biases.std().item(),
|
||||
"moe/expert_bias_max": biases.abs().max().item(),
|
||||
}
|
||||
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
B, T = idx.size()
|
||||
|
||||
|
|
|
|||
241
nanochat/moe.py
Normal file
241
nanochat/moe.py
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
"""
|
||||
Mixture of Experts (MoE) layer for nanochat.
|
||||
|
||||
Replaces the standard dense MLP in each transformer block. Each token picks its
|
||||
top-K experts via a learned sigmoid router, so total parameters scale with
|
||||
num_experts but per-token FLOPs remain constant (iso-FLOP with the dense MLP).
|
||||
|
||||
Expert hidden dim = 4 * dim / (top_k + num_shared), rounded to 128, ensures
|
||||
approximately iso-FLOP with the dense MLP:
|
||||
Dense: 2 * dim * (4*dim) = 8*dim²
|
||||
MoE per token: (top_k + num_shared) * 2 * dim * H ≈ 8*dim²
|
||||
|
||||
Expert weights are 3D tensors of shape (num_experts, hidden, dim). Muon's Polar
|
||||
Express orthogonalization operates on the last two dims, so the expert dimension
|
||||
acts as a batch dim and each expert is independently orthogonalized.
|
||||
|
||||
At forward time, torch._grouped_mm dispatches tokens to experts via cumulative
|
||||
offsets — a single kernel per projection instead of a Python for-loop.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class TopKRouter(nn.Module):
|
||||
"""Sigmoid-gated top-K router. Each token independently picks K experts."""
|
||||
|
||||
def __init__(self, dim, num_experts, top_k):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(dim, num_experts, bias=False)
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
# Auxiliary-loss-free load balancing (DeepSeekV3)
|
||||
self.register_buffer('expert_bias', torch.zeros(num_experts))
|
||||
self.register_buffer('tokens_per_expert_counter', torch.zeros(num_experts))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: (T, dim) flattened token representations
|
||||
Returns:
|
||||
top_scores: (T, top_k) routing weights for selected experts
|
||||
selected_experts: (T, top_k) which experts each token chose
|
||||
num_tokens_per_expert: (num_experts,) how many tokens each expert received
|
||||
"""
|
||||
scores = self.gate(x) # (T, num_experts)
|
||||
scores = torch.sigmoid(scores.float()) # values in (0, 1)
|
||||
# Bias affects expert SELECTION but not gating weights (DeepSeekV3)
|
||||
biased_scores = scores + self.expert_bias
|
||||
_, selected_experts = torch.topk(biased_scores, k=self.top_k, dim=-1, sorted=False)
|
||||
top_scores = scores.gather(dim=-1, index=selected_experts)
|
||||
num_tokens_per_expert = torch.histc(
|
||||
selected_experts.float().view(-1),
|
||||
bins=self.num_experts, min=0, max=self.num_experts,
|
||||
)
|
||||
# Accumulate token counts for load balancing updates
|
||||
self.tokens_per_expert_counter += num_tokens_per_expert
|
||||
return top_scores, selected_experts, num_tokens_per_expert
|
||||
|
||||
def update_expert_bias(self, coeff=1e-3):
|
||||
"""Auxiliary-loss-free bias update (DeepSeekV3). Call before optimizer.step()."""
|
||||
counts = self.tokens_per_expert_counter
|
||||
# Sync token counts across GPUs if distributed
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(counts)
|
||||
if counts.sum() == 0:
|
||||
return
|
||||
mean_count = counts.mean()
|
||||
# Nudge underloaded experts up, overloaded experts down
|
||||
self.expert_bias += coeff * torch.sign(mean_count - counts)
|
||||
self.expert_bias -= self.expert_bias.mean() # center to prevent drift
|
||||
self.tokens_per_expert_counter.zero_()
|
||||
|
||||
|
||||
def _run_experts_grouped_mm(w_up, w_down, x, num_tokens_per_expert):
|
||||
"""Run all experts via grouped matmul — single kernel per projection.
|
||||
|
||||
torch._grouped_mm handles variable tokens-per-expert internally via
|
||||
cumulative offsets, so no Python for-loop or .tolist() device sync needed.
|
||||
All tensor shapes are static (the dynamic token distribution is encoded
|
||||
in the offsets, not in tensor dimensions).
|
||||
|
||||
Args:
|
||||
w_up: (num_experts, expert_hidden_dim, dim) - stacked up-projections
|
||||
w_down: (num_experts, dim, expert_hidden_dim) - stacked down-projections
|
||||
x: (total_tokens, dim) - tokens sorted by expert assignment
|
||||
num_tokens_per_expert: (num_experts,) - count per expert
|
||||
Returns:
|
||||
output: (total_tokens, dim)
|
||||
"""
|
||||
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
|
||||
# Cast everything to bf16 upfront (weights are fp32 for Muon, need bf16 for grouped_mm)
|
||||
x_bf16 = x.bfloat16()
|
||||
w_up_bf16 = w_up.bfloat16().transpose(-2, -1)
|
||||
w_down_bf16 = w_down.bfloat16().transpose(-2, -1)
|
||||
# Up-project all experts at once: (total_tokens, dim) → (total_tokens, expert_hidden_dim)
|
||||
h = torch._grouped_mm(x_bf16, w_up_bf16, offs=offsets)
|
||||
h = F.relu(h).square() # ReLU² activation
|
||||
# Down-project all experts at once: (total_tokens, expert_hidden_dim) → (total_tokens, dim)
|
||||
out = torch._grouped_mm(h.bfloat16(), w_down_bf16, offs=offsets)
|
||||
return out.type_as(x)
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def _run_experts_for_loop(w_up, w_down, x, num_tokens_per_expert):
|
||||
"""Fallback for-loop implementation for CPU/MPS where grouped_mm isn't available.
|
||||
|
||||
Decorated with @torch.compiler.disable because .tolist() causes a device-host
|
||||
sync that torch.compile can't handle. Only used on non-CUDA devices.
|
||||
"""
|
||||
token_counts = num_tokens_per_expert.tolist()
|
||||
chunks = torch.split(x, [int(c) for c in token_counts], dim=0)
|
||||
outputs = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
# No empty-chunk skip: matmul with (0, dim) tensors is valid and produces
|
||||
# zero gradients (vs None), which the optimizer needs for stacking.
|
||||
h = chunk @ w_up[i].T
|
||||
h = F.relu(h).square()
|
||||
h = h @ w_down[i].T
|
||||
outputs.append(h)
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
class SharedExpert(nn.Module):
|
||||
"""Dense MLP shared expert — processes ALL tokens (no routing).
|
||||
|
||||
Same architecture as each routed expert (up → ReLU² → down) but uses
|
||||
standard nn.Linear layers (2D weights, regular matmul) since there's
|
||||
no need for the grouped_mm dispatch machinery.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, expert_hidden_dim):
|
||||
super().__init__()
|
||||
self.w_up = nn.Linear(dim, expert_hidden_dim, bias=False)
|
||||
self.w_down = nn.Linear(expert_hidden_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
h = F.relu(self.w_up(x)).square()
|
||||
return self.w_down(h)
|
||||
|
||||
|
||||
class ExpertGroup(nn.Module):
|
||||
"""
|
||||
N independent expert MLPs stored as 3D weight tensors.
|
||||
Shape (num_experts, hidden, dim) — Muon's Polar Express operates on the
|
||||
last two dims, so each expert matrix is independently orthogonalized.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, expert_hidden_dim, num_experts):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.w_up = nn.Parameter(torch.empty(num_experts, expert_hidden_dim, dim))
|
||||
self.w_down = nn.Parameter(torch.empty(num_experts, dim, expert_hidden_dim))
|
||||
|
||||
def forward(self, x, num_tokens_per_expert):
|
||||
"""
|
||||
Args:
|
||||
x: (T*K, dim) tokens sorted by expert assignment
|
||||
num_tokens_per_expert: (num_experts,) count per expert
|
||||
Returns:
|
||||
output: (T*K, dim)
|
||||
"""
|
||||
if x.is_cuda:
|
||||
return _run_experts_grouped_mm(self.w_up, self.w_down, x, num_tokens_per_expert)
|
||||
return _run_experts_for_loop(self.w_up, self.w_down, x, num_tokens_per_expert)
|
||||
|
||||
|
||||
class MoE(nn.Module):
|
||||
"""
|
||||
Mixture of Experts layer — approximately iso-FLOP replacement for the dense MLP.
|
||||
|
||||
For each token:
|
||||
1. Shared expert processes all tokens via standard dense matmul
|
||||
2. Router scores all routed experts via sigmoid(gate(x))
|
||||
3. Top-K routed experts are selected
|
||||
4. Token is dispatched to those experts (weighted by routing score)
|
||||
5. Routed + shared expert outputs are summed together
|
||||
|
||||
Total active experts per token = top_k + num_shared_experts.
|
||||
Expert hidden dim is sized so total active FLOPs ≈ dense MLP FLOPs.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
dim = config.n_embd
|
||||
num_experts = config.num_experts
|
||||
top_k = config.top_k
|
||||
num_shared = config.num_shared_experts
|
||||
self.top_k = top_k
|
||||
# Iso-FLOP sizing: total active experts per token = top_k + num_shared
|
||||
# Round to nearest 128 for tensor core alignment
|
||||
active_experts = top_k + num_shared
|
||||
expert_hidden_dim = round(4 * dim / active_experts / 128) * 128
|
||||
self.expert_hidden_dim = expert_hidden_dim
|
||||
self.router = TopKRouter(dim, num_experts, top_k)
|
||||
self.experts = ExpertGroup(dim, expert_hidden_dim, num_experts)
|
||||
self.shared_expert = SharedExpert(dim, expert_hidden_dim) if num_shared > 0 else None
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args: x: (bs, slen, dim)
|
||||
Returns: output: (bs, slen, dim) — same shape, drop-in MLP replacement
|
||||
"""
|
||||
bs, slen, dim = x.shape
|
||||
x_flat = x.view(-1, dim) # (T, dim)
|
||||
|
||||
# Step 1: Route — each token picks its top-K experts
|
||||
top_scores, selected_experts, num_tokens_per_expert = self.router(x_flat)
|
||||
|
||||
# Step 2: Sort tokens by expert assignment for contiguous expert processing
|
||||
# argsort groups all assignments to expert 0 first, then expert 1, etc.
|
||||
token_indices_sorted = torch.argsort(selected_experts.view(-1), stable=True)
|
||||
scores_sorted = top_scores.view(-1)[token_indices_sorted] # (T*K,)
|
||||
token_ids = token_indices_sorted // self.top_k # map back to original token
|
||||
routed_input = x_flat[token_ids] # (T*K, dim)
|
||||
|
||||
# Step 3: Pre-multiply by routing scores (score_before_experts strategy)
|
||||
routed_input = (routed_input.float() * scores_sorted.unsqueeze(-1)).to(x.dtype)
|
||||
|
||||
# Step 4: Shared expert — runs on ALL tokens via standard dense matmul
|
||||
# Launched before routed experts so compute can overlap (no data dependency)
|
||||
shared_output = self.shared_expert(x_flat) if self.shared_expert is not None else None
|
||||
|
||||
# Step 5: Run routed experts on their assigned token blocks
|
||||
routed_output = self.experts(routed_input, num_tokens_per_expert)
|
||||
|
||||
# Step 6: Scatter outputs back to original positions and sum over top-K
|
||||
combined = torch.zeros(
|
||||
bs * slen * self.top_k, dim,
|
||||
dtype=routed_output.dtype, device=routed_output.device,
|
||||
)
|
||||
combined[token_indices_sorted] = routed_output
|
||||
output = combined.view(bs * slen, self.top_k, dim).sum(dim=1) # (T, dim)
|
||||
|
||||
# Step 7: Add shared expert output
|
||||
if shared_output is not None:
|
||||
output = output + shared_output
|
||||
|
||||
return output.view(bs, slen, dim)
|
||||
|
|
@ -246,9 +246,13 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
# Second momentum buffer is factored, either per-row or per-column.
|
||||
# Uses *shape[:-1] / *shape[:-2] to preserve leading dims (e.g. expert dim for 3D MoE params).
|
||||
if "second_momentum_buffer" not in state:
|
||||
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
|
||||
if shape[-2] >= shape[-1]:
|
||||
state_shape = (num_params, *shape[:-1], 1)
|
||||
else:
|
||||
state_shape = (num_params, *shape[:-2], 1, shape[-1])
|
||||
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
|
@ -463,8 +467,12 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||
# Second momentum buffer: preserve leading dims for 3D MoE params
|
||||
if "second_momentum_buffer" not in state:
|
||||
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
|
||||
if shape[-2] >= shape[-1]:
|
||||
state_shape = (chunk_size, *shape[:-1], 1)
|
||||
else:
|
||||
state_shape = (chunk_size, *shape[:-2], 1, shape[-1])
|
||||
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
|
|
|
|||
|
|
@ -237,6 +237,10 @@ def disable_fp8(model):
|
|||
# -----------------------------------------------------------------------------
|
||||
# Compile the model
|
||||
|
||||
# MoE uses torch._grouped_mm with cumulative offsets — dynamo needs this to
|
||||
# trace through scalar tensor operations that arise from cumsum/histc in routing
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
|
||||
|
|
@ -257,8 +261,9 @@ print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
|||
# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params
|
||||
def get_scaling_params(m):
|
||||
# As for which params to use exactly, transformer matrices + lm_head gives cleanest scaling laws (see dev/LOG.md Jan 27, 2026)
|
||||
# For MoE, use active params (only top_k routed experts + shared, not all experts)
|
||||
params_counts = m.num_scaling_params()
|
||||
scaling_params = params_counts['transformer_matrices'] + params_counts['lm_head']
|
||||
scaling_params = params_counts['active_transformer_matrices'] + params_counts['lm_head']
|
||||
return scaling_params
|
||||
num_scaling_params = get_scaling_params(model)
|
||||
target_tokens = int(args.target_param_data_ratio * num_scaling_params) # optimal tokens for the model we are about to train
|
||||
|
|
@ -506,6 +511,8 @@ while True:
|
|||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
moe_stats = orig_model.get_moe_stats() if step % 100 == 0 else {}
|
||||
model.update_moe_balancing()
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
|
|
@ -547,6 +554,7 @@ while True:
|
|||
"train/mfu": mfu,
|
||||
"train/epoch": epoch,
|
||||
}
|
||||
log_data.update(moe_stats)
|
||||
wandb_run.log(log_data)
|
||||
|
||||
# state update
|
||||
|
|
|
|||
|
|
@ -305,6 +305,7 @@ for step in range(num_steps):
|
|||
lrm = get_lr_multiplier(step)
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
model.update_moe_balancing()
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
wandb_run.log({
|
||||
|
|
|
|||
|
|
@ -118,6 +118,8 @@ for name, fallback, source in [
|
|||
print0(f"Using {name}={arg_val}")
|
||||
|
||||
orig_model = model
|
||||
# MoE uses torch._grouped_mm — dynamo needs this for scalar tensor tracing
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
|
|
@ -442,6 +444,7 @@ while True:
|
|||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
model.update_moe_balancing()
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
synchronize()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user