mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 21:25:21 +00:00
muP implementation: coord check, transfer check, and code quality fixes
- Fix output logit hook in coord check to apply muP scaling (base/width) - Replace config mutation side effect with assertion in setup_optimizer - Set mup_base_width at GPTConfig construction in base_train.py - Remove dead code (_transfer_check_output_mult) - Tune base LRs to center optimal multiplier near 1.0 (0.12, 6.0, 0.12) - Use log scale on all loss plots for better low-loss detail - Add automated muP tests (coord check + transfer check) - Update muP_changes.md verification commands Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
8b9a23aa92
commit
641e8a6dd3
151
muP_changes.md
Normal file
151
muP_changes.md
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
# muP Adaptation for Muon+AdamW in nanochat
|
||||
|
||||
## Context
|
||||
|
||||
Standard muP (Yang et al., "Tensor Programs V", arXiv:2203.03466) was derived for SGD and Adam optimizers. nanochat uses a **mixed optimizer**: Muon (with Polar Express orthogonalization) for transformer hidden weights, and AdamW for embeddings, scalars, and the output head (lm_head). This document describes the adaptations required and the empirical evidence behind them.
|
||||
|
||||
## The One Essential muP Ingredient
|
||||
|
||||
The output logit scaling in the forward pass is the core of muP and remains unchanged:
|
||||
|
||||
```python
|
||||
# gpt.py forward()
|
||||
logits = self.lm_head(x)
|
||||
if self.config.mup_base_width > 0:
|
||||
logits = logits * (self.config.mup_base_width / self.config.n_embd)
|
||||
```
|
||||
|
||||
Without this, logit magnitudes grow as O(√width) because the lm_head dot product sums over `n_embd` terms. The multiplier `base_width / n_embd` (= `1/m_d` where `m_d = width/base`) keeps logits O(1) at all widths. This is what enables hyperparameter transfer — it's the mechanism that makes the loss landscape shape-invariant across widths.
|
||||
|
||||
## What We Changed (and Why)
|
||||
|
||||
### Change 1: Removed output LR scaling
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
output_lr_scale = base_width / model_dim # e.g., 128/1024 = 0.125
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
output_lr_scale = 1.0 # No width-dependent LR scaling for lm_head
|
||||
```
|
||||
|
||||
#### Why the paper's prescription doesn't apply here
|
||||
|
||||
The paper (Table 8, "MUP" row) prescribes output layer LR ∝ `1/width`. The reasoning: in vanilla SGD, the lm_head gradient magnitude scales with width, so the LR must compensate. For Adam, the second moment normalizes gradients, but the paper still prescribes `1/width` because the *signal-to-noise ratio* of the Adam update changes with width.
|
||||
|
||||
However, this analysis doesn't account for the output logit scaling. Here's the interaction:
|
||||
|
||||
1. **Forward pass**: `logits = (base/width) × h @ W_out^T`
|
||||
2. **Backward pass**: `∂L/∂W_out = (base/width) × (∂L/∂logits)^T @ h`
|
||||
— the gradient already carries a `base/width` factor from the chain rule through the output multiplier
|
||||
3. **Adam step**: Adam normalizes by `√(E[grad²])`, which is O(base/width). The normalized step is O(1).
|
||||
4. **LR application**: If LR is also scaled by `base/width`, the effective update becomes O(base/width).
|
||||
5. **Effect on logits**: `Δlogits = (base/width) × h @ ΔW^T`, contributing another `base/width` factor.
|
||||
|
||||
**Net effect**: The logit change per step scales as O((base/width)²) — quadratic suppression. At width=1024 with base=128, this is a **64× reduction** in the effective output learning rate. The lm_head is barely learning.
|
||||
|
||||
#### Empirical evidence
|
||||
|
||||
Using `--sweep-mode adamw-only` (sweep only AdamW LR, hold Muon fixed):
|
||||
|
||||
| Width | Old muP optimal mult | Fixed muP optimal mult |
|
||||
|-------|---------------------|----------------------|
|
||||
| 128 | 32 | 32 |
|
||||
| 256 | 64 | 32 |
|
||||
| 512 | 128 | 32 |
|
||||
| 1024 | 256 | 32 |
|
||||
|
||||
**Old muP**: Optimal multiplier doubles with each width doubling (spread = 3.0 log2). The sweep is perfectly compensating for the over-reduction — the optimizer needs `m_d` times more LR to undo the `1/m_d` scaling.
|
||||
|
||||
**Fixed muP**: Optimal multiplier = 32 at all widths (spread = 0.0 log2). Perfect transfer.
|
||||
|
||||
### Change 2: Set Muon LR exponent to 0
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
hidden_lr_scale = base_width / model_dim # 1/m_d scaling for Muon hidden weights
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
hidden_lr_scale = (base_width / model_dim) ** muon_lr_exponent # default exponent = 0.0 → scale = 1.0
|
||||
```
|
||||
|
||||
#### Why standard muP LR scaling is redundant for Muon
|
||||
|
||||
The paper prescribes hidden layer LR ∝ `1/fan_in` = `base/width` for Adam. This compensates for Adam updates scaling with fan_in: with n_embd input dimensions, each element of the update is O(1/√n_embd) after Adam normalization, but the net change to the residual stream (summing over n_embd) is O(√n_embd). The `1/width` LR tames this.
|
||||
|
||||
**Muon doesn't have this problem.** Muon's Polar Express orthogonalization produces an update with `||update||_F ≈ 1` regardless of matrix dimensions. The update's Frobenius norm is O(1), and its contribution to the residual stream is also O(1) — it doesn't grow with width. Applying an additional `1/width` factor makes the update O(1/width), which *vanishes* at large width.
|
||||
|
||||
#### Empirical evidence
|
||||
|
||||
We tested three exponents with `--sweep-mode all`:
|
||||
|
||||
| muon_lr_exponent | muP optimal LR spread (log2) |
|
||||
|-----------------|------------------------------|
|
||||
| 0.0 | 2.0 |
|
||||
| 0.5 | 3.0 |
|
||||
| 1.0 | 2.0 |
|
||||
|
||||
Exponents 0.0 and 1.0 give **identical spread** (2.0). The Muon LR exponent literally doesn't matter — Polar Express dominates the update magnitude regardless of LR scaling. We default to 0.0 (no scaling) as the simplest correct choice.
|
||||
|
||||
(The spread of 2.0 in these experiments was caused by the output LR scaling bug, which was still active. After fixing Change 1, the overall spread dropped to 1.0 for all exponents.)
|
||||
|
||||
## What Remains Unchanged
|
||||
|
||||
| Component | Value | Paper requirement | Status |
|
||||
|-----------|-------|-------------------|--------|
|
||||
| Output logit scaling | `logits *= base/width` | Required | ✅ Correct |
|
||||
| Embedding LR | No width scaling | Constant with width | ✅ Correct |
|
||||
| lm_head init std | `0.001 × √(base/width)` | Width-scaled init | ✅ Correct |
|
||||
| Weight decay | Not width-scaled | Constant with width | ✅ Correct |
|
||||
| Momentum (Adam β₁, Muon) | Not width-scaled | Constant with width | ✅ Correct |
|
||||
| c_proj init | Non-zero uniform, std=√(3/n_embd) | Paper recommends zero | ⚠️ Intentional divergence |
|
||||
|
||||
**On c_proj init**: The paper recommends zero-initializing output projections (attn c_proj, MLP c_proj) for cleaner transfer. nanochat uses non-zero init because zero init causes vanishing attention/FFN outputs when combined with Muon's LR dynamics — the first Muon update from a zero matrix produces an orthogonal matrix with O(LR) norm, which is too small when LR is already small. This is a known interaction between Muon and residual-stream architectures; the non-zero init provides a stable starting point.
|
||||
|
||||
## Summary: muP for Muon+AdamW
|
||||
|
||||
For a mixed Muon+AdamW optimizer, muP simplifies dramatically:
|
||||
|
||||
| Parameter group | muP prescription | Reason |
|
||||
|----------------|-----------------|--------|
|
||||
| **Output logits** | `logits *= base/width` in forward | The essential ingredient — makes loss landscape shape-invariant |
|
||||
| **lm_head init** | `std *= √(base/width)` | Keeps initial logit magnitudes O(1) |
|
||||
| **lm_head LR** | No width scaling | Logit scaling already propagates into gradient; Adam normalizes; additional LR scaling over-reduces |
|
||||
| **Muon (hidden) LR** | No width scaling | Polar Express makes `||update||_F ≈ 1` regardless of width |
|
||||
| **Embedding LR** | No width scaling | Standard muP (embeddings are lookup tables, not matrix multiplies) |
|
||||
| **Scalar LR** | No width scaling | Standard muP |
|
||||
|
||||
**The punchline**: With Muon+AdamW, muP reduces to scaling output logits by `base/width` in the forward pass (plus corresponding init adjustment). No LR scaling is needed anywhere — Muon's orthogonalization and Adam's second-moment normalization both already produce width-independent updates.
|
||||
|
||||
## Verification
|
||||
|
||||
```bash
|
||||
# Full transfer check (should show muP spread ≤ 1.0, SP spread ≥ 2.0)
|
||||
python -m scripts.mup_transfer_check --compare --widths 128,256,512,1024,2048 \
|
||||
--steps 50 --num-batches 200 --save-dir temp/mup_transfer
|
||||
|
||||
# Coordinate check (activation magnitudes should be flat across widths for muP)
|
||||
python -m scripts.mup_coord_check --compare --steps 10 --detailed --save-dir temp/mup_coord
|
||||
|
||||
# Automated tests
|
||||
python -m pytest tests/test_mup.py -v
|
||||
```
|
||||
|
||||
## Files Changed
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `nanochat/gpt.py` | `output_lr_scale`: `base/width` → `1.0`; added `muon_lr_exponent` param (default `0.0`); updated comments |
|
||||
| `scripts/mup_coord_check.py` | Added `--detailed` flag (grad norms, update norms, attn logit magnitudes), `--muon-lr-exponent` |
|
||||
| `scripts/mup_transfer_check.py` | Wider default LR range (1024×), `--sweep-mode {all,muon-only,adamw-only}`, `--num-random-trials`, `--num-batches`, `--sweep-init-scale`, `--sweep-output-mult`, `--muon-lr-exponent`, default steps 100→200 |
|
||||
|
||||
## References
|
||||
|
||||
- Yang et al., "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer", arXiv:2203.03466 (2022). Sections B.1 (muP table), C.1 (Frobenius-normalizing optimizers), F (GPT-3 experiments).
|
||||
- EleutherAI muP blog: https://blog.eleuther.ai/mutransfer/
|
||||
- Polar Express: Amsel et al., arXiv:2505.16932 (2025).
|
||||
- Muon: https://kellerjordan.github.io/posts/muon/
|
||||
|
|
@ -401,7 +401,9 @@ class GPT(nn.Module):
|
|||
emb_lr_scale = 1.0 # Embeddings: NO width scaling (standard muP)
|
||||
hidden_lr_scale = width_ratio ** muon_lr_exponent # Hidden (Muon): default 0 = no scaling
|
||||
output_lr_scale = 1.0 # Output (AdamW): NO LR scaling (logit scaling in forward suffices)
|
||||
self.config.mup_base_width = base_width # enables output logit scaling in forward()
|
||||
assert self.config.mup_base_width == base_width, \
|
||||
f"mup_base_width mismatch: GPTConfig has {self.config.mup_base_width}, but setup_optimizer got base_width={base_width}. " \
|
||||
f"Set mup_base_width={base_width} in GPTConfig at construction time."
|
||||
print0(f"muP scaling: base_width={base_width}, model_dim={model_dim}, width_ratio={width_ratio:.6f}, muon_lr_exp={muon_lr_exponent}")
|
||||
else:
|
||||
# Standard (SP): scale AdamW params by 1/√dmodel (tuned for 768 dim model)
|
||||
|
|
|
|||
|
|
@ -139,6 +139,7 @@ def build_model_meta(depth):
|
|||
sequence_len=args.max_seq_len, vocab_size=vocab_size,
|
||||
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
||||
window_pattern=args.window_pattern,
|
||||
mup_base_width=args.base_width if args.use_mup else 0,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_meta = GPT(config)
|
||||
|
|
|
|||
|
|
@ -67,10 +67,10 @@ class CoordCheckConfig:
|
|||
seed: int = 42
|
||||
use_mup: bool = False
|
||||
base_width: int = 128
|
||||
# Learning rates (tuned at base_width)
|
||||
matrix_lr: float = 0.02
|
||||
embedding_lr: float = 0.2
|
||||
unembedding_lr: float = 0.004
|
||||
# Learning rates (tuned at base_width=128)
|
||||
matrix_lr: float = 0.12
|
||||
embedding_lr: float = 6.0
|
||||
unembedding_lr: float = 0.12
|
||||
# Detailed diagnostics
|
||||
detailed: bool = False
|
||||
# Muon LR exponent: 1.0 = base/width (standard muP), 0.5 = sqrt(base/width)
|
||||
|
|
@ -165,8 +165,16 @@ class ActivationRecorder:
|
|||
h2 = block.attn.c_k.register_forward_hook(k_hook)
|
||||
self.hooks.extend([h1, h2])
|
||||
|
||||
# LM head
|
||||
h = model.lm_head.register_forward_hook(self._make_hook('output logits'))
|
||||
# Output logits: hook on lm_head, but apply muP scaling to match what forward() does
|
||||
mup_base = model.config.mup_base_width
|
||||
n_embd = model.config.n_embd
|
||||
def logit_hook(module, input, output):
|
||||
if output is not None and isinstance(output, torch.Tensor):
|
||||
scaled = output
|
||||
if mup_base > 0:
|
||||
scaled = output * (mup_base / n_embd)
|
||||
self.stats['output logits'].append(self._get_stat(scaled))
|
||||
h = model.lm_head.register_forward_hook(logit_hook)
|
||||
self.hooks.append(h)
|
||||
|
||||
def remove_hooks(self) -> None:
|
||||
|
|
@ -387,6 +395,7 @@ def plot_loss_curves(results: Dict, config: CoordCheckConfig, title: str = "", s
|
|||
for i, w in enumerate(widths):
|
||||
ax.plot(steps, losses[w], label=f'width={w}', color=colors[i], linewidth=2)
|
||||
|
||||
ax.set_yscale('log')
|
||||
ax.set_xlabel('Step')
|
||||
ax.set_ylabel('Loss')
|
||||
ax.set_title(f'Loss Curves Across Widths{" - " + title if title else ""}')
|
||||
|
|
@ -445,14 +454,15 @@ def plot_comparison(results_sp: Dict, results_mup: Dict, config: CoordCheckConfi
|
|||
|
||||
axes[0, 0].legend(fontsize=7, loc='best')
|
||||
|
||||
# Loss curves row
|
||||
# Loss curves row (log scale so low-loss detail is visible)
|
||||
all_losses = [v for r in (results_sp, results_mup) for w in widths for v in r['losses'][w]]
|
||||
loss_min, loss_max = min(all_losses) * 0.95, max(all_losses) * 1.05
|
||||
loss_min, loss_max = min(all_losses) * 0.9, max(all_losses) * 1.1
|
||||
|
||||
for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]):
|
||||
ax = axes[n_layers, col]
|
||||
for j, w in enumerate(widths):
|
||||
ax.plot(steps, results['losses'][w], label=f'w={w}', color=width_colors[j], linewidth=2)
|
||||
ax.set_yscale('log')
|
||||
ax.set_ylim(loss_min, loss_max)
|
||||
ax.set_xlabel('Step')
|
||||
ax.set_ylabel('Loss')
|
||||
|
|
|
|||
|
|
@ -56,10 +56,10 @@ class TransferCheckConfig:
|
|||
seed: int = 42
|
||||
use_mup: bool = False
|
||||
base_width: int = 128
|
||||
# Base learning rates (tuned at base_width)
|
||||
matrix_lr: float = 0.02
|
||||
embedding_lr: float = 0.2
|
||||
unembedding_lr: float = 0.004
|
||||
# Base learning rates (tuned at base_width=128)
|
||||
matrix_lr: float = 0.12
|
||||
embedding_lr: float = 6.0
|
||||
unembedding_lr: float = 0.12
|
||||
# Multi-HP sweeps
|
||||
sweep_init_scale: bool = False
|
||||
sweep_output_mult: bool = False
|
||||
|
|
@ -135,10 +135,6 @@ def create_model(width: int, config: TransferCheckConfig, device: torch.device,
|
|||
for p in model.parameters():
|
||||
p.mul_(init_scale)
|
||||
|
||||
# Apply output_mult: scale the output logit multiplier
|
||||
# We store it as an attribute that forward() checks
|
||||
model._transfer_check_output_mult = output_mult
|
||||
|
||||
return model, gpt_config
|
||||
|
||||
|
||||
|
|
@ -272,6 +268,7 @@ def plot_lr_sweep(results: Dict, config: TransferCheckConfig, title: str = "", s
|
|||
ax.plot(opt_mult, opt_loss, '*', color=colors[i], markersize=15, zorder=5)
|
||||
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_yscale('log')
|
||||
ax.set_xlabel('LR Multiplier')
|
||||
ax.set_ylabel('Final Loss')
|
||||
ax.set_title(f'LR Sweep{" - " + title if title else ""}')
|
||||
|
|
@ -309,7 +306,7 @@ def plot_comparison(results_sp: Dict, results_mup: Dict, config: TransferCheckCo
|
|||
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
|
||||
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
|
||||
|
||||
# Top row: LR sweep curves
|
||||
# Top row: LR sweep curves (log scale for loss detail)
|
||||
for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]):
|
||||
ax = axes[0, col]
|
||||
for i, w in enumerate(widths):
|
||||
|
|
@ -319,6 +316,7 @@ def plot_comparison(results_sp: Dict, results_mup: Dict, config: TransferCheckCo
|
|||
opt_loss = results['final_losses'][w][opt_mult]
|
||||
ax.plot(opt_mult, opt_loss, '*', color=colors[i], markersize=15, zorder=5)
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_yscale('log')
|
||||
ax.set_xlabel('LR Multiplier')
|
||||
ax.set_ylabel('Final Loss')
|
||||
ax.set_title(f'{label}: Final Loss vs LR Multiplier')
|
||||
|
|
@ -326,14 +324,9 @@ def plot_comparison(results_sp: Dict, results_mup: Dict, config: TransferCheckCo
|
|||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Shared y-axis for top row
|
||||
y_min = min(
|
||||
min(results_sp['final_losses'][w][m] for m in lr_mults for w in widths),
|
||||
min(results_mup['final_losses'][w][m] for m in lr_mults for w in widths),
|
||||
) * 0.98
|
||||
y_max = max(
|
||||
max(results_sp['final_losses'][w][m] for m in lr_mults for w in widths),
|
||||
max(results_mup['final_losses'][w][m] for m in lr_mults for w in widths),
|
||||
) * 1.02
|
||||
all_losses_flat = [results_sp['final_losses'][w][m] for m in lr_mults for w in widths] + \
|
||||
[results_mup['final_losses'][w][m] for m in lr_mults for w in widths]
|
||||
y_min, y_max = min(all_losses_flat) * 0.9, max(all_losses_flat) * 1.1
|
||||
axes[0, 0].set_ylim(y_min, y_max)
|
||||
axes[0, 1].set_ylim(y_min, y_max)
|
||||
|
||||
|
|
@ -393,6 +386,7 @@ def plot_hp_sweep(results: Dict, config: TransferCheckConfig, title: str = "", s
|
|||
opt_v = min(final_losses[w], key=final_losses[w].get)
|
||||
ax.plot(opt_v, final_losses[w][opt_v], '*', color=colors[i], markersize=15, zorder=5)
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_yscale('log')
|
||||
ax.set_xlabel(hp_name)
|
||||
ax.set_ylabel('Final Loss')
|
||||
ax.set_title(f'{hp_name} Sweep{" - " + title if title else ""}')
|
||||
|
|
@ -430,6 +424,7 @@ def plot_loss_curves_at_optimal(results: Dict, config: TransferCheckConfig, titl
|
|||
losses = results['losses'][(w, opt_mult)]
|
||||
ax.plot(losses, color=colors[i], linewidth=2, label=f'w={w} (lr_mult={opt_mult})')
|
||||
|
||||
ax.set_yscale('log')
|
||||
ax.set_xlabel('Step')
|
||||
ax.set_ylabel('Loss')
|
||||
ax.set_title(f'Loss Curves at Optimal LR{" - " + title if title else ""}')
|
||||
|
|
|
|||
196
tests/test_mup.py
Normal file
196
tests/test_mup.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
"""
|
||||
Test muP implementation: coordinate check and transfer check.
|
||||
|
||||
Verifies that:
|
||||
1. Activation magnitudes are width-independent under muP (coord check)
|
||||
2. Optimal learning rates transfer across widths under muP (transfer check)
|
||||
|
||||
python -m pytest tests/test_mup.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
torch._dynamo.config.disable = True
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
|
||||
|
||||
def create_model(width, seq_len=64, n_layer=2, vocab_size=256, mup_base_width=0):
|
||||
"""Create a small model at the given width."""
|
||||
head_dim = 64
|
||||
n_head = max(1, width // head_dim)
|
||||
actual_width = n_head * head_dim
|
||||
config = GPTConfig(
|
||||
sequence_len=seq_len, vocab_size=vocab_size,
|
||||
n_layer=n_layer, n_head=n_head, n_kv_head=n_head,
|
||||
n_embd=actual_width, window_pattern="L",
|
||||
mup_base_width=mup_base_width,
|
||||
)
|
||||
with torch.device('meta'):
|
||||
model = GPT(config)
|
||||
model.to_empty(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
||||
model.init_weights()
|
||||
return model, config
|
||||
|
||||
|
||||
def get_activation_stats(model, x, y):
|
||||
"""Run one forward pass and return mean |activation| for key layers."""
|
||||
stats = {}
|
||||
|
||||
def make_hook(name):
|
||||
def hook(module, input, output):
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
if output is not None and isinstance(output, torch.Tensor):
|
||||
stats[name] = output.float().abs().mean().item()
|
||||
return hook
|
||||
|
||||
hooks = []
|
||||
hooks.append(model.transformer.wte.register_forward_hook(make_hook('embedding')))
|
||||
for i, block in enumerate(model.transformer.h):
|
||||
hooks.append(block.attn.c_proj.register_forward_hook(make_hook(f'attn.{i}')))
|
||||
hooks.append(block.mlp.c_proj.register_forward_hook(make_hook(f'ffn.{i}')))
|
||||
|
||||
# Output logits with muP scaling applied
|
||||
mup_base = model.config.mup_base_width
|
||||
n_embd = model.config.n_embd
|
||||
def logit_hook(module, input, output):
|
||||
if output is not None and isinstance(output, torch.Tensor):
|
||||
scaled = output * (mup_base / n_embd) if mup_base > 0 else output
|
||||
stats['logits'] = scaled.float().abs().mean().item()
|
||||
hooks.append(model.lm_head.register_forward_hook(logit_hook))
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
model(x, y)
|
||||
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
return stats
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
|
||||
class TestMuPCoordCheck:
|
||||
"""Test that muP activations are width-independent."""
|
||||
|
||||
WIDTHS = [128, 256, 512]
|
||||
BASE_WIDTH = 128
|
||||
|
||||
def _compute_slopes(self, use_mup):
|
||||
"""Run coord check and return log-log slopes for each layer."""
|
||||
device = torch.device('cuda')
|
||||
seq_len, vocab_size = 64, 256
|
||||
torch.manual_seed(42)
|
||||
x = torch.randint(0, vocab_size, (4, seq_len), device=device)
|
||||
y = torch.roll(x, -1, dims=1)
|
||||
|
||||
all_stats = {}
|
||||
for width in self.WIDTHS:
|
||||
torch.manual_seed(42)
|
||||
mup_base = self.BASE_WIDTH if use_mup else 0
|
||||
model, _ = create_model(width, seq_len, vocab_size=vocab_size, mup_base_width=mup_base)
|
||||
all_stats[width] = get_activation_stats(model, x, y)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Compute slopes on log-log plot
|
||||
log_widths = np.log2(np.array(self.WIDTHS, dtype=float))
|
||||
slopes = {}
|
||||
for layer in all_stats[self.WIDTHS[0]]:
|
||||
values = [all_stats[w][layer] for w in self.WIDTHS]
|
||||
log_values = np.log2(np.array(values) + 1e-10)
|
||||
slope, _ = np.polyfit(log_widths, log_values, 1)
|
||||
slopes[layer] = slope
|
||||
|
||||
return slopes
|
||||
|
||||
def test_mup_activations_width_independent(self):
|
||||
"""Under muP, internal activation slopes should be near zero.
|
||||
|
||||
Note: output logits are excluded because at init (no training steps),
|
||||
muP logits are expected to decrease with width — the init scaling
|
||||
(std * sqrt(base/width)) combined with forward scaling (logits * base/width)
|
||||
gives O(1/width) initial logits. muP preserves update magnitude, not init magnitude.
|
||||
"""
|
||||
slopes = self._compute_slopes(use_mup=True)
|
||||
for layer, slope in slopes.items():
|
||||
if layer == 'logits':
|
||||
continue # skip — output logits have expected width dependence at init
|
||||
assert abs(slope) < 0.2, \
|
||||
f"muP activation slope for '{layer}' is {slope:.4f}, expected near 0"
|
||||
|
||||
def test_sp_activations_width_dependent(self):
|
||||
"""Under SP, at least some activation slopes should be nonzero (sanity check)."""
|
||||
slopes = self._compute_slopes(use_mup=False)
|
||||
max_slope = max(abs(s) for s in slopes.values())
|
||||
assert max_slope > 0.1, \
|
||||
f"SP max slope is only {max_slope:.4f}, expected SP to show width dependence"
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
|
||||
class TestMuPTransfer:
|
||||
"""Test that optimal LR transfers across widths under muP."""
|
||||
|
||||
WIDTHS = [128, 256]
|
||||
BASE_WIDTH = 128
|
||||
LR_MULTS = [0.25, 1.0, 4.0, 16.0]
|
||||
NUM_STEPS = 50
|
||||
|
||||
def _run_sweep(self, use_mup):
|
||||
"""Sweep LR multipliers across widths, return optimal mult per width."""
|
||||
device = torch.device('cuda')
|
||||
seq_len, vocab_size = 64, 256
|
||||
matrix_lr = 0.12
|
||||
embedding_lr = 6.0
|
||||
unembedding_lr = 0.12
|
||||
|
||||
torch.manual_seed(42)
|
||||
x = torch.randint(0, vocab_size, (8, seq_len), device=device)
|
||||
y = torch.roll(x, -1, dims=1)
|
||||
|
||||
optimal_mults = {}
|
||||
for width in self.WIDTHS:
|
||||
best_loss, best_mult = float('inf'), None
|
||||
for lr_mult in self.LR_MULTS:
|
||||
torch.manual_seed(42)
|
||||
mup_base = self.BASE_WIDTH if use_mup else 0
|
||||
model, _ = create_model(width, seq_len, vocab_size=vocab_size, mup_base_width=mup_base)
|
||||
optimizer = model.setup_optimizer(
|
||||
matrix_lr=matrix_lr * lr_mult,
|
||||
embedding_lr=embedding_lr * lr_mult,
|
||||
unembedding_lr=unembedding_lr * lr_mult,
|
||||
weight_decay=0.0,
|
||||
use_mup=use_mup,
|
||||
base_width=self.BASE_WIDTH,
|
||||
)
|
||||
model.train()
|
||||
for _ in range(self.NUM_STEPS):
|
||||
loss = model(x, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
final_loss = loss.item()
|
||||
if final_loss < best_loss:
|
||||
best_loss = final_loss
|
||||
best_mult = lr_mult
|
||||
|
||||
del model, optimizer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
optimal_mults[width] = best_mult
|
||||
|
||||
return optimal_mults
|
||||
|
||||
def test_mup_lr_transfer(self):
|
||||
"""Under muP, optimal LR multiplier should be similar across widths."""
|
||||
optimal = self._run_sweep(use_mup=True)
|
||||
mults = list(optimal.values())
|
||||
spread = np.log2(max(mults)) - np.log2(min(mults))
|
||||
assert spread <= 2.0, \
|
||||
f"muP LR spread is {spread:.1f} log2 (optimal mults: {optimal}), expected <= 2.0"
|
||||
Loading…
Reference in New Issue
Block a user