Add MOE debug interval and logging for gradient statistics

- Introduced `MOE_DEBUG_INTERVAL` parameter in `runmps.sh` to control debug logging frequency during training.
- Enhanced `base_train.py` to log gradients of routed and shared weights at specified intervals, aiding in monitoring model performance.
- Updated `gpt.py` to adjust router bias calculations, improving load balancing among experts.
- Added unit tests in `test_moe.py` to validate the behavior of the MoE implementation and ensure correctness of gradient calculations.
This commit is contained in:
William Thurston 2025-11-13 16:22:20 -08:00
parent d7c62c2cfe
commit 76227f70d3
4 changed files with 216 additions and 8 deletions

View File

@ -137,6 +137,7 @@ MOE_EXPERT_FFN_MULT=${MOE_EXPERT_FFN_MULT:--1}
MOE_DENSE_LAYERS=${MOE_DENSE_LAYERS:--1}
MOE_GRANULARITY_TARGET=${MOE_GRANULARITY_TARGET:-12}
MOE_ACTIVATION_DEN=${MOE_ACTIVATION_DEN:-32}
MOE_DEBUG_INTERVAL=${MOE_DEBUG_INTERVAL:-0}
EVAL_SEQUENCES=10000
EVAL_STEPS=$(((EVAL_SEQUENCES + DEVICE_BATCH - 1) / DEVICE_BATCH))
EVAL_BATCH_MULT=4 # evaluate on 4 full batches
@ -276,7 +277,7 @@ python -m scripts.tok_eval
fi
fi
python -m scripts.base_train \
MOE_DEBUG_INTERVAL=$MOE_DEBUG_INTERVAL python -m scripts.base_train \
--depth=$BASE_DEPTH \
--max_seq_len=$SEQ_LEN \
--device_batch_size=$DEVICE_BATCH \

View File

@ -151,10 +151,7 @@ class MoEFeedForward(nn.Module):
self.register_parameter("shared_w1", None)
self.register_parameter("shared_w2", None)
self.register_buffer("router_bias", torch.zeros(self.num_routed_experts, dtype=torch.float32))
self.register_buffer("ema_load", torch.zeros(self.num_routed_experts, dtype=torch.float32))
self.balance_strength = 0.001
self.balance_decay = 0.99
self.bias_clamp = 0.2
self._last_stats = None
self.last_entropy = None
self.last_load = None
@ -203,10 +200,10 @@ class MoEFeedForward(nn.Module):
tokens = max(1, assignments.numel())
load = load / tokens
target = load.new_full((self.num_routed_experts,), 1.0 / self.num_routed_experts)
self.ema_load.mul_(self.balance_decay).add_((1 - self.balance_decay) * load)
balance = target - self.ema_load
self.router_bias.add_(self.balance_strength * balance)
self.router_bias.clamp_(-self.bias_clamp, self.bias_clamp)
balance = target - load
if self.balance_strength != 0:
direction = torch.sign(balance)
self.router_bias.add_(self.balance_strength * direction)
entropy = (-probs * torch.log(probs + 1e-9)).sum(dim=-1).mean()
load_cpu = load.detach().cpu()
entropy_cpu = entropy.detach().cpu()

View File

