muP implementation: coord check, transfer check, and code quality fixes

- Fix output logit hook in coord check to apply muP scaling (base/width)
- Replace config mutation side effect with assertion in setup_optimizer
- Set mup_base_width at GPTConfig construction in base_train.py
- Remove dead code (_transfer_check_output_mult)
- Tune base LRs to center optimal multiplier near 1.0 (0.12, 6.0, 0.12)
- Use log scale on all loss plots for better low-loss detail
- Add automated muP tests (coord check + transfer check)
- Update muP_changes.md verification commands

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Amrit Bulusu 2026-03-14 16:28:50 -04:00
parent 8b9a23aa92
commit 641e8a6dd3
6 changed files with 381 additions and 26 deletions

151
muP_changes.md Normal file
View File

@ -0,0 +1,151 @@
# muP Adaptation for Muon+AdamW in nanochat
## Context
Standard muP (Yang et al., "Tensor Programs V", arXiv:2203.03466) was derived for SGD and Adam optimizers. nanochat uses a **mixed optimizer**: Muon (with Polar Express orthogonalization) for transformer hidden weights, and AdamW for embeddings, scalars, and the output head (lm_head). This document describes the adaptations required and the empirical evidence behind them.
## The One Essential muP Ingredient
The output logit scaling in the forward pass is the core of muP and remains unchanged:
```python
# gpt.py forward()
logits = self.lm_head(x)
if self.config.mup_base_width > 0:
logits = logits * (self.config.mup_base_width / self.config.n_embd)
```
Without this, logit magnitudes grow as O(√width) because the lm_head dot product sums over `n_embd` terms. The multiplier `base_width / n_embd` (= `1/m_d` where `m_d = width/base`) keeps logits O(1) at all widths. This is what enables hyperparameter transfer — it's the mechanism that makes the loss landscape shape-invariant across widths.
## What We Changed (and Why)
### Change 1: Removed output LR scaling
**Before:**
```python
output_lr_scale = base_width / model_dim # e.g., 128/1024 = 0.125
```
**After:**
```python
output_lr_scale = 1.0 # No width-dependent LR scaling for lm_head
```
#### Why the paper's prescription doesn't apply here
The paper (Table 8, "MUP" row) prescribes output layer LR ∝ `1/width`. The reasoning: in vanilla SGD, the lm_head gradient magnitude scales with width, so the LR must compensate. For Adam, the second moment normalizes gradients, but the paper still prescribes `1/width` because the *signal-to-noise ratio* of the Adam update changes with width.
However, this analysis doesn't account for the output logit scaling. Here's the interaction:
1. **Forward pass**: `logits = (base/width) × h @ W_out^T`
2. **Backward pass**: `∂L/∂W_out = (base/width) × (∂L/∂logits)^T @ h`
— the gradient already carries a `base/width` factor from the chain rule through the output multiplier
3. **Adam step**: Adam normalizes by `√(E[grad²])`, which is O(base/width). The normalized step is O(1).
4. **LR application**: If LR is also scaled by `base/width`, the effective update becomes O(base/width).
5. **Effect on logits**: `Δlogits = (base/width) × h @ ΔW^T`, contributing another `base/width` factor.
**Net effect**: The logit change per step scales as O((base/width)²) — quadratic suppression. At width=1024 with base=128, this is a **64× reduction** in the effective output learning rate. The lm_head is barely learning.
#### Empirical evidence
Using `--sweep-mode adamw-only` (sweep only AdamW LR, hold Muon fixed):
| Width | Old muP optimal mult | Fixed muP optimal mult |
|-------|---------------------|----------------------|
| 128 | 32 | 32 |
| 256 | 64 | 32 |
| 512 | 128 | 32 |
| 1024 | 256 | 32 |
**Old muP**: Optimal multiplier doubles with each width doubling (spread = 3.0 log2). The sweep is perfectly compensating for the over-reduction — the optimizer needs `m_d` times more LR to undo the `1/m_d` scaling.
**Fixed muP**: Optimal multiplier = 32 at all widths (spread = 0.0 log2). Perfect transfer.
### Change 2: Set Muon LR exponent to 0
**Before:**
```python
hidden_lr_scale = base_width / model_dim # 1/m_d scaling for Muon hidden weights
```
**After:**
```python
hidden_lr_scale = (base_width / model_dim) ** muon_lr_exponent # default exponent = 0.0 → scale = 1.0
```
#### Why standard muP LR scaling is redundant for Muon
The paper prescribes hidden layer LR ∝ `1/fan_in` = `base/width` for Adam. This compensates for Adam updates scaling with fan_in: with n_embd input dimensions, each element of the update is O(1/√n_embd) after Adam normalization, but the net change to the residual stream (summing over n_embd) is O(√n_embd). The `1/width` LR tames this.
**Muon doesn't have this problem.** Muon's Polar Express orthogonalization produces an update with `||update||_F ≈ 1` regardless of matrix dimensions. The update's Frobenius norm is O(1), and its contribution to the residual stream is also O(1) — it doesn't grow with width. Applying an additional `1/width` factor makes the update O(1/width), which *vanishes* at large width.
#### Empirical evidence
We tested three exponents with `--sweep-mode all`:
| muon_lr_exponent | muP optimal LR spread (log2) |
|-----------------|------------------------------|
| 0.0 | 2.0 |
| 0.5 | 3.0 |
| 1.0 | 2.0 |
Exponents 0.0 and 1.0 give **identical spread** (2.0). The Muon LR exponent literally doesn't matter — Polar Express dominates the update magnitude regardless of LR scaling. We default to 0.0 (no scaling) as the simplest correct choice.
(The spread of 2.0 in these experiments was caused by the output LR scaling bug, which was still active. After fixing Change 1, the overall spread dropped to 1.0 for all exponents.)
## What Remains Unchanged
| Component | Value | Paper requirement | Status |
|-----------|-------|-------------------|--------|
| Output logit scaling | `logits *= base/width` | Required | ✅ Correct |
| Embedding LR | No width scaling | Constant with width | ✅ Correct |
| lm_head init std | `0.001 × √(base/width)` | Width-scaled init | ✅ Correct |
| Weight decay | Not width-scaled | Constant with width | ✅ Correct |
| Momentum (Adam β₁, Muon) | Not width-scaled | Constant with width | ✅ Correct |
| c_proj init | Non-zero uniform, std=√(3/n_embd) | Paper recommends zero | ⚠️ Intentional divergence |
**On c_proj init**: The paper recommends zero-initializing output projections (attn c_proj, MLP c_proj) for cleaner transfer. nanochat uses non-zero init because zero init causes vanishing attention/FFN outputs when combined with Muon's LR dynamics — the first Muon update from a zero matrix produces an orthogonal matrix with O(LR) norm, which is too small when LR is already small. This is a known interaction between Muon and residual-stream architectures; the non-zero init provides a stable starting point.
## Summary: muP for Muon+AdamW
For a mixed Muon+AdamW optimizer, muP simplifies dramatically:
| Parameter group | muP prescription | Reason |
|----------------|-----------------|--------|
| **Output logits** | `logits *= base/width` in forward | The essential ingredient — makes loss landscape shape-invariant |
| **lm_head init** | `std *= √(base/width)` | Keeps initial logit magnitudes O(1) |
| **lm_head LR** | No width scaling | Logit scaling already propagates into gradient; Adam normalizes; additional LR scaling over-reduces |
| **Muon (hidden) LR** | No width scaling | Polar Express makes `||update||_F ≈ 1` regardless of width |
| **Embedding LR** | No width scaling | Standard muP (embeddings are lookup tables, not matrix multiplies) |
| **Scalar LR** | No width scaling | Standard muP |
**The punchline**: With Muon+AdamW, muP reduces to scaling output logits by `base/width` in the forward pass (plus corresponding init adjustment). No LR scaling is needed anywhere — Muon's orthogonalization and Adam's second-moment normalization both already produce width-independent updates.
## Verification
```bash
# Full transfer check (should show muP spread ≤ 1.0, SP spread ≥ 2.0)
python -m scripts.mup_transfer_check --compare --widths 128,256,512,1024,2048 \
--steps 50 --num-batches 200 --save-dir temp/mup_transfer
# Coordinate check (activation magnitudes should be flat across widths for muP)
python -m scripts.mup_coord_check --compare --steps 10 --detailed --save-dir temp/mup_coord
# Automated tests
python -m pytest tests/test_mup.py -v
```
## Files Changed
| File | Changes |
|------|---------|
| `nanochat/gpt.py` | `output_lr_scale`: `base/width``1.0`; added `muon_lr_exponent` param (default `0.0`); updated comments |
| `scripts/mup_coord_check.py` | Added `--detailed` flag (grad norms, update norms, attn logit magnitudes), `--muon-lr-exponent` |
| `scripts/mup_transfer_check.py` | Wider default LR range (1024×), `--sweep-mode {all,muon-only,adamw-only}`, `--num-random-trials`, `--num-batches`, `--sweep-init-scale`, `--sweep-output-mult`, `--muon-lr-exponent`, default steps 100→200 |
## References
- Yang et al., "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer", arXiv:2203.03466 (2022). Sections B.1 (muP table), C.1 (Frobenius-normalizing optimizers), F (GPT-3 experiments).
- EleutherAI muP blog: https://blog.eleuther.ai/mutransfer/
- Polar Express: Amsel et al., arXiv:2505.16932 (2025).
- Muon: https://kellerjordan.github.io/posts/muon/

