mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-18 10:53:13 +00:00
working, 35 mfu, have to optimize
This commit is contained in:
parent
48804bff3a
commit
2f8734ee0f
|
|
@ -24,6 +24,7 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW
|
|||
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
from nanochat.flash_attention import flash_attn
|
||||
from nanochat.moe import MoE
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
|
|
@ -33,6 +34,8 @@ class GPTConfig:
|
|||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_embd: int = 768
|
||||
num_experts: int = 8 # MoE: number of expert MLPs
|
||||
top_k: int = 2 # MoE: number of active experts per token
|
||||
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
||||
# Characters: L=long (full context), S=short (half context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
|
|
@ -118,28 +121,15 @@ class CausalSelfAttention(nn.Module):
|
|||
return y
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_fc(x)
|
||||
x = F.relu(x).square()
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp = MLP(config)
|
||||
self.moe = MoE(config)
|
||||
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
x = x + self.moe(norm(x))
|
||||
return x
|
||||
|
||||
|
||||
|
|
@ -197,8 +187,9 @@ class GPT(nn.Module):
|
|||
attn.c_k: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_v: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_proj: zeros
|
||||
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
||||
mlp.c_proj: zeros
|
||||
moe.router.gate: uniform, std=1/sqrt(n_embd)
|
||||
moe.experts.w_ups: uniform, std=1/sqrt(n_embd)
|
||||
moe.experts.w_downs: zeros
|
||||
"""
|
||||
|
||||
# Embedding and unembedding
|
||||
|
|
@ -213,8 +204,15 @@ class GPT(nn.Module):
|
|||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
# MoE: router gate and expert up-projections get uniform, down-projections get zero
|
||||
torch.nn.init.uniform_(block.moe.router.gate.weight, -s, s)
|
||||
for w_up in block.moe.experts.w_ups:
|
||||
torch.nn.init.uniform_(w_up, -s, s)
|
||||
for w_down in block.moe.experts.w_downs:
|
||||
torch.nn.init.zeros_(w_down)
|
||||
# MoE load balancing buffers (zero after to_empty from meta device)
|
||||
block.moe.router.expert_bias.zero_()
|
||||
block.moe.router.tokens_per_expert_counter.zero_()
|
||||
|
||||
# Per-layer scalars
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
|
|
@ -306,6 +304,11 @@ class GPT(nn.Module):
|
|||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
# MoE: only top_k/num_experts fraction of expert params active per token
|
||||
expert_hidden = 4 * self.config.n_embd // self.config.top_k
|
||||
expert_params_per_layer = self.config.num_experts * 2 * self.config.n_embd * expert_hidden
|
||||
inactive_per_layer = expert_params_per_layer * (self.config.num_experts - self.config.top_k) // self.config.num_experts
|
||||
nparams_exclude += inactive_per_layer * self.config.n_layer
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
|
|
@ -385,6 +388,31 @@ class GPT(nn.Module):
|
|||
group["initial_lr"] = group["lr"]
|
||||
return optimizer
|
||||
|
||||
def update_moe_balancing(self, coeff=1e-3):
|
||||
"""Update expert routing bias for load balancing. Call before optimizer.step()."""
|
||||
for block in self.transformer.h:
|
||||
block.moe.router.update_expert_bias(coeff)
|
||||
|
||||
def get_moe_stats(self):
|
||||
"""Collect MoE routing statistics for logging. Call BEFORE update_moe_balancing (which resets counters)."""
|
||||
all_counts = []
|
||||
all_biases = []
|
||||
for block in self.transformer.h:
|
||||
router = block.moe.router
|
||||
all_counts.append(router.tokens_per_expert_counter)
|
||||
all_biases.append(router.expert_bias)
|
||||
counts = torch.stack(all_counts).float() # (n_layer, num_experts)
|
||||
biases = torch.stack(all_biases).float() # (n_layer, num_experts)
|
||||
# Load imbalance: coefficient of variation (std/mean) per layer, averaged
|
||||
counts_mean = counts.mean(dim=-1).clamp(min=1)
|
||||
counts_std = counts.std(dim=-1)
|
||||
load_imbalance = (counts_std / counts_mean).mean().item()
|
||||
return {
|
||||
"moe/load_imbalance": load_imbalance,
|
||||
"moe/expert_bias_std": biases.std().item(),
|
||||
"moe/expert_bias_max": biases.abs().max().item(),
|
||||
}
|
||||
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
B, T = idx.size()
|
||||
|
||||
|
|
|
|||
|
|
@ -237,6 +237,10 @@ def disable_fp8(model):
|
|||
# -----------------------------------------------------------------------------
|
||||
# Compile the model
|
||||
|
||||
# MoE uses torch._grouped_mm with cumulative offsets — dynamo needs this to
|
||||
# trace through scalar tensor operations that arise from cumsum/histc in routing
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
|
||||
|
|
@ -506,6 +510,8 @@ while True:
|
|||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
moe_stats = orig_model.get_moe_stats() if step % 100 == 0 else {}
|
||||
model.update_moe_balancing()
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
|
|
@ -547,6 +553,7 @@ while True:
|
|||
"train/mfu": mfu,
|
||||
"train/epoch": epoch,
|
||||
}
|
||||
log_data.update(moe_stats)
|
||||
wandb_run.log(log_data)
|
||||
|
||||
# state update
|
||||
|
|
|
|||
|
|
@ -305,6 +305,7 @@ for step in range(num_steps):
|
|||
lrm = get_lr_multiplier(step)
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
model.update_moe_balancing()
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
wandb_run.log({
|
||||
|
|
|
|||
|
|
@ -118,6 +118,8 @@ for name, fallback, source in [
|
|||
print0(f"Using {name}={arg_val}")
|
||||
|
||||
orig_model = model
|
||||
# MoE uses torch._grouped_mm — dynamo needs this for scalar tensor tracing
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
|
|
@ -430,6 +432,7 @@ while True:
|
|||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
model.update_moe_balancing()
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
synchronize()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user