@ -310,6 +310,7 @@ tokens_per_step = total_batch_size
total_tokens_seen = 0
total_sequences_seen = 0
last_val_bpb = None
moe_debug_interval = int(os.environ.get("MOE_DEBUG_INTERVAL", "0") or 0)
def save_base_checkpoint(step_idx):
optimizer_state = [opt.state_dict() for opt in optimizers]
@ -414,6 +415,32 @@ for step in range(num_iterations + 1):
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
if (
moe_debug_interval > 0
and master_process
and step % moe_debug_interval == 0
):
from nanochat.gpt import MoEFeedForward
grad_lines = []
for layer_idx, block in enumerate(orig_model.transformer.h):
mlp = getattr(block, "mlp", None)
if isinstance(mlp, MoEFeedForward):
routed_grad = (
mlp.routed_w2.grad.detach().abs().mean().item()
if mlp.routed_w2.grad is not None else 0.0
)
shared_grad = (
mlp.shared_w2.grad.detach().abs().mean().item()
if mlp.shared_w2 is not None and mlp.shared_w2.grad is not None else 0.0
)
grad_lines.append(
f"layer{layer_idx}: |grad routed_w2|={routed_grad:.3e} |grad shared_w2|={shared_grad:.3e}"
)
if grad_lines:
print0(f"[MOE-GRAD] step {step:05d}")
for line in grad_lines:
print0(f" {line}")
# gradient clipping (TODO possibly expertiment with)
if grad_clip > 0.0:
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
@ -468,6 +495,31 @@ for step in range(num_iterations + 1):
log_payload.update(orig_model.get_moe_stats())
wandb_run.log(log_payload)
if (
moe_debug_interval > 0
and master_process
and step % moe_debug_interval == 0
):
from nanochat.gpt import MoEFeedForward
debug_lines = []
for layer_idx, block in enumerate(orig_model.transformer.h):
mlp = getattr(block, "mlp", None)
if isinstance(mlp, MoEFeedForward):
routed_norm = mlp.routed_w2.detach().abs().mean().item()
shared_norm = (
mlp.shared_w2.detach().abs().mean().item()
if mlp.shared_w2 is not None else 0.0
)
bias_norm = mlp.router_bias.detach().abs().mean().item()
debug_lines.append(
f"layer{layer_idx}: |routed_w2|={routed_norm:.3e} |shared_w2|={shared_norm:.3e} |router_bias|={bias_norm:.3e}"
)
if debug_lines:
print0(f"[MOE-DEBUG] step {step:05d}")
for line in debug_lines:
print0(f" {line}")
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")

158
tests/test_moe.py Normal file
View File