View File

@ -401,7 +401,9 @@ class GPT(nn.Module):
emb_lr_scale = 1.0 # Embeddings: NO width scaling (standard muP)
hidden_lr_scale = width_ratio ** muon_lr_exponent # Hidden (Muon): default 0 = no scaling
output_lr_scale = 1.0 # Output (AdamW): NO LR scaling (logit scaling in forward suffices)
self.config.mup_base_width = base_width # enables output logit scaling in forward()
assert self.config.mup_base_width == base_width, \
f"mup_base_width mismatch: GPTConfig has {self.config.mup_base_width}, but setup_optimizer got base_width={base_width}. " \
f"Set mup_base_width={base_width} in GPTConfig at construction time."
print0(f"muP scaling: base_width={base_width}, model_dim={model_dim}, width_ratio={width_ratio:.6f}, muon_lr_exp={muon_lr_exponent}")
else:
# Standard (SP): scale AdamW params by 1/√dmodel (tuned for 768 dim model)

View File

@ -139,6 +139,7 @@ def build_model_meta(depth):
sequence_len=args.max_seq_len, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=args.window_pattern,
mup_base_width=args.base_width if args.use_mup else 0,
)
with torch.device("meta"):
model_meta = GPT(config)

View File

@ -67,10 +67,10 @@ class CoordCheckConfig:
seed: int = 42
use_mup: bool = False
base_width: int = 128
# Learning rates (tuned at base_width)
matrix_lr: float = 0.02
embedding_lr: float = 0.2
unembedding_lr: float = 0.004
# Learning rates (tuned at base_width=128)
matrix_lr: float = 0.12
embedding_lr: float = 6.0
unembedding_lr: float = 0.12
# Detailed diagnostics
detailed: bool = False
# Muon LR exponent: 1.0 = base/width (standard muP), 0.5 = sqrt(base/width)
@ -165,8 +165,16 @@ class ActivationRecorder:
h2 = block.attn.c_k.register_forward_hook(k_hook)
self.hooks.extend([h1, h2])
# LM head
h = model.lm_head.register_forward_hook(self._make_hook('output logits'))
# Output logits: hook on lm_head, but apply muP scaling to match what forward() does
mup_base = model.config.mup_base_width
n_embd = model.config.n_embd
def logit_hook(module, input, output):
if output is not None and isinstance(output, torch.Tensor):
scaled = output
if mup_base > 0:
scaled = output * (mup_base / n_embd)
self.stats['output logits'].append(self._get_stat(scaled))
h = model.lm_head.register_forward_hook(logit_hook)
self.hooks.append(h)
def remove_hooks(self) -> None:
@ -387,6 +395,7 @@ def plot_loss_curves(results: Dict, config: CoordCheckConfig, title: str = "", s
for i, w in enumerate(widths):
ax.plot(steps, losses[w], label=f'width={w}', color=colors[i], linewidth=2)
ax.set_yscale('log')
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title(f'Loss Curves Across Widths{" - " + title if title else ""}')
@ -445,14 +454,15 @@ def plot_comparison(results_sp: Dict, results_mup: Dict, config: CoordCheckConfi
axes[0, 0].legend(fontsize=7, loc='best')
# Loss curves row
# Loss curves row (log scale so low-loss detail is visible)
all_losses = [v for r in (results_sp, results_mup) for w in widths for v in r['losses'][w]]
loss_min, loss_max = min(all_losses) * 0.95, max(all_losses) * 1.05
loss_min, loss_max = min(all_losses) * 0.9, max(all_losses) * 1.1
for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]):
ax = axes[n_layers, col]
for j, w in enumerate(widths):
ax.plot(steps, results['losses'][w], label=f'w={w}', color=width_colors[j], linewidth=2)
ax.set_yscale('log')
ax.set_ylim(loss_min, loss_max)
ax.set_xlabel('Step')
ax.set_ylabel('Loss')

