From 76227f70d382397943c3a4373db87d58f8b49967 Mon Sep 17 00:00:00 2001 From: William Thurston Date: Thu, 13 Nov 2025 16:22:20 -0800 Subject: [PATCH] Add MOE debug interval and logging for gradient statistics - Introduced `MOE_DEBUG_INTERVAL` parameter in `runmps.sh` to control debug logging frequency during training. - Enhanced `base_train.py` to log gradients of routed and shared weights at specified intervals, aiding in monitoring model performance. - Updated `gpt.py` to adjust router bias calculations, improving load balancing among experts. - Added unit tests in `test_moe.py` to validate the behavior of the MoE implementation and ensure correctness of gradient calculations. --- dev/runmps.sh | 3 +- nanochat/gpt.py | 11 ++- scripts/base_train.py | 52 ++++++++++++++ tests/test_moe.py | 158 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 216 insertions(+), 8 deletions(-) create mode 100644 tests/test_moe.py diff --git a/dev/runmps.sh b/dev/runmps.sh index 4027949..4831c71 100755 --- a/dev/runmps.sh +++ b/dev/runmps.sh @@ -137,6 +137,7 @@ MOE_EXPERT_FFN_MULT=${MOE_EXPERT_FFN_MULT:--1} MOE_DENSE_LAYERS=${MOE_DENSE_LAYERS:--1} MOE_GRANULARITY_TARGET=${MOE_GRANULARITY_TARGET:-12} MOE_ACTIVATION_DEN=${MOE_ACTIVATION_DEN:-32} +MOE_DEBUG_INTERVAL=${MOE_DEBUG_INTERVAL:-0} EVAL_SEQUENCES=10000 EVAL_STEPS=$(((EVAL_SEQUENCES + DEVICE_BATCH - 1) / DEVICE_BATCH)) EVAL_BATCH_MULT=4 # evaluate on 4 full batches @@ -276,7 +277,7 @@ python -m scripts.tok_eval fi fi - python -m scripts.base_train \ + MOE_DEBUG_INTERVAL=$MOE_DEBUG_INTERVAL python -m scripts.base_train \ --depth=$BASE_DEPTH \ --max_seq_len=$SEQ_LEN \ --device_batch_size=$DEVICE_BATCH \ diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 1f172b3..fbcdaac 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -151,10 +151,7 @@ class MoEFeedForward(nn.Module): self.register_parameter("shared_w1", None) self.register_parameter("shared_w2", None) self.register_buffer("router_bias", torch.zeros(self.num_routed_experts, dtype=torch.float32)) - self.register_buffer("ema_load", torch.zeros(self.num_routed_experts, dtype=torch.float32)) self.balance_strength = 0.001 - self.balance_decay = 0.99 - self.bias_clamp = 0.2 self._last_stats = None self.last_entropy = None self.last_load = None @@ -203,10 +200,10 @@ class MoEFeedForward(nn.Module): tokens = max(1, assignments.numel()) load = load / tokens target = load.new_full((self.num_routed_experts,), 1.0 / self.num_routed_experts) - self.ema_load.mul_(self.balance_decay).add_((1 - self.balance_decay) * load) - balance = target - self.ema_load - self.router_bias.add_(self.balance_strength * balance) - self.router_bias.clamp_(-self.bias_clamp, self.bias_clamp) + balance = target - load + if self.balance_strength != 0: + direction = torch.sign(balance) + self.router_bias.add_(self.balance_strength * direction) entropy = (-probs * torch.log(probs + 1e-9)).sum(dim=-1).mean() load_cpu = load.detach().cpu() entropy_cpu = entropy.detach().cpu() diff --git a/scripts/base_train.py b/scripts/base_train.py index 2f7ccef..eb2dd49 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -310,6 +310,7 @@ tokens_per_step = total_batch_size total_tokens_seen = 0 total_sequences_seen = 0 last_val_bpb = None +moe_debug_interval = int(os.environ.get("MOE_DEBUG_INTERVAL", "0") or 0) def save_base_checkpoint(step_idx): optimizer_state = [opt.state_dict() for opt in optimizers] @@ -414,6 +415,32 @@ for step in range(num_iterations + 1): loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + if ( + moe_debug_interval > 0 + and master_process + and step % moe_debug_interval == 0 + ): + from nanochat.gpt import MoEFeedForward + + grad_lines = [] + for layer_idx, block in enumerate(orig_model.transformer.h): + mlp = getattr(block, "mlp", None) + if isinstance(mlp, MoEFeedForward): + routed_grad = ( + mlp.routed_w2.grad.detach().abs().mean().item() + if mlp.routed_w2.grad is not None else 0.0 + ) + shared_grad = ( + mlp.shared_w2.grad.detach().abs().mean().item() + if mlp.shared_w2 is not None and mlp.shared_w2.grad is not None else 0.0 + ) + grad_lines.append( + f"layer{layer_idx}: |grad routed_w2|={routed_grad:.3e} |grad shared_w2|={shared_grad:.3e}" + ) + if grad_lines: + print0(f"[MOE-GRAD] step {step:05d}") + for line in grad_lines: + print0(f" {line}") # gradient clipping (TODO possibly expertiment with) if grad_clip > 0.0: torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip) @@ -468,6 +495,31 @@ for step in range(num_iterations + 1): log_payload.update(orig_model.get_moe_stats()) wandb_run.log(log_payload) + if ( + moe_debug_interval > 0 + and master_process + and step % moe_debug_interval == 0 + ): + from nanochat.gpt import MoEFeedForward + + debug_lines = [] + for layer_idx, block in enumerate(orig_model.transformer.h): + mlp = getattr(block, "mlp", None) + if isinstance(mlp, MoEFeedForward): + routed_norm = mlp.routed_w2.detach().abs().mean().item() + shared_norm = ( + mlp.shared_w2.detach().abs().mean().item() + if mlp.shared_w2 is not None else 0.0 + ) + bias_norm = mlp.router_bias.detach().abs().mean().item() + debug_lines.append( + f"layer{layer_idx}: |routed_w2|={routed_norm:.3e} |shared_w2|={shared_norm:.3e} |router_bias|={bias_norm:.3e}" + ) + if debug_lines: + print0(f"[MOE-DEBUG] step {step:05d}") + for line in debug_lines: + print0(f" {line}") + # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") diff --git a/tests/test_moe.py b/tests/test_moe.py new file mode 100644 index 0000000..5402962 --- /dev/null +++ b/tests/test_moe.py @@ -0,0 +1,158 @@ +import torch +import torch.nn.functional as F + +from nanochat.gpt import GPTConfig, MoEFeedForward + + +def _make_moe(num_experts: int, top_k: int, shared: int = 0) -> MoEFeedForward: + config = GPTConfig( + sequence_len=8, + vocab_size=16, + n_layer=1, + n_head=1, + n_kv_head=1, + n_embd=4, + moe_num_experts=num_experts, + moe_num_shared_experts=shared, + moe_experts_per_token=top_k, + moe_expert_ffn_mult=1.0, + dense_layers_before_moe=0, + ) + moe = MoEFeedForward(config) + moe.eval() + with torch.no_grad(): + moe.router.weight.zero_() + bias = torch.linspace(float(num_experts), 1.0, steps=num_experts) + moe.router_bias.copy_(bias) + eye = torch.eye(moe.model_dim) + for idx in range(num_experts): + moe.routed_w1[idx].copy_(eye) + moe.routed_w2[idx].copy_(eye * (idx + 1)) + if shared: + for idx in range(shared): + moe.shared_w1[idx].copy_(eye) + moe.shared_w2[idx].copy_(eye * (idx + 1)) + return moe + + +def _reference_forward(moe: MoEFeedForward, x: torch.Tensor) -> torch.Tensor: + B, T, C = x.shape + x_flat = x.view(B * T, C) + router_logits = moe.router(x_flat) + moe.router_bias + probs = torch.softmax(router_logits, dim=-1) + topk_scores, topk_idx = torch.topk(probs, moe.top_k, dim=-1) + topk_scores = topk_scores / torch.clamp(topk_scores.sum(dim=-1, keepdim=True), min=1e-9) + + routed = torch.zeros_like(x_flat) + for token in range(x_flat.size(0)): + for slot in range(moe.top_k): + expert = topk_idx[token, slot].item() + weight = topk_scores[token, slot] + hidden = torch.matmul(moe.routed_w1[expert], x_flat[token]) + hidden = F.relu(hidden).square() + routed[token] = routed[token] + weight * torch.matmul(moe.routed_w2[expert], hidden) + + if moe.num_shared_experts > 0: + shared = torch.zeros_like(x_flat) + for idx in range(moe.num_shared_experts): + hidden = torch.matmul(moe.shared_w1[idx], x_flat.t()).t() + hidden = F.relu(hidden).square() + shared = shared + torch.matmul(moe.shared_w2[idx], hidden.t()).t() + routed = routed + shared / moe.num_shared_experts + + return routed.view(B, T, C) + + +def test_moe_topk_changes_output(): + torch.manual_seed(0) + moe_full = _make_moe(num_experts=4, top_k=4) + moe_single = _make_moe(num_experts=4, top_k=1) + with torch.no_grad(): + moe_single.router.weight.copy_(moe_full.router.weight) + moe_single.router_bias.copy_(moe_full.router_bias) + moe_single.routed_w1.copy_(moe_full.routed_w1) + moe_single.routed_w2.copy_(moe_full.routed_w2) + x = torch.ones(2, 3, moe_full.model_dim) + out_full = moe_full(x) + out_single = moe_single(x) + assert not torch.allclose(out_full, out_single) + + +def test_router_bias_pushes_toward_uniform_load(): + moe = _make_moe(num_experts=4, top_k=1) + moe.train() + assignments = torch.tensor([0, 0, 0, 0], dtype=torch.long) + probs = torch.full((assignments.shape[0], moe.num_routed_experts), 1.0 / moe.num_routed_experts) + before = moe.router_bias.clone() + moe._balance_router(assignments, probs) + after = moe.router_bias + assert after[0].item() < before[0].item() + assert torch.all(after[1:] > before[1:]) + + +def test_shared_expert_matches_reference(): + torch.manual_seed(0) + moe = _make_moe(num_experts=2, top_k=1, shared=2) + x = torch.randn(2, 3, moe.model_dim) + out = moe(x) + ref = _reference_forward(moe, x) + assert torch.allclose(out, ref, atol=1e-6, rtol=1e-5) + + +class LoopMoE(MoEFeedForward): + def _dispatch_batched(self, flat_assignments, x_flat, topk_scores): + num_tokens, model_dim = x_flat.shape + routed = torch.zeros(num_tokens, model_dim, device=x_flat.device, dtype=x_flat.dtype) + assignments = flat_assignments.view(num_tokens, self.top_k) + for token in range(num_tokens): + for slot in range(self.top_k): + expert = assignments[token, slot].item() + weight = topk_scores[token, slot] + hidden = torch.matmul(self.routed_w1[expert], x_flat[token]) + hidden = F.relu(hidden).square() + routed[token] = routed[token] + weight * torch.matmul(self.routed_w2[expert], hidden) + return routed + + +def test_moe_gradients_match_loop_reference(): + torch.manual_seed(0) + config_moe = GPTConfig( + sequence_len=8, + vocab_size=16, + n_layer=1, + n_head=1, + n_kv_head=1, + n_embd=4, + moe_num_experts=4, + moe_num_shared_experts=1, + moe_experts_per_token=2, + moe_expert_ffn_mult=1.5, + dense_layers_before_moe=0, + ) + moe_fast = MoEFeedForward(config_moe) + moe_loop = LoopMoE(config_moe) + moe_loop.load_state_dict(moe_fast.state_dict()) + + x_fast = torch.randn(2, 3, config_moe.n_embd, requires_grad=True) + x_loop = x_fast.detach().clone().requires_grad_(True) + target = torch.randn_like(x_fast) + + out_fast = moe_fast(x_fast) + out_loop = moe_loop(x_loop) + assert torch.allclose(out_fast, out_loop, atol=1e-6, rtol=1e-5) + + loss_fast = (out_fast * target).sum() + loss_loop = (out_loop * target).sum() + + moe_fast.zero_grad(set_to_none=True) + moe_loop.zero_grad(set_to_none=True) + x_fast.grad = None + x_loop.grad = None + + loss_fast.backward(retain_graph=True) + loss_loop.backward() + + for (name_fast, p_fast), (name_loop, p_loop) in zip(moe_fast.named_parameters(), moe_loop.named_parameters()): + assert torch.allclose(p_fast.grad, p_loop.grad, atol=1e-6, rtol=1e-5), f"gradient mismatch for {name_fast}" + + assert torch.allclose(x_fast.grad, x_loop.grad, atol=1e-6, rtol=1e-5)