@ -0,0 +1,158 @@
import torch
import torch.nn.functional as F
from nanochat.gpt import GPTConfig, MoEFeedForward
def _make_moe(num_experts: int, top_k: int, shared: int = 0) -> MoEFeedForward:
config = GPTConfig(
sequence_len=8,
vocab_size=16,
n_layer=1,
n_head=1,
n_kv_head=1,
n_embd=4,
moe_num_experts=num_experts,
moe_num_shared_experts=shared,
moe_experts_per_token=top_k,
moe_expert_ffn_mult=1.0,
dense_layers_before_moe=0,
)
moe = MoEFeedForward(config)
moe.eval()
with torch.no_grad():
moe.router.weight.zero_()
bias = torch.linspace(float(num_experts), 1.0, steps=num_experts)
moe.router_bias.copy_(bias)
eye = torch.eye(moe.model_dim)
for idx in range(num_experts):
moe.routed_w1[idx].copy_(eye)
moe.routed_w2[idx].copy_(eye * (idx + 1))
if shared:
for idx in range(shared):
moe.shared_w1[idx].copy_(eye)
moe.shared_w2[idx].copy_(eye * (idx + 1))
return moe
def _reference_forward(moe: MoEFeedForward, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
x_flat = x.view(B * T, C)
router_logits = moe.router(x_flat) + moe.router_bias
probs = torch.softmax(router_logits, dim=-1)
topk_scores, topk_idx = torch.topk(probs, moe.top_k, dim=-1)
topk_scores = topk_scores / torch.clamp(topk_scores.sum(dim=-1, keepdim=True), min=1e-9)
routed = torch.zeros_like(x_flat)
for token in range(x_flat.size(0)):
for slot in range(moe.top_k):
expert = topk_idx[token, slot].item()
weight = topk_scores[token, slot]
hidden = torch.matmul(moe.routed_w1[expert], x_flat[token])
hidden = F.relu(hidden).square()
routed[token] = routed[token] + weight * torch.matmul(moe.routed_w2[expert], hidden)
if moe.num_shared_experts > 0:
shared = torch.zeros_like(x_flat)
for idx in range(moe.num_shared_experts):
hidden = torch.matmul(moe.shared_w1[idx], x_flat.t()).t()
hidden = F.relu(hidden).square()
shared = shared + torch.matmul(moe.shared_w2[idx], hidden.t()).t()
routed = routed + shared / moe.num_shared_experts
return routed.view(B, T, C)
def test_moe_topk_changes_output():
torch.manual_seed(0)
moe_full = _make_moe(num_experts=4, top_k=4)
moe_single = _make_moe(num_experts=4, top_k=1)
with torch.no_grad():
moe_single.router.weight.copy_(moe_full.router.weight)
moe_single.router_bias.copy_(moe_full.router_bias)
moe_single.routed_w1.copy_(moe_full.routed_w1)
moe_single.routed_w2.copy_(moe_full.routed_w2)
x = torch.ones(2, 3, moe_full.model_dim)
out_full = moe_full(x)
out_single = moe_single(x)
assert not torch.allclose(out_full, out_single)
def test_router_bias_pushes_toward_uniform_load():
moe = _make_moe(num_experts=4, top_k=1)
moe.train()
assignments = torch.tensor([0, 0, 0, 0], dtype=torch.long)
probs = torch.full((assignments.shape[0], moe.num_routed_experts), 1.0 / moe.num_routed_experts)
before = moe.router_bias.clone()
moe._balance_router(assignments, probs)
after = moe.router_bias
assert after[0].item() < before[0].item()
assert torch.all(after[1:] > before[1:])
def test_shared_expert_matches_reference():
torch.manual_seed(0)
moe = _make_moe(num_experts=2, top_k=1, shared=2)
x = torch.randn(2, 3, moe.model_dim)
out = moe(x)
ref = _reference_forward(moe, x)
assert torch.allclose(out, ref, atol=1e-6, rtol=1e-5)
class LoopMoE(MoEFeedForward):
def _dispatch_batched(self, flat_assignments, x_flat, topk_scores):
num_tokens, model_dim = x_flat.shape
routed = torch.zeros(num_tokens, model_dim, device=x_flat.device, dtype=x_flat.dtype)
assignments = flat_assignments.view(num_tokens, self.top_k)
for token in range(num_tokens):
for slot in range(self.top_k):
expert = assignments[token, slot].item()
weight = topk_scores[token, slot]
hidden = torch.matmul(self.routed_w1[expert], x_flat[token])
hidden = F.relu(hidden).square()
routed[token] = routed[token] + weight * torch.matmul(self.routed_w2[expert], hidden)
return routed
def test_moe_gradients_match_loop_reference():
torch.manual_seed(0)
config_moe = GPTConfig(
sequence_len=8,
vocab_size=16,
n_layer=1,
n_head=1,
n_kv_head=1,
n_embd=4,
moe_num_experts=4,
moe_num_shared_experts=1,
moe_experts_per_token=2,
moe_expert_ffn_mult=1.5,
dense_layers_before_moe=0,
)
moe_fast = MoEFeedForward(config_moe)
moe_loop = LoopMoE(config_moe)
moe_loop.load_state_dict(moe_fast.state_dict())
x_fast = torch.randn(2, 3, config_moe.n_embd, requires_grad=True)
x_loop = x_fast.detach().clone().requires_grad_(True)
target = torch.randn_like(x_fast)
out_fast = moe_fast(x_fast)
out_loop = moe_loop(x_loop)
assert torch.allclose(out_fast, out_loop, atol=1e-6, rtol=1e-5)
loss_fast = (out_fast * target).sum()
loss_loop = (out_loop * target).sum()
moe_fast.zero_grad(set_to_none=True)
moe_loop.zero_grad(set_to_none=True)
x_fast.grad = None
x_loop.grad = None
loss_fast.backward(retain_graph=True)
loss_loop.backward()
for (name_fast, p_fast), (name_loop, p_loop) in zip(moe_fast.named_parameters(), moe_loop.named_parameters()):
assert torch.allclose(p_fast.grad, p_loop.grad, atol=1e-6, rtol=1e-5), f"gradient mismatch for {name_fast}"
assert torch.allclose(x_fast.grad, x_loop.grad, atol=1e-6, rtol=1e-5)