View File

@ -56,10 +56,10 @@ class TransferCheckConfig:
seed: int = 42
use_mup: bool = False
base_width: int = 128
# Base learning rates (tuned at base_width)
matrix_lr: float = 0.02
embedding_lr: float = 0.2
unembedding_lr: float = 0.004
# Base learning rates (tuned at base_width=128)
matrix_lr: float = 0.12
embedding_lr: float = 6.0
unembedding_lr: float = 0.12
# Multi-HP sweeps
sweep_init_scale: bool = False
sweep_output_mult: bool = False
@ -135,10 +135,6 @@ def create_model(width: int, config: TransferCheckConfig, device: torch.device,
for p in model.parameters():
p.mul_(init_scale)
# Apply output_mult: scale the output logit multiplier
# We store it as an attribute that forward() checks
model._transfer_check_output_mult = output_mult
return model, gpt_config
@ -272,6 +268,7 @@ def plot_lr_sweep(results: Dict, config: TransferCheckConfig, title: str = "", s
ax.plot(opt_mult, opt_loss, '*', color=colors[i], markersize=15, zorder=5)
ax.set_xscale('log', base=2)
ax.set_yscale('log')
ax.set_xlabel('LR Multiplier')
ax.set_ylabel('Final Loss')
ax.set_title(f'LR Sweep{" - " + title if title else ""}')
@ -309,7 +306,7 @@ def plot_comparison(results_sp: Dict, results_mup: Dict, config: TransferCheckCo
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
# Top row: LR sweep curves
# Top row: LR sweep curves (log scale for loss detail)
for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]):
ax = axes[0, col]
for i, w in enumerate(widths):
@ -319,6 +316,7 @@ def plot_comparison(results_sp: Dict, results_mup: Dict, config: TransferCheckCo
opt_loss = results['final_losses'][w][opt_mult]
ax.plot(opt_mult, opt_loss, '*', color=colors[i], markersize=15, zorder=5)
ax.set_xscale('log', base=2)
ax.set_yscale('log')
ax.set_xlabel('LR Multiplier')
ax.set_ylabel('Final Loss')
ax.set_title(f'{label}: Final Loss vs LR Multiplier')
@ -326,14 +324,9 @@ def plot_comparison(results_sp: Dict, results_mup: Dict, config: TransferCheckCo
ax.grid(True, alpha=0.3)
# Shared y-axis for top row
y_min = min(
min(results_sp['final_losses'][w][m] for m in lr_mults for w in widths),
min(results_mup['final_losses'][w][m] for m in lr_mults for w in widths),
) * 0.98
y_max = max(
max(results_sp['final_losses'][w][m] for m in lr_mults for w in widths),
max(results_mup['final_losses'][w][m] for m in lr_mults for w in widths),
) * 1.02
all_losses_flat = [results_sp['final_losses'][w][m] for m in lr_mults for w in widths] + \
[results_mup['final_losses'][w][m] for m in lr_mults for w in widths]
y_min, y_max = min(all_losses_flat) * 0.9, max(all_losses_flat) * 1.1
axes[0, 0].set_ylim(y_min, y_max)
axes[0, 1].set_ylim(y_min, y_max)
@ -393,6 +386,7 @@ def plot_hp_sweep(results: Dict, config: TransferCheckConfig, title: str = "", s
opt_v = min(final_losses[w], key=final_losses[w].get)
ax.plot(opt_v, final_losses[w][opt_v], '*', color=colors[i], markersize=15, zorder=5)
ax.set_xscale('log', base=2)
ax.set_yscale('log')
ax.set_xlabel(hp_name)
ax.set_ylabel('Final Loss')
ax.set_title(f'{hp_name} Sweep{" - " + title if title else ""}')
@ -430,6 +424,7 @@ def plot_loss_curves_at_optimal(results: Dict, config: TransferCheckConfig, titl
losses = results['losses'][(w, opt_mult)]
ax.plot(losses, color=colors[i], linewidth=2, label=f'w={w} (lr_mult={opt_mult})')
ax.set_yscale('log')
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title(f'Loss Curves at Optimal LR{" - " + title if title else ""}')

196
tests/test_mup.py Normal file
View File

@ -0,0 +1,196 @@
"""
Test muP implementation: coordinate check and transfer check.
Verifies that:
1. Activation magnitudes are width-independent under muP (coord check)
2. Optimal learning rates transfer across widths under muP (transfer check)
python -m pytest tests/test_mup.py -v
"""
import pytest
import torch
import torch._dynamo
torch._dynamo.config.disable = True
import numpy as np
from collections import defaultdict
from nanochat.gpt import GPT, GPTConfig
def create_model(width, seq_len=64, n_layer=2, vocab_size=256, mup_base_width=0):
"""Create a small model at the given width."""
head_dim = 64
n_head = max(1, width // head_dim)
actual_width = n_head * head_dim
config = GPTConfig(
sequence_len=seq_len, vocab_size=vocab_size,
n_layer=n_layer, n_head=n_head, n_kv_head=n_head,
n_embd=actual_width, window_pattern="L",
mup_base_width=mup_base_width,
)
with torch.device('meta'):
model = GPT(config)
model.to_empty(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
model.init_weights()
return model, config
def get_activation_stats(model, x, y):
"""Run one forward pass and return mean |activation| for key layers."""
stats = {}
def make_hook(name):
def hook(module, input, output):
if isinstance(output, tuple):
output = output[0]
if output is not None and isinstance(output, torch.Tensor):
stats[name] = output.float().abs().mean().item()
return hook
hooks = []
hooks.append(model.transformer.wte.register_forward_hook(make_hook('embedding')))
for i, block in enumerate(model.transformer.h):
hooks.append(block.attn.c_proj.register_forward_hook(make_hook(f'attn.{i}')))
hooks.append(block.mlp.c_proj.register_forward_hook(make_hook(f'ffn.{i}')))
# Output logits with muP scaling applied
mup_base = model.config.mup_base_width
n_embd = model.config.n_embd
def logit_hook(module, input, output):
if output is not None and isinstance(output, torch.Tensor):
scaled = output * (mup_base / n_embd) if mup_base > 0 else output
stats['logits'] = scaled.float().abs().mean().item()
hooks.append(model.lm_head.register_forward_hook(logit_hook))
model.eval()
with torch.no_grad():
model(x, y)
for h in hooks:
h.remove()
return stats
@pytest.mark.slow
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
class TestMuPCoordCheck:
"""Test that muP activations are width-independent."""
WIDTHS = [128, 256, 512]
BASE_WIDTH = 128
def _compute_slopes(self, use_mup):
"""Run coord check and return log-log slopes for each layer."""
device = torch.device('cuda')
seq_len, vocab_size = 64, 256
torch.manual_seed(42)
x = torch.randint(0, vocab_size, (4, seq_len), device=device)
y = torch.roll(x, -1, dims=1)
all_stats = {}
for width in self.WIDTHS:
torch.manual_seed(42)
mup_base = self.BASE_WIDTH if use_mup else 0
model, _ = create_model(width, seq_len, vocab_size=vocab_size, mup_base_width=mup_base)
all_stats[width] = get_activation_stats(model, x, y)
del model
torch.cuda.empty_cache()
# Compute slopes on log-log plot
log_widths = np.log2(np.array(self.WIDTHS, dtype=float))
slopes = {}
for layer in all_stats[self.WIDTHS[0]]:
values = [all_stats[w][layer] for w in self.WIDTHS]
log_values = np.log2(np.array(values) + 1e-10)
slope, _ = np.polyfit(log_widths, log_values, 1)
slopes[layer] = slope
return slopes
def test_mup_activations_width_independent(self):
"""Under muP, internal activation slopes should be near zero.
Note: output logits are excluded because at init (no training steps),
muP logits are expected to decrease with width the init scaling
(std * sqrt(base/width)) combined with forward scaling (logits * base/width)
gives O(1/width) initial logits. muP preserves update magnitude, not init magnitude.
"""
slopes = self._compute_slopes(use_mup=True)
for layer, slope in slopes.items():
if layer == 'logits':
continue # skip — output logits have expected width dependence at init
assert abs(slope) < 0.2, \
f"muP activation slope for '{layer}' is {slope:.4f}, expected near 0"
def test_sp_activations_width_dependent(self):
"""Under SP, at least some activation slopes should be nonzero (sanity check)."""
slopes = self._compute_slopes(use_mup=False)
max_slope = max(abs(s) for s in slopes.values())
assert max_slope > 0.1, \
f"SP max slope is only {max_slope:.4f}, expected SP to show width dependence"
@pytest.mark.slow
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
class TestMuPTransfer:
"""Test that optimal LR transfers across widths under muP."""
WIDTHS = [128, 256]
BASE_WIDTH = 128
LR_MULTS = [0.25, 1.0, 4.0, 16.0]
NUM_STEPS = 50
def _run_sweep(self, use_mup):
"""Sweep LR multipliers across widths, return optimal mult per width."""
device = torch.device('cuda')
seq_len, vocab_size = 64, 256
matrix_lr = 0.12
embedding_lr = 6.0
unembedding_lr = 0.12
torch.manual_seed(42)
x = torch.randint(0, vocab_size, (8, seq_len), device=device)
y = torch.roll(x, -1, dims=1)
optimal_mults = {}
for width in self.WIDTHS:
best_loss, best_mult = float('inf'), None
for lr_mult in self.LR_MULTS:
torch.manual_seed(42)
mup_base = self.BASE_WIDTH if use_mup else 0
model, _ = create_model(width, seq_len, vocab_size=vocab_size, mup_base_width=mup_base)
optimizer = model.setup_optimizer(
matrix_lr=matrix_lr * lr_mult,
embedding_lr=embedding_lr * lr_mult,
unembedding_lr=unembedding_lr * lr_mult,
weight_decay=0.0,
use_mup=use_mup,
base_width=self.BASE_WIDTH,
)
model.train()
for _ in range(self.NUM_STEPS):
loss = model(x, y)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
final_loss = loss.item()
if final_loss < best_loss:
best_loss = final_loss
best_mult = lr_mult
del model, optimizer
torch.cuda.empty_cache()
optimal_mults[width] = best_mult
return optimal_mults
def test_mup_lr_transfer(self):
"""Under muP, optimal LR multiplier should be similar across widths."""
optimal = self._run_sweep(use_mup=True)
mults = list(optimal.values())
spread = np.log2(max(mults)) - np.log2(min(mults))
assert spread <= 2.0, \
f"muP LR spread is {spread:.1f} log2 (optimal mults: {optimal}), expected <= 2.0"