mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 05:35:19 +00:00
Merge 5c92dd02cb into 1cd94d768f
This commit is contained in:
commit
641b53b73f
153
muP_changes.md
Normal file
153
muP_changes.md
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
# muP Adaptation for Muon+AdamW in nanochat
|
||||
|
||||
## Context
|
||||
|
||||
Standard muP (Yang et al., "Tensor Programs V", arXiv:2203.03466) was derived for SGD and Adam optimizers. nanochat uses a **mixed optimizer**: Muon (with Polar Express orthogonalization) for transformer hidden weights, and AdamW for embeddings, scalars, and the output head (lm_head). This document describes the adaptations required and the empirical evidence behind them.
|
||||
|
||||
## The One Essential muP Ingredient
|
||||
|
||||
The output logit scaling in the forward pass is the core of muP and remains unchanged:
|
||||
|
||||
```python
|
||||
# gpt.py forward()
|
||||
logits = self.lm_head(x)
|
||||
if self.config.mup_base_width > 0:
|
||||
logits = logits * (self.config.mup_base_width / self.config.n_embd)
|
||||
```
|
||||
|
||||
Without this, logit magnitudes grow as O(√width) because the lm_head dot product sums over `n_embd` terms. The multiplier `base_width / n_embd` (= `1/m_d` where `m_d = width/base`) keeps logits O(1) at all widths. This is what enables hyperparameter transfer — it's the mechanism that makes the loss landscape shape-invariant across widths.
|
||||
|
||||
## What We Changed (and Why)
|
||||
|
||||
### Change 1: Removed output LR scaling
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
output_lr_scale = base_width / model_dim # e.g., 128/1024 = 0.125
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
output_lr_scale = 1.0 # No width-dependent LR scaling for lm_head
|
||||
```
|
||||
|
||||
#### Why the paper's prescription doesn't apply here
|
||||
|
||||
The paper (Table 8, "MUP" row) prescribes output layer LR ∝ `1/width`. The reasoning: in vanilla SGD, the lm_head gradient magnitude scales with width, so the LR must compensate. For Adam, the second moment normalizes gradients, but the paper still prescribes `1/width` because the *signal-to-noise ratio* of the Adam update changes with width.
|
||||
|
||||
However, this analysis doesn't account for the output logit scaling. Here's the interaction:
|
||||
|
||||
1. **Forward pass**: `logits = (base/width) × h @ W_out^T`
|
||||
2. **Backward pass**: `∂L/∂W_out = (base/width) × (∂L/∂logits)^T @ h`
|
||||
— the gradient already carries a `base/width` factor from the chain rule through the output multiplier
|
||||
3. **Adam step**: Adam normalizes by `√(E[grad²])`, which is O(base/width). The normalized step is O(1).
|
||||
4. **LR application**: If LR is also scaled by `base/width`, the effective update becomes O(base/width).
|
||||
5. **Effect on logits**: `Δlogits = (base/width) × h @ ΔW^T`, contributing another `base/width` factor.
|
||||
|
||||
**Net effect**: The logit change per step scales as O((base/width)²) — quadratic suppression. At width=1024 with base=128, this is a **64× reduction** in the effective output learning rate. The lm_head is barely learning.
|
||||
|
||||
#### Empirical evidence
|
||||
|
||||
Using `--sweep-mode adamw-only` (sweep only AdamW LR, hold Muon fixed):
|
||||
|
||||
| Width | Old muP optimal mult | Fixed muP optimal mult |
|
||||
|-------|---------------------|----------------------|
|
||||
| 128 | 32 | 32 |
|
||||
| 256 | 64 | 32 |
|
||||
| 512 | 128 | 32 |
|
||||
| 1024 | 256 | 32 |
|
||||
|
||||
**Old muP**: Optimal multiplier doubles with each width doubling (spread = 3.0 log2). The sweep is perfectly compensating for the over-reduction — the optimizer needs `m_d` times more LR to undo the `1/m_d` scaling.
|
||||
|
||||
**Fixed muP**: Optimal multiplier = 32 at all widths (spread = 0.0 log2). Perfect transfer.
|
||||
|
||||
### Change 2: Set Muon LR exponent to 0
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
hidden_lr_scale = base_width / model_dim # 1/m_d scaling for Muon hidden weights
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
hidden_lr_scale = (base_width / model_dim) ** muon_lr_exponent # default exponent = 0.0 → scale = 1.0
|
||||
```
|
||||
|
||||
#### Why standard muP LR scaling is redundant for Muon
|
||||
|
||||
The paper prescribes hidden layer LR ∝ `1/fan_in` = `base/width` for Adam. This compensates for Adam updates scaling with fan_in: with n_embd input dimensions, each element of the update is O(1/√n_embd) after Adam normalization, but the net change to the residual stream (summing over n_embd) is O(√n_embd). The `1/width` LR tames this.
|
||||
|
||||
**Muon doesn't have this problem.** Muon's Polar Express orthogonalization produces an update with `||update||_F ≈ 1` regardless of matrix dimensions. The update's Frobenius norm is O(1), and its contribution to the residual stream is also O(1) — it doesn't grow with width. Applying an additional `1/width` factor makes the update O(1/width), which *vanishes* at large width.
|
||||
|
||||
#### Empirical evidence
|
||||
|
||||
We tested three exponents with `--sweep-mode all`:
|
||||
|
||||
| muon_lr_exponent | muP optimal LR spread (log2) |
|
||||
|-----------------|------------------------------|
|
||||
| 0.0 | 2.0 |
|
||||
| 0.5 | 3.0 |
|
||||
| 1.0 | 2.0 |
|
||||
|
||||
Exponents 0.0 and 1.0 give **identical spread** (2.0). The Muon LR exponent literally doesn't matter — Polar Express dominates the update magnitude regardless of LR scaling. We default to 0.0 (no scaling) as the simplest correct choice.
|
||||
|
||||
(The spread of 2.0 in these experiments was caused by the output LR scaling bug, which was still active. After fixing Change 1, the overall spread dropped to 1.0 for all exponents.)
|
||||
|
||||
## What Remains Unchanged
|
||||
|
||||
| Component | Value | Paper requirement | Status |
|
||||
|-----------|-------|-------------------|--------|
|
||||
| Output logit scaling | `logits *= base/width` | Required | ✅ Correct |
|
||||
| Embedding LR | No width scaling | Constant with width | ✅ Correct |
|
||||
| lm_head init std | `0.02` (flat, no width scaling) | See note below | ✅ Correct |
|
||||
| Weight decay | Not width-scaled | Constant with width | ✅ Correct |
|
||||
| Momentum (Adam β₁, Muon) | Not width-scaled | Constant with width | ✅ Correct |
|
||||
| c_proj init | Non-zero uniform, std=√(3/n_embd) | Paper recommends zero | ⚠️ Intentional divergence |
|
||||
|
||||
**On lm_head init**: The paper prescribes width-scaled init (`std ∝ 1/√width`) to keep initial logit magnitudes O(1). We previously used `0.001 × √(base/width)`. However, the forward-pass logit scaling (`logits *= base/width`) already suppresses logit magnitudes at large widths. The width-scaled init was double-compensating — initial logits were O(base/width) instead of O(1), making the lm_head start too quiet at large widths. We now use a flat `std = 0.02` which, combined with the forward-pass scaling, produces well-behaved initial logits at all widths.
|
||||
|
||||
**On c_proj init**: The paper recommends zero-initializing output projections (attn c_proj, MLP c_proj) for cleaner transfer. nanochat uses non-zero init because zero init causes vanishing attention/FFN outputs when combined with Muon's LR dynamics — the first Muon update from a zero matrix produces an orthogonal matrix with O(LR) norm, which is too small when LR is already small. This is a known interaction between Muon and residual-stream architectures; the non-zero init provides a stable starting point.
|
||||
|
||||
## Summary: muP for Muon+AdamW
|
||||
|
||||
For a mixed Muon+AdamW optimizer, muP simplifies dramatically:
|
||||
|
||||
| Parameter group | muP prescription | Reason |
|
||||
|----------------|-----------------|--------|
|
||||
| **Output logits** | `logits *= base/width` in forward | The essential ingredient — makes loss landscape shape-invariant |
|
||||
| **lm_head init** | `std = 0.02` (flat, no width scaling) | Forward-pass logit scaling already handles width independence; width-scaled init double-compensates |
|
||||
| **lm_head LR** | No width scaling | Logit scaling already propagates into gradient; Adam normalizes; additional LR scaling over-reduces |
|
||||
| **Muon (hidden) LR** | No width scaling | Polar Express makes `||update||_F ≈ 1` regardless of width |
|
||||
| **Embedding LR** | No width scaling | Standard muP (embeddings are lookup tables, not matrix multiplies) |
|
||||
| **Scalar LR** | No width scaling | Standard muP |
|
||||
|
||||
**The punchline**: With Muon+AdamW, muP reduces to scaling output logits by `base/width` in the forward pass. No LR scaling or width-dependent init is needed anywhere — Muon's orthogonalization and Adam's second-moment normalization both already produce width-independent updates.
|
||||
|
||||
## Verification
|
||||
|
||||
```bash
|
||||
# Full transfer check (should show muP spread ≤ 1.0, SP spread ≥ 2.0)
|
||||
python -m scripts.mup_transfer_check --compare --widths 128,256,512,1024,2048 \
|
||||
--steps 50 --num-batches 200 --save-dir temp/mup_transfer
|
||||
|
||||
# Coordinate check (activation magnitudes should be flat across widths for muP)
|
||||
python -m scripts.mup_coord_check --compare --steps 10 --detailed --save-dir temp/mup_coord
|
||||
|
||||
# Automated tests
|
||||
python -m pytest tests/test_mup.py -v
|
||||
```
|
||||
|
||||
## Files Changed
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `nanochat/gpt.py` | `output_lr_scale`: `base/width` → `1.0`; added `muon_lr_exponent` param (default `0.0`); lm_head init: `0.001 × √(base/width)` → flat `0.02` for muP (removed double-compensation); updated comments |
|
||||
| `scripts/mup_coord_check.py` | Added `--detailed` flag (grad norms, update norms, attn logit magnitudes), `--muon-lr-exponent`; switched to float32 (disabled bfloat16 autocast) for numerical precision |
|
||||
| `scripts/mup_transfer_check.py` | Wider default LR range (1024×), `--sweep-mode {all,muon-only,adamw-only}`, `--num-random-trials`, `--num-batches`, `--sweep-init-scale`, `--sweep-output-mult`, `--muon-lr-exponent`, default steps 100→200; switched to float32 for numerical precision |
|
||||
|
||||
## References
|
||||
|
||||
- Yang et al., "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer", arXiv:2203.03466 (2022). Sections B.1 (muP table), C.1 (Frobenius-normalizing optimizers), F (GPT-3 experiments).
|
||||
- EleutherAI muP blog: https://blog.eleuther.ai/mutransfer/
|
||||
- Polar Express: Amsel et al., arXiv:2505.16932 (2025).
|
||||
- Muon: https://kellerjordan.github.io/posts/muon/
|
||||
|
|
@ -37,6 +37,9 @@ class GPTConfig:
|
|||
# Characters: L=long (full context), S=short (quarter context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "SSSL"
|
||||
# muP (Maximal Update Parametrization): set > 0 to enable. Value is the base/proxy width.
|
||||
# Enables: non-zero c_proj init scaled as 1/sqrt(m_d), output logit scaling by base_width/n_embd.
|
||||
mup_base_width: int = 0
|
||||
|
||||
|
||||
def norm(x):
|
||||
|
|
@ -203,20 +206,22 @@ class GPT(nn.Module):
|
|||
"""
|
||||
Initialize the full model in this one function for maximum clarity.
|
||||
|
||||
wte (embedding): normal, std=1.0
|
||||
wte (embedding): normal, std=0.8
|
||||
lm_head: normal, std=0.001
|
||||
for each block:
|
||||
attn.c_q: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_k: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_v: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_proj: zeros
|
||||
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
||||
mlp.c_proj: zeros
|
||||
attn.c_proj: zeros (SP) or uniform (muP)
|
||||
mlp.c_fc: uniform, std=0.5/sqrt(n_embd)
|
||||
mlp.c_proj: zeros (SP) or uniform (muP)
|
||||
"""
|
||||
|
||||
# Embedding and unembedding
|
||||
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
|
||||
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
||||
# muP uses 0.02 for stronger initial logit signal; forward-pass scaling handles width independence
|
||||
lm_head_std = 0.02 if self.config.mup_base_width > 0 else 0.001
|
||||
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std)
|
||||
|
||||
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
||||
n_embd = self.config.n_embd
|
||||
|
|
@ -225,9 +230,15 @@ class GPT(nn.Module):
|
|||
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
||||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
if self.config.mup_base_width > 0:
|
||||
# muP: output projections use same scale as hidden weights (std = sigma_base/sqrt(m_d))
|
||||
# Zero init causes attn/FFN outputs to vanish as width increases with muP LR scaling
|
||||
torch.nn.init.uniform_(block.attn.c_proj.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.mlp.c_proj.weight, -s, s)
|
||||
else:
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight) # SP: projections are zero
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
|
||||
# Per-layer scalars
|
||||
# Per-layer resid init: stronger residual at early layers, weaker at deep layers
|
||||
|
|
@ -261,7 +272,6 @@ class GPT(nn.Module):
|
|||
ve.to(dtype=COMPUTE_DTYPE)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
# autodetect the device from model embeddings
|
||||
if device is None:
|
||||
device = self.transformer.wte.weight.device
|
||||
|
|
@ -366,7 +376,7 @@ class GPT(nn.Module):
|
|||
'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5, use_mup=False, base_width=256, muon_lr_exponent=0.0):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
|
||||
|
|
@ -380,16 +390,45 @@ class GPT(nn.Module):
|
|||
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
# Compute LR scaling factors based on mode
|
||||
if use_mup:
|
||||
# muP for mixed Muon+AdamW optimizer:
|
||||
#
|
||||
# Key insight: the output logit scaling (logits *= base/width in forward()) already
|
||||
# propagates a base/width factor into the lm_head gradient. Adam normalizes gradient
|
||||
# magnitude via its second moment, but applying ANOTHER base/width to the LR compounds
|
||||
# the scaling, making lm_head updates O(base²/width²) instead of O(1). So we do NOT
|
||||
# apply width-dependent LR scaling to the output layer — the logit scaling alone suffices.
|
||||
#
|
||||
# For Muon (hidden weights): Polar Express orthogonalization normalizes ||update||_F ≈ 1
|
||||
# regardless of width, making the update already O(1). No width-dependent LR scaling
|
||||
# is needed (empirically confirmed: exponent 0 and 1 give identical transfer behavior).
|
||||
#
|
||||
# The muon_lr_exponent parameter is kept for experimentation but defaults to 0.
|
||||
width_ratio = base_width / model_dim # e.g., 128/1024 = 0.125
|
||||
dmodel_lr_scale = 1.0 # muP: no sqrt-based scaling (width scaling handled explicitly)
|
||||
emb_lr_scale = 1.0 # Embeddings: NO width scaling (standard muP)
|
||||
hidden_lr_scale = width_ratio ** muon_lr_exponent # Hidden (Muon): default 0 = no scaling
|
||||
output_lr_scale = 1.0 # Output (AdamW): NO LR scaling (logit scaling in forward suffices)
|
||||
assert self.config.mup_base_width == base_width, \
|
||||
f"mup_base_width mismatch: GPTConfig has {self.config.mup_base_width}, but setup_optimizer got base_width={base_width}. " \
|
||||
f"Set mup_base_width={base_width} in GPTConfig at construction time."
|
||||
print0(f"muP scaling: base_width={base_width}, model_dim={model_dim}, width_ratio={width_ratio:.6f}, muon_lr_exp={muon_lr_exponent}")
|
||||
else:
|
||||
# Standard (SP): scale AdamW params by 1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
emb_lr_scale = dmodel_lr_scale
|
||||
hidden_lr_scale = 1.0 # Muon params: no scaling in SP mode
|
||||
output_lr_scale = dmodel_lr_scale
|
||||
print0(f"Standard scaling: dmodel_lr_scale={dmodel_lr_scale:.6f}")
|
||||
|
||||
# Build param_groups with all required fields explicit
|
||||
# Per-group betas and weight decay tuned by Andrej for SP; muP uses same base HPs
|
||||
param_groups = [
|
||||
# AdamW groups (embeddings, lm_head, scalars)
|
||||
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01),
|
||||
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001),
|
||||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
|
||||
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * output_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01),
|
||||
dict(kind='adamw', params=embedding_params, lr=embedding_lr * emb_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001),
|
||||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * emb_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
|
||||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
|
||||
|
|
@ -398,7 +437,7 @@ class GPT(nn.Module):
|
|||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
group_params = [p for p in matrix_params if p.shape == shape]
|
||||
param_groups.append(dict(
|
||||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
kind='muon', params=group_params, lr=matrix_lr * hidden_lr_scale,
|
||||
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
|
||||
))
|
||||
|
||||
|
|
@ -462,6 +501,11 @@ class GPT(nn.Module):
|
|||
# Forward the lm_head (compute logits)
|
||||
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
||||
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
||||
if self.config.mup_base_width > 0:
|
||||
# muP: scale output logits by base_width/n_embd (= 1/m_d).
|
||||
# Without this, logits grow with width because the lm_head dot product sums over n_embd terms.
|
||||
# 1/sqrt(m_d) only corrects at init; 1/m_d is required for all training steps (see Eleuther blog Fig 8-9).
|
||||
logits = logits * (self.config.mup_base_width / self.config.n_embd)
|
||||
logits = logits[..., :self.config.vocab_size] # slice to remove padding
|
||||
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
||||
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
||||
|
|
|
|||
|
|
@ -67,6 +67,8 @@ parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate
|
|||
parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--use-mup", action="store_true", help="use muP (Maximal Update Parameterization) LR scaling")
|
||||
parser.add_argument("--base-width", type=int, default=256, help="base width for muP LR scaling (LRs tuned at this width)")
|
||||
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
|
|
@ -137,6 +139,7 @@ def build_model_meta(depth):
|
|||
sequence_len=args.max_seq_len, vocab_size=vocab_size,
|
||||
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
||||
window_pattern=args.window_pattern,
|
||||
mup_base_width=args.base_width if args.use_mup else 0,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_meta = GPT(config)
|
||||
|
|
@ -312,6 +315,9 @@ optimizer = model.setup_optimizer(
|
|||
# Muon hyperparameters
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=weight_decay_scaled,
|
||||
# muP scaling
|
||||
use_mup=args.use_mup,
|
||||
base_width=args.base_width,
|
||||
)
|
||||
|
||||
if resuming:
|
||||
|
|
|
|||
722
scripts/mup_coord_check.py
Normal file
722
scripts/mup_coord_check.py
Normal file
|
|
@ -0,0 +1,722 @@
|
|||
"""
|
||||
muP Coordinate Check for nanochat
|
||||
|
||||
This script validates muP implementation by checking that activation magnitudes
|
||||
are independent of model width. Based on EleutherAI's nanoGPT-mup and Microsoft's
|
||||
mup library.
|
||||
|
||||
Reference: https://blog.eleuther.ai/mutransfer/
|
||||
Reference: Yang et al., "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot
|
||||
Hyperparameter Transfer" (arXiv:2203.03466), Sections B.1 and F.
|
||||
|
||||
Usage:
|
||||
python -m scripts.mup_coord_check --widths 128,256,512,1024 --steps 10
|
||||
python -m scripts.mup_coord_check --use-mup --widths 128,256,512,1024
|
||||
python -m scripts.mup_coord_check --compare --detailed
|
||||
python -m scripts.mup_coord_check --compare --muon-lr-exponent 0.5
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
os.environ["NANOCHAT_DTYPE"] = "float32"
|
||||
import torch
|
||||
import torch._dynamo
|
||||
torch._dynamo.config.disable = True
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
import os
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
|
||||
|
||||
def load_batch(batch_size: int, seq_len: int, device: torch.device):
|
||||
"""Load a single batch from the nanochat training pipeline.
|
||||
Falls back to random data if the tokenizer/dataset isn't available."""
|
||||
try:
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
|
||||
tokenizer = get_tokenizer()
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
loader = tokenizing_distributed_data_loader_bos_bestfit(
|
||||
tokenizer, batch_size, seq_len, split="train", device=device,
|
||||
)
|
||||
x, y = next(loader)
|
||||
print(f"Loaded real training data (vocab_size={vocab_size})")
|
||||
return x, y, vocab_size
|
||||
except Exception as e:
|
||||
print(f"Could not load training data ({e}), using random tokens")
|
||||
vocab_size = 32768
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(42)
|
||||
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, generator=rng)
|
||||
y = torch.roll(x, -1, dims=1)
|
||||
y[:, -1] = -1
|
||||
return x, y, vocab_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class CoordCheckConfig:
|
||||
widths: List[int]
|
||||
num_steps: int = 10
|
||||
batch_size: int = 4
|
||||
seq_len: int = 128
|
||||
vocab_size: int = 32768
|
||||
n_layer: int = 2
|
||||
seed: int = 42
|
||||
use_mup: bool = False
|
||||
base_width: int = 128
|
||||
# Learning rates (tuned at base_width=128)
|
||||
matrix_lr: float = 0.12
|
||||
embedding_lr: float = 6.0
|
||||
unembedding_lr: float = 0.12
|
||||
# Detailed diagnostics
|
||||
detailed: bool = False
|
||||
# Muon LR exponent: 1.0 = base/width (standard muP), 0.5 = sqrt(base/width)
|
||||
# Paper Section C.1: Frobenius-normalizing optimizers may need exponent 0.5
|
||||
muon_lr_exponent: float = 0.0
|
||||
|
||||
|
||||
class ActivationRecorder:
|
||||
"""Records activation statistics during forward pass using hooks."""
|
||||
|
||||
def __init__(self, detailed: bool = False):
|
||||
self.stats: Dict[str, List[float]] = defaultdict(list)
|
||||
self.hooks = []
|
||||
self.detailed = detailed
|
||||
|
||||
def _get_stat(self, tensor: torch.Tensor) -> float:
|
||||
"""Compute mean absolute value (l1 norm per element)."""
|
||||
if tensor is None:
|
||||
return 0.0
|
||||
if tensor.dtype == torch.bool:
|
||||
return tensor.float().abs().mean().item()
|
||||
return tensor.float().abs().mean().item()
|
||||
|
||||
def _make_hook(self, name: str):
|
||||
"""Create a forward hook that records output statistics."""
|
||||
def hook(module, input, output):
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
if output is not None and isinstance(output, torch.Tensor):
|
||||
self.stats[name].append(self._get_stat(output))
|
||||
return hook
|
||||
|
||||
def _make_attn_logit_hook(self, name: str, n_head: int, n_kv_head: int, head_dim: int):
|
||||
"""Create a hook on c_k that computes pre-softmax attention logit magnitudes.
|
||||
|
||||
We hook onto c_k's forward, then use the most recent c_q output to compute
|
||||
q @ k^T / sqrt(d) for a single batch element to measure attention logit scale.
|
||||
"""
|
||||
# We'll store q output and compute logits when k is available
|
||||
self._last_q = None
|
||||
|
||||
def q_hook(module, input, output):
|
||||
self._last_q = output.detach()
|
||||
|
||||
def k_hook(module, input, output):
|
||||
if self._last_q is None:
|
||||
return
|
||||
q = self._last_q
|
||||
k = output.detach()
|
||||
B, T, _ = q.shape
|
||||
q = q[0:1].view(1, T, n_head, head_dim)
|
||||
k = k[0:1].view(1, T, n_kv_head, head_dim)
|
||||
# Apply QK norm (same as model)
|
||||
q = F.rms_norm(q, (q.size(-1),))
|
||||
k = F.rms_norm(k, (k.size(-1),))
|
||||
# Expand k for GQA
|
||||
if n_head != n_kv_head:
|
||||
k = k.repeat_interleave(n_head // n_kv_head, dim=2)
|
||||
# Compute logits: q @ k^T / sqrt(d) — just for first few positions
|
||||
T_sub = min(T, 32)
|
||||
q_sub = q[:, :T_sub].transpose(1, 2) # (1, H, T_sub, D)
|
||||
k_sub = k[:, :T_sub].transpose(1, 2) # (1, H, T_sub, D)
|
||||
logits = torch.matmul(q_sub, k_sub.transpose(-2, -1)) / (head_dim ** 0.5)
|
||||
self.stats[name].append(logits.float().abs().mean().item())
|
||||
self._last_q = None
|
||||
|
||||
return q_hook, k_hook
|
||||
|
||||
def register_hooks(self, model: GPT) -> None:
|
||||
"""Register forward hooks on key layers."""
|
||||
# Embedding
|
||||
h = model.transformer.wte.register_forward_hook(self._make_hook('word embedding'))
|
||||
self.hooks.append(h)
|
||||
|
||||
# Each transformer block
|
||||
for i, block in enumerate(model.transformer.h):
|
||||
# Attention output
|
||||
h = block.attn.c_proj.register_forward_hook(self._make_hook(f'attn output.{i}'))
|
||||
self.hooks.append(h)
|
||||
# MLP output
|
||||
h = block.mlp.c_proj.register_forward_hook(self._make_hook(f'FFN output.{i}'))
|
||||
self.hooks.append(h)
|
||||
|
||||
# Detailed: attention logit magnitudes
|
||||
if self.detailed:
|
||||
n_head = block.attn.n_head
|
||||
n_kv_head = block.attn.n_kv_head
|
||||
head_dim = block.attn.head_dim
|
||||
q_hook, k_hook = self._make_attn_logit_hook(
|
||||
f'attn logits.{i}', n_head, n_kv_head, head_dim)
|
||||
h1 = block.attn.c_q.register_forward_hook(q_hook)
|
||||
h2 = block.attn.c_k.register_forward_hook(k_hook)
|
||||
self.hooks.extend([h1, h2])
|
||||
|
||||
# Output logits: hook on lm_head, but apply muP scaling to match what forward() does
|
||||
mup_base = model.config.mup_base_width
|
||||
n_embd = model.config.n_embd
|
||||
def logit_hook(module, input, output):
|
||||
if output is not None and isinstance(output, torch.Tensor):
|
||||
scaled = output
|
||||
if mup_base > 0:
|
||||
scaled = output * (mup_base / n_embd)
|
||||
self.stats['output logits'].append(self._get_stat(scaled))
|
||||
h = model.lm_head.register_forward_hook(logit_hook)
|
||||
self.hooks.append(h)
|
||||
|
||||
def remove_hooks(self) -> None:
|
||||
"""Remove all registered hooks."""
|
||||
for h in self.hooks:
|
||||
h.remove()
|
||||
self.hooks = []
|
||||
|
||||
def get_step_stats(self) -> Dict[str, float]:
|
||||
"""Get mean stats for the current step and reset."""
|
||||
step_stats = {}
|
||||
for name, values in self.stats.items():
|
||||
if values:
|
||||
step_stats[name] = np.mean(values)
|
||||
self.stats = defaultdict(list)
|
||||
return step_stats
|
||||
|
||||
|
||||
def create_model(width: int, config: CoordCheckConfig, device: torch.device, mup_base_width: int = 0) -> Tuple[GPT, GPTConfig]:
|
||||
"""Create a model with the specified width."""
|
||||
head_dim = 64
|
||||
n_head = max(1, width // head_dim)
|
||||
actual_width = n_head * head_dim
|
||||
|
||||
gpt_config = GPTConfig(
|
||||
sequence_len=config.seq_len,
|
||||
vocab_size=config.vocab_size,
|
||||
n_layer=config.n_layer,
|
||||
n_head=n_head,
|
||||
n_kv_head=n_head,
|
||||
n_embd=actual_width,
|
||||
window_pattern="L",
|
||||
mup_base_width=mup_base_width,
|
||||
)
|
||||
|
||||
with torch.device('meta'):
|
||||
model = GPT(gpt_config)
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
|
||||
return model, gpt_config
|
||||
|
||||
|
||||
def setup_optimizer_mup(model: GPT, config: CoordCheckConfig, width: int):
|
||||
"""Set up optimizer with muP scaling using the native use_mup flag."""
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=config.unembedding_lr,
|
||||
embedding_lr=config.embedding_lr,
|
||||
matrix_lr=config.matrix_lr,
|
||||
weight_decay=0.0,
|
||||
use_mup=True,
|
||||
base_width=config.base_width,
|
||||
muon_lr_exponent=config.muon_lr_exponent,
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
def setup_optimizer_sp(model: GPT, config: CoordCheckConfig, width: int):
|
||||
"""Set up optimizer with standard parameterization (current nanochat)."""
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=config.unembedding_lr,
|
||||
embedding_lr=config.embedding_lr,
|
||||
matrix_lr=config.matrix_lr,
|
||||
weight_decay=0.0,
|
||||
use_mup=False,
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
def record_detailed_stats(model: GPT, results: Dict, width: int, step: int):
|
||||
"""Record weight update norms and gradient norms per parameter group."""
|
||||
for name, p in model.named_parameters():
|
||||
if p.grad is None:
|
||||
continue
|
||||
# Simplify name for display
|
||||
short_name = name.replace('transformer.', '').replace('.weight', '')
|
||||
# Gradient norm
|
||||
grad_norm = p.grad.float().norm().item()
|
||||
results['detailed_stats'][width][f'grad norm: {short_name}'].append(grad_norm)
|
||||
|
||||
|
||||
def record_weight_update_norms(model: GPT, params_before: Dict[str, torch.Tensor],
|
||||
results: Dict, width: int):
|
||||
"""Record ||delta_W|| for each parameter after optimizer step."""
|
||||
for name, p in model.named_parameters():
|
||||
if name not in params_before:
|
||||
continue
|
||||
short_name = name.replace('transformer.', '').replace('.weight', '')
|
||||
delta = (p.data.float() - params_before[name]).norm().item()
|
||||
results['detailed_stats'][width][f'update norm: {short_name}'].append(delta)
|
||||
|
||||
|
||||
def run_coord_check(config: CoordCheckConfig, device: torch.device,
|
||||
x: torch.Tensor, y: torch.Tensor) -> Dict:
|
||||
"""Run coordinate check across all widths."""
|
||||
results = {
|
||||
'widths': [],
|
||||
'steps': list(range(config.num_steps)),
|
||||
'stats': defaultdict(lambda: defaultdict(list)),
|
||||
'losses': defaultdict(list),
|
||||
'detailed_stats': defaultdict(lambda: defaultdict(list)),
|
||||
}
|
||||
|
||||
for width in config.widths:
|
||||
print(f"\nTraining width={width}...")
|
||||
|
||||
torch.manual_seed(config.seed)
|
||||
|
||||
mup_base_width = config.base_width if config.use_mup else 0
|
||||
model, gpt_config = create_model(width, config, device, mup_base_width=mup_base_width)
|
||||
actual_width = gpt_config.n_embd
|
||||
results['widths'].append(actual_width)
|
||||
|
||||
if config.use_mup:
|
||||
optimizer = setup_optimizer_mup(model, config, actual_width)
|
||||
else:
|
||||
optimizer = setup_optimizer_sp(model, config, actual_width)
|
||||
|
||||
recorder = ActivationRecorder(detailed=config.detailed)
|
||||
recorder.register_hooks(model)
|
||||
|
||||
model.train()
|
||||
|
||||
for step in range(config.num_steps):
|
||||
with torch.amp.autocast(device_type='cuda', dtype=torch.float32, enabled=False):
|
||||
loss = model(x, y)
|
||||
|
||||
results['losses'][actual_width].append(loss.item())
|
||||
|
||||
step_stats = recorder.get_step_stats()
|
||||
for layer, value in step_stats.items():
|
||||
results['stats'][actual_width][layer].append(value)
|
||||
|
||||
if step == 0:
|
||||
print(f" Step {step}: loss={loss.item():.4f}, layers={list(step_stats.keys())}")
|
||||
|
||||
# Record gradient norms before step (detailed mode)
|
||||
loss.backward()
|
||||
|
||||
if config.detailed:
|
||||
record_detailed_stats(model, results, actual_width, step)
|
||||
# Snapshot params before optimizer step to compute update norms
|
||||
params_before = {name: p.data.float().clone()
|
||||
for name, p in model.named_parameters()
|
||||
if p.grad is not None}
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if config.detailed:
|
||||
record_weight_update_norms(model, params_before, results, actual_width)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
print(f" Final loss: {loss.item():.4f}")
|
||||
|
||||
recorder.remove_hooks()
|
||||
del model, optimizer
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def plot_coord_check(results: Dict, config: CoordCheckConfig, save_path: Optional[str] = None):
|
||||
"""Plot coordinate check: one subplot per layer, x=width (log2), y=mean |activation|, lines=steps."""
|
||||
widths = results['widths']
|
||||
steps = results['steps']
|
||||
stats = results['stats']
|
||||
|
||||
layer_names = list(stats[widths[0]].keys())
|
||||
n_layers = len(layer_names)
|
||||
n_cols = 4
|
||||
n_rows = (n_layers + n_cols - 1) // n_cols
|
||||
|
||||
param_type = "muP" if config.use_mup else "SP"
|
||||
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
|
||||
axes = np.array(axes).flatten()
|
||||
|
||||
step_colors = plt.cm.plasma(np.linspace(0, 1, len(steps)))
|
||||
|
||||
for i, layer in enumerate(layer_names):
|
||||
ax = axes[i]
|
||||
for s, step in enumerate(steps):
|
||||
values = [stats[w][layer][s] for w in widths]
|
||||
ax.plot(widths, values, 'o-', color=step_colors[s], linewidth=1.5,
|
||||
label=f'step {step}' if i == 0 else None)
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_xticks(widths)
|
||||
ax.set_xticklabels(widths, fontsize=7)
|
||||
ax.set_title(layer, fontsize=9)
|
||||
ax.set_xlabel('Width')
|
||||
ax.set_ylabel('Mean |activation|')
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
axes[0].legend(fontsize=7, loc='best')
|
||||
|
||||
for i in range(n_layers, len(axes)):
|
||||
axes[i].set_visible(False)
|
||||
|
||||
fig.suptitle(f'Coordinate Check ({param_type}): Activation Magnitude vs Width', fontsize=14)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved plot to {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_loss_curves(results: Dict, config: CoordCheckConfig, title: str = "", save_path: Optional[str] = None):
|
||||
"""Plot loss curves across widths to verify HP transfer."""
|
||||
widths = results['widths']
|
||||
steps = results['steps']
|
||||
losses = results['losses']
|
||||
|
||||
fig, ax = plt.subplots(figsize=(5 * 2, 4))
|
||||
colors = plt.cm.viridis(np.linspace(0, 1, len(widths)))
|
||||
|
||||
for i, w in enumerate(widths):
|
||||
ax.plot(steps, losses[w], label=f'width={w}', color=colors[i], linewidth=2)
|
||||
|
||||
ax.set_yscale('log')
|
||||
ax.set_xlabel('Step')
|
||||
ax.set_ylabel('Loss')
|
||||
ax.set_title(f'Loss Curves Across Widths{" - " + title if title else ""}')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Add annotation for final loss spread
|
||||
final_losses = [losses[w][-1] for w in widths]
|
||||
spread = max(final_losses) - min(final_losses)
|
||||
ax.annotate(f'Final loss spread: {spread:.4f}', xy=(0.7, 0.95), xycoords='axes fraction', fontsize=10)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved loss curves to {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_comparison(results_sp: Dict, results_mup: Dict, config: CoordCheckConfig, save_path: Optional[str] = None):
|
||||
"""Plot SP vs muP: one subplot per layer (left=SP, right=muP), x=width (log2), y=mean |activation|, lines=steps."""
|
||||
widths = results_sp['widths']
|
||||
steps = results_sp['steps']
|
||||
|
||||
layer_names = list(results_sp['stats'][widths[0]].keys())
|
||||
n_layers = len(layer_names)
|
||||
|
||||
# n_layers activation rows + 1 loss row, 2 cols (SP | muP)
|
||||
n_rows, n_cols = n_layers + 1, 2
|
||||
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 3 * n_rows))
|
||||
|
||||
step_colors = plt.cm.plasma(np.linspace(0, 1, len(steps)))
|
||||
width_colors = plt.cm.viridis(np.linspace(0, 1, len(widths)))
|
||||
|
||||
for row, layer in enumerate(layer_names):
|
||||
# Shared y-axis range across SP and muP for this layer
|
||||
all_vals = [results_sp['stats'][w][layer][s] for w in widths for s in range(len(steps))] + \
|
||||
[results_mup['stats'][w][layer][s] for w in widths for s in range(len(steps))]
|
||||
y_min, y_max = min(all_vals) * 0.9, max(all_vals) * 1.1
|
||||
|
||||
for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]):
|
||||
ax = axes[row, col]
|
||||
for s, step in enumerate(steps):
|
||||
values = [results['stats'][w][layer][s] for w in widths]
|
||||
ax.plot(widths, values, 'o-', color=step_colors[s], linewidth=1.5,
|
||||
label=f'step {step}' if (row == 0 and col == 0) else None)
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_xticks(widths)
|
||||
ax.set_xticklabels(widths, fontsize=7)
|
||||
ax.set_ylim(y_min, y_max)
|
||||
ax.set_title(f'{label}: {layer}', fontsize=9)
|
||||
ax.set_xlabel('Width')
|
||||
ax.set_ylabel('Mean |activation|')
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
axes[0, 0].legend(fontsize=7, loc='best')
|
||||
|
||||
# Loss curves row (log scale so low-loss detail is visible)
|
||||
all_losses = [v for r in (results_sp, results_mup) for w in widths for v in r['losses'][w]]
|
||||
loss_min, loss_max = min(all_losses) * 0.9, max(all_losses) * 1.1
|
||||
|
||||
for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]):
|
||||
ax = axes[n_layers, col]
|
||||
for j, w in enumerate(widths):
|
||||
ax.plot(steps, results['losses'][w], label=f'w={w}', color=width_colors[j], linewidth=2)
|
||||
ax.set_yscale('log')
|
||||
ax.set_ylim(loss_min, loss_max)
|
||||
ax.set_xlabel('Step')
|
||||
ax.set_ylabel('Loss')
|
||||
ax.set_title(f'{label}: Loss Curves')
|
||||
ax.legend(fontsize=7)
|
||||
ax.grid(True, alpha=0.3)
|
||||
final_losses = [results['losses'][w][-1] for w in widths]
|
||||
spread = max(final_losses) - min(final_losses)
|
||||
ax.annotate(f'Spread: {spread:.4f}', xy=(0.65, 0.95), xycoords='axes fraction', fontsize=9)
|
||||
|
||||
fig.suptitle('Coordinate Check: SP vs muP', fontsize=14)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved comparison plot to {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_detailed(results: Dict, config: CoordCheckConfig, save_path: Optional[str] = None):
|
||||
"""Plot detailed diagnostics: gradient norms, weight update norms, attention logits."""
|
||||
widths = results['widths']
|
||||
detailed = results['detailed_stats']
|
||||
if not detailed or not detailed[widths[0]]:
|
||||
print("No detailed stats recorded. Use --detailed flag.")
|
||||
return
|
||||
|
||||
# Collect all detailed metric names
|
||||
metric_names = sorted(detailed[widths[0]].keys())
|
||||
|
||||
# Group by category
|
||||
categories = defaultdict(list)
|
||||
for name in metric_names:
|
||||
if name.startswith('grad norm:'):
|
||||
categories['Gradient Norms'].append(name)
|
||||
elif name.startswith('update norm:'):
|
||||
categories['Weight Update Norms'].append(name)
|
||||
elif name.startswith('attn logits'):
|
||||
categories['Attention Logit Magnitudes'].append(name)
|
||||
|
||||
for cat_name, names in categories.items():
|
||||
n = len(names)
|
||||
n_cols = min(4, n)
|
||||
n_rows = (n + n_cols - 1) // n_cols
|
||||
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
|
||||
if n == 1:
|
||||
axes = np.array([axes])
|
||||
axes = np.array(axes).flatten()
|
||||
|
||||
steps = results['steps']
|
||||
width_colors = plt.cm.viridis(np.linspace(0, 1, len(widths)))
|
||||
|
||||
for i, name in enumerate(names):
|
||||
ax = axes[i]
|
||||
for j, w in enumerate(widths):
|
||||
values = detailed[w].get(name, [])
|
||||
if values:
|
||||
ax.plot(range(len(values)), values, color=width_colors[j],
|
||||
linewidth=1.5, label=f'w={w}' if i == 0 else None)
|
||||
ax.set_title(name.split(': ', 1)[-1] if ': ' in name else name, fontsize=8)
|
||||
ax.set_xlabel('Step')
|
||||
ax.set_ylabel('Norm')
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_yscale('log')
|
||||
|
||||
for i in range(n, len(axes)):
|
||||
axes[i].set_visible(False)
|
||||
|
||||
axes[0].legend(fontsize=7, loc='best')
|
||||
param_type = "muP" if config.use_mup else "SP"
|
||||
fig.suptitle(f'{cat_name} ({param_type})', fontsize=14)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
cat_slug = cat_name.lower().replace(' ', '_')
|
||||
path = save_path.replace('.png', f'_{cat_slug}.png')
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved {cat_name} plot to {path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def compute_width_dependence(results: Dict) -> Dict[str, float]:
|
||||
"""Compute how much activations scale with width (slope on log-log plot)."""
|
||||
widths = np.array(results['widths'])
|
||||
log_widths = np.log2(widths)
|
||||
final_step = len(results['steps']) - 1
|
||||
|
||||
slopes = {}
|
||||
for layer in results['stats'][widths[0]].keys():
|
||||
values = [results['stats'][w][layer][final_step] for w in widths]
|
||||
log_values = np.log2(np.array(values) + 1e-10)
|
||||
slope, _ = np.polyfit(log_widths, log_values, 1)
|
||||
slopes[layer] = slope
|
||||
|
||||
return slopes
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='muP Coordinate Check')
|
||||
parser.add_argument('--widths', type=str, default='128,256,512,1024',
|
||||
help='Comma-separated list of widths to test')
|
||||
parser.add_argument('--steps', type=int, default=10,
|
||||
help='Number of training steps')
|
||||
parser.add_argument('--batch-size', type=int, default=4,
|
||||
help='Batch size')
|
||||
parser.add_argument('--seq-len', type=int, default=128,
|
||||
help='Sequence length')
|
||||
parser.add_argument('--n-layer', type=int, default=2,
|
||||
help='Number of transformer layers')
|
||||
parser.add_argument('--use-mup', action='store_true',
|
||||
help='Use muP learning rate scaling')
|
||||
parser.add_argument('--base-width', type=int, default=128,
|
||||
help='Base width for muP scaling')
|
||||
parser.add_argument('--compare', action='store_true',
|
||||
help='Run both SP and muP and compare')
|
||||
parser.add_argument('--save-dir', type=str, default=None,
|
||||
help='Directory to save plots')
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help='Random seed')
|
||||
parser.add_argument('--detailed', action='store_true',
|
||||
help='Record detailed diagnostics: gradient norms, weight update norms, '
|
||||
'attention logit magnitudes')
|
||||
parser.add_argument('--muon-lr-exponent', type=float, default=0.0,
|
||||
help='Muon LR exponent for muP: 1.0 = (base/width)^1 (standard muP), '
|
||||
'0.5 = (base/width)^0.5 (for Frobenius-normalizing optimizers, '
|
||||
'see Yang et al. Section C.1)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse widths
|
||||
widths = [int(w) for w in args.widths.split(',')]
|
||||
|
||||
# Setup device
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load a single batch of real training data (reused every step)
|
||||
x, y, vocab_size = load_batch(args.batch_size, args.seq_len, device)
|
||||
|
||||
# Create config
|
||||
config = CoordCheckConfig(
|
||||
widths=widths,
|
||||
num_steps=args.steps,
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
vocab_size=vocab_size,
|
||||
n_layer=args.n_layer,
|
||||
seed=args.seed,
|
||||
use_mup=args.use_mup,
|
||||
base_width=args.base_width,
|
||||
detailed=args.detailed,
|
||||
muon_lr_exponent=args.muon_lr_exponent,
|
||||
)
|
||||
|
||||
if args.compare:
|
||||
# Run both SP and muP
|
||||
print("\n" + "="*60)
|
||||
print("Running Standard Parameterization (SP)")
|
||||
print("="*60)
|
||||
config.use_mup = False
|
||||
results_sp = run_coord_check(config, device, x, y)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Running muP")
|
||||
if config.muon_lr_exponent != 1.0:
|
||||
print(f" (Muon LR exponent: {config.muon_lr_exponent})")
|
||||
print("="*60)
|
||||
config.use_mup = True
|
||||
results_mup = run_coord_check(config, device, x, y)
|
||||
|
||||
# Compute slopes
|
||||
print("\n" + "="*60)
|
||||
print("Width Dependence (slope on log-log plot)")
|
||||
print("Expected: ~0 for width-independent, positive = grows with width")
|
||||
print("="*60)
|
||||
|
||||
slopes_sp = compute_width_dependence(results_sp)
|
||||
slopes_mup = compute_width_dependence(results_mup)
|
||||
|
||||
print(f"\n{'Layer':<20} {'SP Slope':>12} {'muP Slope':>12}")
|
||||
print("-"*46)
|
||||
for layer in slopes_sp:
|
||||
print(f"{layer:<20} {slopes_sp[layer]:>12.4f} {slopes_mup[layer]:>12.4f}")
|
||||
|
||||
# Plot comparison
|
||||
save_path = None
|
||||
if args.save_dir:
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
save_path = os.path.join(args.save_dir, 'coord_check_comparison.png')
|
||||
plot_comparison(results_sp, results_mup, config, save_path)
|
||||
|
||||
# Plot detailed diagnostics if requested
|
||||
if config.detailed:
|
||||
for results, label in [(results_sp, 'SP'), (results_mup, 'muP')]:
|
||||
config.use_mup = (label == 'muP')
|
||||
detail_save = None
|
||||
if args.save_dir:
|
||||
detail_save = os.path.join(args.save_dir, f'detailed_{label.lower()}.png')
|
||||
plot_detailed(results, config, detail_save)
|
||||
|
||||
else:
|
||||
# Run single mode
|
||||
param_type = "muP" if config.use_mup else "SP"
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Running Coordinate Check ({param_type})")
|
||||
print(f"{'='*60}")
|
||||
print(f"Widths: {widths}")
|
||||
print(f"Steps: {config.num_steps}")
|
||||
print(f"Base width: {config.base_width}")
|
||||
if config.use_mup and config.muon_lr_exponent != 1.0:
|
||||
print(f"Muon LR exponent: {config.muon_lr_exponent}")
|
||||
|
||||
results = run_coord_check(config, device, x, y)
|
||||
|
||||
# Compute slopes
|
||||
slopes = compute_width_dependence(results)
|
||||
print("\n" + "="*60)
|
||||
print("Width Dependence (slope on log-log plot)")
|
||||
print("Expected for muP: ~0 (width-independent)")
|
||||
print("="*60)
|
||||
for layer, slope in slopes.items():
|
||||
status = "OK" if abs(slope) < 0.1 else "WARN"
|
||||
print(f" {layer}: {slope:+.4f} [{status}]")
|
||||
|
||||
# Loss curve analysis
|
||||
final_losses = [results['losses'][w][-1] for w in results['widths']]
|
||||
loss_spread = max(final_losses) - min(final_losses)
|
||||
print(f"\nFinal loss spread across widths: {loss_spread:.4f}")
|
||||
print(f"Expected for muP: low spread (similar losses across widths)")
|
||||
|
||||
# Plot activations
|
||||
save_path = None
|
||||
if args.save_dir:
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
save_path = os.path.join(args.save_dir, f'coord_check_{param_type.lower()}.png')
|
||||
plot_coord_check(results, config, save_path)
|
||||
|
||||
# Plot loss curves
|
||||
loss_save_path = None
|
||||
if args.save_dir:
|
||||
loss_save_path = os.path.join(args.save_dir, f'loss_curves_{param_type.lower()}.png')
|
||||
plot_loss_curves(results, config, title=param_type, save_path=loss_save_path)
|
||||
|
||||
# Plot detailed diagnostics if requested
|
||||
if config.detailed:
|
||||
detail_save = None
|
||||
if args.save_dir:
|
||||
detail_save = os.path.join(args.save_dir, f'detailed_{param_type.lower()}.png')
|
||||
plot_detailed(results, config, detail_save)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
669
scripts/mup_transfer_check.py
Normal file
669
scripts/mup_transfer_check.py
Normal file
|
|
@ -0,0 +1,669 @@
|
|||
"""
|
||||
muP Transfer Check for nanochat
|
||||
|
||||
Validates that optimal learning rates transfer across model widths under muP.
|
||||
For each width, sweeps over LR multipliers and records final loss. Under correct
|
||||
muP, the optimal LR multiplier should be ~1.0 at all widths (i.e., the same LR
|
||||
works everywhere). Under SP, the optimal LR typically shifts with width.
|
||||
|
||||
Reference: https://blog.eleuther.ai/mutransfer/
|
||||
Reference: Yang et al., "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot
|
||||
Hyperparameter Transfer" (arXiv:2203.03466), Section F.
|
||||
|
||||
Usage:
|
||||
# Quick check (~2 min on GPU)
|
||||
python -m scripts.mup_transfer_check
|
||||
|
||||
# Compare SP vs muP side-by-side
|
||||
python -m scripts.mup_transfer_check --compare
|
||||
|
||||
# Wide LR sweep (paper-style, ~1000x range)
|
||||
python -m scripts.mup_transfer_check --compare --widths 128,256,512,1024 --steps 200
|
||||
|
||||
# Random log-uniform LR trials (paper-style methodology)
|
||||
python -m scripts.mup_transfer_check --compare --num-random-trials 20
|
||||
|
||||
# Multi-HP sweep (init scale + output multiplier)
|
||||
python -m scripts.mup_transfer_check --compare --sweep-init-scale --sweep-output-mult
|
||||
|
||||
# Save plots
|
||||
python -m scripts.mup_transfer_check --compare --save-dir plots/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
os.environ["NANOCHAT_DTYPE"] = "float32"
|
||||
import torch
|
||||
import torch._dynamo
|
||||
torch._dynamo.config.disable = True
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
import os
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferCheckConfig:
|
||||
widths: List[int]
|
||||
lr_multipliers: List[float]
|
||||
num_steps: int = 200
|
||||
batch_size: int = 8
|
||||
seq_len: int = 128
|
||||
vocab_size: int = 32768
|
||||
n_layer: int = 2
|
||||
seed: int = 42
|
||||
use_mup: bool = False
|
||||
base_width: int = 128
|
||||
# Base learning rates (tuned at base_width=128)
|
||||
matrix_lr: float = 0.12
|
||||
embedding_lr: float = 6.0
|
||||
unembedding_lr: float = 0.12
|
||||
# Multi-HP sweeps
|
||||
sweep_init_scale: bool = False
|
||||
sweep_output_mult: bool = False
|
||||
# Data diversity
|
||||
num_batches: int = 1
|
||||
# Muon LR exponent for muP (1.0=standard, 0.5=Frobenius-norm optimizers)
|
||||
muon_lr_exponent: float = 0.0
|
||||
# Sweep mode: which optimizer groups the LR multiplier applies to
|
||||
# "all" = multiply all LRs (default), "muon-only" = only matrix_lr,
|
||||
# "adamw-only" = only embedding_lr/unembedding_lr
|
||||
sweep_mode: str = "all"
|
||||
|
||||
|
||||
def load_batches(num_batches: int, batch_size: int, seq_len: int, device: torch.device):
|
||||
"""Load multiple batches from the nanochat training pipeline.
|
||||
Falls back to random data if the tokenizer/dataset isn't available."""
|
||||
try:
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
|
||||
tokenizer = get_tokenizer()
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
loader = tokenizing_distributed_data_loader_bos_bestfit(
|
||||
tokenizer, batch_size, seq_len, split="train", device=device,
|
||||
)
|
||||
batches = []
|
||||
for i, (x, y) in enumerate(loader):
|
||||
batches.append((x, y))
|
||||
if len(batches) >= num_batches:
|
||||
break
|
||||
print(f"Loaded {len(batches)} real training batch(es) (vocab_size={vocab_size})")
|
||||
return batches, vocab_size
|
||||
except Exception as e:
|
||||
print(f"Could not load training data ({e}), using random tokens")
|
||||
vocab_size = 32768
|
||||
batches = []
|
||||
for i in range(num_batches):
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(42 + i)
|
||||
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, generator=rng)
|
||||
y = torch.roll(x, -1, dims=1)
|
||||
y[:, -1] = -1
|
||||
batches.append((x, y))
|
||||
return batches, vocab_size
|
||||
|
||||
|
||||
def create_model(width: int, config: TransferCheckConfig, device: torch.device,
|
||||
mup_base_width: int = 0, init_scale: float = 1.0,
|
||||
output_mult: float = 1.0):
|
||||
"""Create a model with the specified width and optional HP overrides."""
|
||||
head_dim = 64
|
||||
n_head = max(1, width // head_dim)
|
||||
actual_width = n_head * head_dim
|
||||
|
||||
gpt_config = GPTConfig(
|
||||
sequence_len=config.seq_len,
|
||||
vocab_size=config.vocab_size,
|
||||
n_layer=config.n_layer,
|
||||
n_head=n_head,
|
||||
n_kv_head=n_head,
|
||||
n_embd=actual_width,
|
||||
window_pattern="L",
|
||||
mup_base_width=mup_base_width,
|
||||
)
|
||||
|
||||
with torch.device('meta'):
|
||||
model = GPT(gpt_config)
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
|
||||
# Apply init_scale: multiply all parameter inits by scalar
|
||||
if init_scale != 1.0:
|
||||
with torch.no_grad():
|
||||
for p in model.parameters():
|
||||
p.mul_(init_scale)
|
||||
|
||||
return model, gpt_config
|
||||
|
||||
|
||||
def train_model(width: int, lr_mult: float, config: TransferCheckConfig,
|
||||
device: torch.device, batches: List,
|
||||
init_scale: float = 1.0, output_mult: float = 1.0):
|
||||
"""Train a model at given width and LR multiplier, return loss history."""
|
||||
torch.manual_seed(config.seed)
|
||||
|
||||
mup_base_width = config.base_width if config.use_mup else 0
|
||||
model, gpt_config = create_model(width, config, device, mup_base_width=mup_base_width,
|
||||
init_scale=init_scale, output_mult=output_mult)
|
||||
actual_width = gpt_config.n_embd
|
||||
|
||||
# Scale the learning rates by the multiplier, respecting sweep_mode
|
||||
if config.sweep_mode == "muon-only":
|
||||
muon_mult, adamw_mult = lr_mult, 1.0
|
||||
elif config.sweep_mode == "adamw-only":
|
||||
muon_mult, adamw_mult = 1.0, lr_mult
|
||||
else: # "all"
|
||||
muon_mult, adamw_mult = lr_mult, lr_mult
|
||||
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=config.unembedding_lr * adamw_mult,
|
||||
embedding_lr=config.embedding_lr * adamw_mult,
|
||||
matrix_lr=config.matrix_lr * muon_mult,
|
||||
weight_decay=0.0,
|
||||
use_mup=config.use_mup,
|
||||
base_width=config.base_width,
|
||||
muon_lr_exponent=config.muon_lr_exponent,
|
||||
)
|
||||
|
||||
model.train()
|
||||
losses = []
|
||||
num_batches = len(batches)
|
||||
|
||||
for step in range(config.num_steps):
|
||||
x, y = batches[step % num_batches]
|
||||
with torch.amp.autocast(device_type='cuda', dtype=torch.float32, enabled=False):
|
||||
loss = model(x, y)
|
||||
|
||||
losses.append(loss.item())
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
del model, optimizer
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return losses, actual_width
|
||||
|
||||
|
||||
def run_transfer_check(config: TransferCheckConfig, device: torch.device,
|
||||
batches: List) -> Dict:
|
||||
"""Run LR sweep across all widths."""
|
||||
results = {
|
||||
'widths': [],
|
||||
'lr_multipliers': config.lr_multipliers,
|
||||
'losses': {}, # losses[(width, lr_mult)] = [loss_step0, ...]
|
||||
'final_losses': defaultdict(dict), # final_losses[width][lr_mult] = final_loss
|
||||
}
|
||||
|
||||
for width in config.widths:
|
||||
actual_width = None
|
||||
for lr_mult in config.lr_multipliers:
|
||||
print(f" width={width}, lr_mult={lr_mult:.4f}...", end=" ", flush=True)
|
||||
|
||||
losses, actual_width = train_model(width, lr_mult, config, device, batches)
|
||||
results['losses'][(actual_width, lr_mult)] = losses
|
||||
results['final_losses'][actual_width][lr_mult] = losses[-1]
|
||||
print(f"final_loss={losses[-1]:.4f}")
|
||||
|
||||
if actual_width not in results['widths']:
|
||||
results['widths'].append(actual_width)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_hp_sweep(config: TransferCheckConfig, device: torch.device,
|
||||
batches: List, hp_name: str, hp_values: List[float]) -> Dict:
|
||||
"""Run a sweep over a single HP (init_scale or output_mult) at fixed LR."""
|
||||
results = {
|
||||
'widths': [],
|
||||
'hp_values': hp_values,
|
||||
'hp_name': hp_name,
|
||||
'final_losses': defaultdict(dict),
|
||||
}
|
||||
|
||||
for width in config.widths:
|
||||
actual_width = None
|
||||
for hp_val in hp_values:
|
||||
init_scale = hp_val if hp_name == 'init_scale' else 1.0
|
||||
output_mult = hp_val if hp_name == 'output_mult' else 1.0
|
||||
print(f" width={width}, {hp_name}={hp_val:.4f}...", end=" ", flush=True)
|
||||
|
||||
losses, actual_width = train_model(
|
||||
width, 1.0, config, device, batches,
|
||||
init_scale=init_scale, output_mult=output_mult)
|
||||
results['final_losses'][actual_width][hp_val] = losses[-1]
|
||||
print(f"final_loss={losses[-1]:.4f}")
|
||||
|
||||
if actual_width not in results['widths']:
|
||||
results['widths'].append(actual_width)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def find_optimal_lr(final_losses: Dict[float, float]) -> float:
|
||||
"""Find the LR multiplier with the lowest final loss."""
|
||||
return min(final_losses, key=final_losses.get)
|
||||
|
||||
|
||||
def plot_lr_sweep(results: Dict, config: TransferCheckConfig, title: str = "", save_path: Optional[str] = None):
|
||||
"""Plot LR sweep: final loss vs LR multiplier for each width."""
|
||||
widths = results['widths']
|
||||
lr_mults = results['lr_multipliers']
|
||||
final_losses = results['final_losses']
|
||||
|
||||
n_cols = 2
|
||||
fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 4))
|
||||
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
|
||||
|
||||
# Left: final loss vs LR multiplier
|
||||
ax = axes[0]
|
||||
for i, w in enumerate(widths):
|
||||
losses = [final_losses[w][m] for m in lr_mults]
|
||||
ax.plot(lr_mults, losses, 'o-', color=colors[i], linewidth=2, label=f'width={w}')
|
||||
opt_mult = find_optimal_lr(final_losses[w])
|
||||
opt_loss = final_losses[w][opt_mult]
|
||||
ax.plot(opt_mult, opt_loss, '*', color=colors[i], markersize=15, zorder=5)
|
||||
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_yscale('log')
|
||||
ax.set_xlabel('LR Multiplier')
|
||||
ax.set_ylabel('Final Loss')
|
||||
ax.set_title(f'LR Sweep{" - " + title if title else ""}')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Right: optimal LR multiplier vs width
|
||||
ax = axes[1]
|
||||
opt_mults = [find_optimal_lr(final_losses[w]) for w in widths]
|
||||
ax.plot(widths, opt_mults, 'o-', linewidth=2, markersize=8, color='tab:blue')
|
||||
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='target (1.0)')
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_yscale('log', base=2)
|
||||
ax.set_xticks(widths)
|
||||
ax.set_xticklabels(widths)
|
||||
ax.set_xlabel('Width')
|
||||
ax.set_ylabel('Optimal LR Multiplier')
|
||||
ax.set_title(f'Optimal LR vs Width{" - " + title if title else ""}')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved plot to {save_path}")
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_comparison(results_sp: Dict, results_mup: Dict, config: TransferCheckConfig, save_path: Optional[str] = None):
|
||||
"""Plot SP vs muP comparison side by side."""
|
||||
widths = results_sp['widths']
|
||||
lr_mults = results_sp['lr_multipliers']
|
||||
|
||||
n_rows, n_cols = 2, 2
|
||||
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
|
||||
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
|
||||
|
||||
# Top row: LR sweep curves (log scale for loss detail)
|
||||
for col, (results, label) in enumerate([(results_sp, 'SP'), (results_mup, 'muP')]):
|
||||
ax = axes[0, col]
|
||||
for i, w in enumerate(widths):
|
||||
losses = [results['final_losses'][w][m] for m in lr_mults]
|
||||
ax.plot(lr_mults, losses, 'o-', color=colors[i], linewidth=2, label=f'w={w}')
|
||||
opt_mult = find_optimal_lr(results['final_losses'][w])
|
||||
opt_loss = results['final_losses'][w][opt_mult]
|
||||
ax.plot(opt_mult, opt_loss, '*', color=colors[i], markersize=15, zorder=5)
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_yscale('log')
|
||||
ax.set_xlabel('LR Multiplier')
|
||||
ax.set_ylabel('Final Loss')
|
||||
ax.set_title(f'{label}: Final Loss vs LR Multiplier')
|
||||
ax.legend(fontsize=8)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Shared y-axis for top row
|
||||
all_losses_flat = [results_sp['final_losses'][w][m] for m in lr_mults for w in widths] + \
|
||||
[results_mup['final_losses'][w][m] for m in lr_mults for w in widths]
|
||||
y_min, y_max = min(all_losses_flat) * 0.9, max(all_losses_flat) * 1.1
|
||||
axes[0, 0].set_ylim(y_min, y_max)
|
||||
axes[0, 1].set_ylim(y_min, y_max)
|
||||
|
||||
# Bottom left: optimal LR vs width for both
|
||||
ax = axes[1, 0]
|
||||
for results, label, color in [(results_sp, 'SP', 'tab:red'), (results_mup, 'muP', 'tab:blue')]:
|
||||
opt_mults = [find_optimal_lr(results['final_losses'][w]) for w in widths]
|
||||
ax.plot(widths, opt_mults, 'o-', linewidth=2, markersize=8, color=color, label=label)
|
||||
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='target')
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_yscale('log', base=2)
|
||||
ax.set_xticks(widths)
|
||||
ax.set_xticklabels(widths)
|
||||
ax.set_xlabel('Width')
|
||||
ax.set_ylabel('Optimal LR Multiplier')
|
||||
ax.set_title('Optimal LR Multiplier vs Width')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Bottom right: loss at optimal LR vs width
|
||||
ax = axes[1, 1]
|
||||
for results, label, color in [(results_sp, 'SP', 'tab:red'), (results_mup, 'muP', 'tab:blue')]:
|
||||
opt_losses = [results['final_losses'][w][find_optimal_lr(results['final_losses'][w])] for w in widths]
|
||||
ax.plot(widths, opt_losses, 'o-', linewidth=2, markersize=8, color=color, label=label)
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_xticks(widths)
|
||||
ax.set_xticklabels(widths)
|
||||
ax.set_xlabel('Width')
|
||||
ax.set_ylabel('Best Final Loss')
|
||||
ax.set_title('Best Achievable Loss vs Width')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
fig.suptitle('muP Transfer Check: SP vs muP', fontsize=14)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved comparison plot to {save_path}")
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_hp_sweep(results: Dict, config: TransferCheckConfig, title: str = "", save_path: Optional[str] = None):
|
||||
"""Plot HP sweep: final loss vs HP value for each width."""
|
||||
widths = results['widths']
|
||||
hp_values = results['hp_values']
|
||||
hp_name = results['hp_name']
|
||||
final_losses = results['final_losses']
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
|
||||
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
|
||||
|
||||
ax = axes[0]
|
||||
for i, w in enumerate(widths):
|
||||
losses = [final_losses[w][v] for v in hp_values]
|
||||
ax.plot(hp_values, losses, 'o-', color=colors[i], linewidth=2, label=f'w={w}')
|
||||
opt_v = min(final_losses[w], key=final_losses[w].get)
|
||||
ax.plot(opt_v, final_losses[w][opt_v], '*', color=colors[i], markersize=15, zorder=5)
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_yscale('log')
|
||||
ax.set_xlabel(hp_name)
|
||||
ax.set_ylabel('Final Loss')
|
||||
ax.set_title(f'{hp_name} Sweep{" - " + title if title else ""}')
|
||||
ax.legend(fontsize=8)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
ax = axes[1]
|
||||
opt_vals = [min(final_losses[w], key=final_losses[w].get) for w in widths]
|
||||
ax.plot(widths, opt_vals, 'o-', linewidth=2, markersize=8, color='tab:blue')
|
||||
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='target (1.0)')
|
||||
ax.set_xscale('log', base=2)
|
||||
ax.set_yscale('log', base=2)
|
||||
ax.set_xlabel('Width')
|
||||
ax.set_ylabel(f'Optimal {hp_name}')
|
||||
ax.set_title(f'Optimal {hp_name} vs Width')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved plot to {save_path}")
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_loss_curves_at_optimal(results: Dict, config: TransferCheckConfig, title: str = "", save_path: Optional[str] = None):
|
||||
"""Plot full loss curves at the optimal LR for each width."""
|
||||
widths = results['widths']
|
||||
|
||||
fig, ax = plt.subplots(figsize=(5 * 2, 4))
|
||||
colors = plt.cm.viridis(np.linspace(0, 0.85, len(widths)))
|
||||
|
||||
for i, w in enumerate(widths):
|
||||
opt_mult = find_optimal_lr(results['final_losses'][w])
|
||||
losses = results['losses'][(w, opt_mult)]
|
||||
ax.plot(losses, color=colors[i], linewidth=2, label=f'w={w} (lr_mult={opt_mult})')
|
||||
|
||||
ax.set_yscale('log')
|
||||
ax.set_xlabel('Step')
|
||||
ax.set_ylabel('Loss')
|
||||
ax.set_title(f'Loss Curves at Optimal LR{" - " + title if title else ""}')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved plot to {save_path}")
|
||||
plt.show()
|
||||
|
||||
|
||||
def print_summary(results: Dict, label: str):
|
||||
"""Print a summary table of the LR sweep results."""
|
||||
widths = results['widths']
|
||||
lr_mults = results['lr_multipliers']
|
||||
final_losses = results['final_losses']
|
||||
|
||||
print(f"\n{'='*70}")
|
||||
print(f" {label}: LR Sweep Results")
|
||||
print(f"{'='*70}")
|
||||
|
||||
# Header
|
||||
header = f"{'Width':>8}"
|
||||
for m in lr_mults:
|
||||
header += f" | {m:>7.3f}"
|
||||
header += f" | {'Best':>7} | {'Opt LR':>7}"
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
|
||||
opt_mults = []
|
||||
for w in widths:
|
||||
row = f"{w:>8}"
|
||||
for m in lr_mults:
|
||||
loss = final_losses[w][m]
|
||||
row += f" | {loss:>7.4f}"
|
||||
opt_m = find_optimal_lr(final_losses[w])
|
||||
opt_mults.append(opt_m)
|
||||
opt_loss = final_losses[w][opt_m]
|
||||
row += f" | {opt_loss:>7.4f} | {opt_m:>7.3f}"
|
||||
print(row)
|
||||
|
||||
# Transfer quality metric: how much does the optimal LR shift?
|
||||
opt_mults_arr = np.array(opt_mults)
|
||||
log_opt = np.log2(opt_mults_arr)
|
||||
spread = log_opt.max() - log_opt.min() # spread in log2 space
|
||||
print(f"\nOptimal LR spread (log2): {spread:.3f}")
|
||||
print(f" (0 = perfect transfer, >1 = poor transfer)")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='muP Transfer Check')
|
||||
parser.add_argument('--widths', type=str, default='128,256,512,1024',
|
||||
help='Comma-separated list of widths to test')
|
||||
# Paper-style default: ~1000x range, 11 log-spaced points
|
||||
parser.add_argument('--lr-mults', type=str,
|
||||
default='0.03125,0.044,0.0625,0.088,0.125,0.177,0.25,0.354,0.5,0.707,1.0,1.414,2.0,2.828,4.0,5.657,8.0',
|
||||
help='Comma-separated LR multipliers to sweep (default: 256x range, 17 points, ~sqrt(2) spacing)')
|
||||
parser.add_argument('--num-random-trials', type=int, default=0,
|
||||
help='If >0, use N log-uniform random LR multipliers from 10^Uniform(-1.5,1.5) '
|
||||
'instead of the grid. Paper-style methodology (Section F).')
|
||||
parser.add_argument('--steps', type=int, default=200,
|
||||
help='Number of training steps per run')
|
||||
parser.add_argument('--batch-size', type=int, default=8,
|
||||
help='Batch size')
|
||||
parser.add_argument('--seq-len', type=int, default=128,
|
||||
help='Sequence length')
|
||||
parser.add_argument('--n-layer', type=int, default=2,
|
||||
help='Number of transformer layers')
|
||||
parser.add_argument('--use-mup', action='store_true',
|
||||
help='Use muP learning rate scaling')
|
||||
parser.add_argument('--base-width', type=int, default=128,
|
||||
help='Base width for muP scaling')
|
||||
parser.add_argument('--compare', action='store_true',
|
||||
help='Run both SP and muP and compare')
|
||||
parser.add_argument('--save-dir', type=str, default=None,
|
||||
help='Directory to save plots')
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help='Random seed')
|
||||
parser.add_argument('--num-batches', type=int, default=1,
|
||||
help='Number of data batches to cycle through (default 1 for backward compat, '
|
||||
'recommend 8 for thorough checks)')
|
||||
# Multi-HP sweeps
|
||||
parser.add_argument('--sweep-init-scale', action='store_true',
|
||||
help='Also sweep init scale multiplier (sampled from 10^Uniform(-1,1))')
|
||||
parser.add_argument('--sweep-output-mult', action='store_true',
|
||||
help='Also sweep output logit multiplier (sampled from 4^Uniform(-1,1))')
|
||||
parser.add_argument('--muon-lr-exponent', type=float, default=0.0,
|
||||
help='Muon LR exponent for muP: 1.0 = (base/width)^1 (standard), '
|
||||
'0.5 = (base/width)^0.5 (for Frobenius-normalizing optimizers like Muon)')
|
||||
parser.add_argument('--sweep-mode', type=str, default='all',
|
||||
choices=['all', 'muon-only', 'adamw-only'],
|
||||
help='Which optimizer groups the LR multiplier applies to: '
|
||||
'"all" = all LRs (default), "muon-only" = only Muon/matrix LR, '
|
||||
'"adamw-only" = only AdamW/embedding/output LR')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
widths = [int(w) for w in args.widths.split(',')]
|
||||
|
||||
# Generate LR multipliers
|
||||
if args.num_random_trials > 0:
|
||||
# Log-uniform random sampling: 10^Uniform(-1.5, 1.5)
|
||||
rng = np.random.RandomState(args.seed)
|
||||
lr_mults = sorted(10 ** rng.uniform(-1.5, 1.5, args.num_random_trials))
|
||||
lr_mults = [round(float(m), 6) for m in lr_mults]
|
||||
print(f"Using {args.num_random_trials} random log-uniform LR multipliers: "
|
||||
f"[{lr_mults[0]:.4f}, ..., {lr_mults[-1]:.4f}]")
|
||||
else:
|
||||
lr_mults = sorted(float(m) for m in args.lr_mults.split(','))
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load batches
|
||||
batches, vocab_size = load_batches(args.num_batches, args.batch_size, args.seq_len, device)
|
||||
|
||||
config = TransferCheckConfig(
|
||||
widths=widths,
|
||||
lr_multipliers=lr_mults,
|
||||
num_steps=args.steps,
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
vocab_size=vocab_size,
|
||||
n_layer=args.n_layer,
|
||||
seed=args.seed,
|
||||
use_mup=args.use_mup,
|
||||
base_width=args.base_width,
|
||||
sweep_init_scale=args.sweep_init_scale,
|
||||
sweep_output_mult=args.sweep_output_mult,
|
||||
num_batches=args.num_batches,
|
||||
muon_lr_exponent=args.muon_lr_exponent,
|
||||
sweep_mode=args.sweep_mode,
|
||||
)
|
||||
|
||||
if args.compare:
|
||||
# Run SP
|
||||
print("\n" + "=" * 60)
|
||||
print("Running Standard Parameterization (SP)")
|
||||
print("=" * 60)
|
||||
config.use_mup = False
|
||||
results_sp = run_transfer_check(config, device, batches)
|
||||
print_summary(results_sp, "SP")
|
||||
|
||||
# Run muP
|
||||
print("\n" + "=" * 60)
|
||||
print("Running muP")
|
||||
print("=" * 60)
|
||||
config.use_mup = True
|
||||
results_mup = run_transfer_check(config, device, batches)
|
||||
print_summary(results_mup, "muP")
|
||||
|
||||
# Compare
|
||||
print("\n" + "=" * 60)
|
||||
print("COMPARISON")
|
||||
print("=" * 60)
|
||||
sp_opts = [find_optimal_lr(results_sp['final_losses'][w]) for w in results_sp['widths']]
|
||||
mup_opts = [find_optimal_lr(results_mup['final_losses'][w]) for w in results_mup['widths']]
|
||||
sp_spread = np.log2(max(sp_opts)) - np.log2(min(sp_opts))
|
||||
mup_spread = np.log2(max(mup_opts)) - np.log2(min(mup_opts))
|
||||
print(f"SP optimal LR spread (log2): {sp_spread:.3f}")
|
||||
print(f"muP optimal LR spread (log2): {mup_spread:.3f}")
|
||||
if mup_spread < sp_spread:
|
||||
print(f"muP shows {sp_spread/max(mup_spread, 0.001):.1f}x better LR transfer!")
|
||||
else:
|
||||
print("muP does NOT show better LR transfer (check implementation)")
|
||||
|
||||
# Plot
|
||||
save_path = None
|
||||
if args.save_dir:
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
save_path = os.path.join(args.save_dir, 'transfer_check_comparison.png')
|
||||
plot_comparison(results_sp, results_mup, config, save_path)
|
||||
|
||||
# Also plot loss curves at optimal LR
|
||||
for results, label in [(results_sp, 'SP'), (results_mup, 'muP')]:
|
||||
lc_save = None
|
||||
if args.save_dir:
|
||||
lc_save = os.path.join(args.save_dir, f'optimal_loss_curves_{label.lower()}.png')
|
||||
plot_loss_curves_at_optimal(results, config, title=label, save_path=lc_save)
|
||||
|
||||
# Multi-HP sweeps (only for muP, to demonstrate transfer)
|
||||
if args.sweep_init_scale or args.sweep_output_mult:
|
||||
config.use_mup = True
|
||||
|
||||
if args.sweep_init_scale:
|
||||
print("\n" + "=" * 60)
|
||||
print("muP: Init Scale Sweep")
|
||||
print("=" * 60)
|
||||
# 10^Uniform(-1, 1) => range [0.1, 10]
|
||||
init_scales = [0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]
|
||||
init_results = run_hp_sweep(config, device, batches, 'init_scale', init_scales)
|
||||
save_hp = None
|
||||
if args.save_dir:
|
||||
save_hp = os.path.join(args.save_dir, 'init_scale_sweep.png')
|
||||
plot_hp_sweep(init_results, config, title="muP", save_path=save_hp)
|
||||
|
||||
if args.sweep_output_mult:
|
||||
print("\n" + "=" * 60)
|
||||
print("muP: Output Multiplier Sweep")
|
||||
print("=" * 60)
|
||||
# 4^Uniform(-1, 1) => range [0.25, 4]
|
||||
output_mults = [0.25, 0.5, 1.0, 2.0, 4.0]
|
||||
output_results = run_hp_sweep(config, device, batches, 'output_mult', output_mults)
|
||||
save_hp = None
|
||||
if args.save_dir:
|
||||
save_hp = os.path.join(args.save_dir, 'output_mult_sweep.png')
|
||||
plot_hp_sweep(output_results, config, title="muP", save_path=save_hp)
|
||||
|
||||
else:
|
||||
param_type = "muP" if config.use_mup else "SP"
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Running Transfer Check ({param_type})")
|
||||
print(f"{'='*60}")
|
||||
print(f"Widths: {widths}")
|
||||
print(f"LR multipliers: {lr_mults}")
|
||||
print(f"Steps: {config.num_steps}")
|
||||
if config.sweep_mode != "all":
|
||||
print(f"Sweep mode: {config.sweep_mode}")
|
||||
|
||||
results = run_transfer_check(config, device, batches)
|
||||
print_summary(results, param_type)
|
||||
|
||||
# Plot LR sweep
|
||||
save_path = None
|
||||
if args.save_dir:
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
save_path = os.path.join(args.save_dir, f'transfer_check_{param_type.lower()}.png')
|
||||
plot_lr_sweep(results, config, title=param_type, save_path=save_path)
|
||||
|
||||
# Plot loss curves at optimal LR
|
||||
lc_save = None
|
||||
if args.save_dir:
|
||||
lc_save = os.path.join(args.save_dir, f'optimal_loss_curves_{param_type.lower()}.png')
|
||||
plot_loss_curves_at_optimal(results, config, title=param_type, save_path=lc_save)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
196
tests/test_mup.py
Normal file
196
tests/test_mup.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
"""
|
||||
Test muP implementation: coordinate check and transfer check.
|
||||
|
||||
Verifies that:
|
||||
1. Activation magnitudes are width-independent under muP (coord check)
|
||||
2. Optimal learning rates transfer across widths under muP (transfer check)
|
||||
|
||||
python -m pytest tests/test_mup.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
torch._dynamo.config.disable = True
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
|
||||
|
||||
def create_model(width, seq_len=64, n_layer=2, vocab_size=256, mup_base_width=0):
|
||||
"""Create a small model at the given width."""
|
||||
head_dim = 64
|
||||
n_head = max(1, width // head_dim)
|
||||
actual_width = n_head * head_dim
|
||||
config = GPTConfig(
|
||||
sequence_len=seq_len, vocab_size=vocab_size,
|
||||
n_layer=n_layer, n_head=n_head, n_kv_head=n_head,
|
||||
n_embd=actual_width, window_pattern="L",
|
||||
mup_base_width=mup_base_width,
|
||||
)
|
||||
with torch.device('meta'):
|
||||
model = GPT(config)
|
||||
model.to_empty(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
||||
model.init_weights()
|
||||
return model, config
|
||||
|
||||
|
||||
def get_activation_stats(model, x, y):
|
||||
"""Run one forward pass and return mean |activation| for key layers."""
|
||||
stats = {}
|
||||
|
||||
def make_hook(name):
|
||||
def hook(module, input, output):
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
if output is not None and isinstance(output, torch.Tensor):
|
||||
stats[name] = output.float().abs().mean().item()
|
||||
return hook
|
||||
|
||||
hooks = []
|
||||
hooks.append(model.transformer.wte.register_forward_hook(make_hook('embedding')))
|
||||
for i, block in enumerate(model.transformer.h):
|
||||
hooks.append(block.attn.c_proj.register_forward_hook(make_hook(f'attn.{i}')))
|
||||
hooks.append(block.mlp.c_proj.register_forward_hook(make_hook(f'ffn.{i}')))
|
||||
|
||||
# Output logits with muP scaling applied
|
||||
mup_base = model.config.mup_base_width
|
||||
n_embd = model.config.n_embd
|
||||
def logit_hook(module, input, output):
|
||||
if output is not None and isinstance(output, torch.Tensor):
|
||||
scaled = output * (mup_base / n_embd) if mup_base > 0 else output
|
||||
stats['logits'] = scaled.float().abs().mean().item()
|
||||
hooks.append(model.lm_head.register_forward_hook(logit_hook))
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
model(x, y)
|
||||
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
return stats
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
|
||||
class TestMuPCoordCheck:
|
||||
"""Test that muP activations are width-independent."""
|
||||
|
||||
WIDTHS = [128, 256, 512]
|
||||
BASE_WIDTH = 128
|
||||
|
||||
def _compute_slopes(self, use_mup):
|
||||
"""Run coord check and return log-log slopes for each layer."""
|
||||
device = torch.device('cuda')
|
||||
seq_len, vocab_size = 64, 256
|
||||
torch.manual_seed(42)
|
||||
x = torch.randint(0, vocab_size, (4, seq_len), device=device)
|
||||
y = torch.roll(x, -1, dims=1)
|
||||
|
||||
all_stats = {}
|
||||
for width in self.WIDTHS:
|
||||
torch.manual_seed(42)
|
||||
mup_base = self.BASE_WIDTH if use_mup else 0
|
||||
model, _ = create_model(width, seq_len, vocab_size=vocab_size, mup_base_width=mup_base)
|
||||
all_stats[width] = get_activation_stats(model, x, y)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Compute slopes on log-log plot
|
||||
log_widths = np.log2(np.array(self.WIDTHS, dtype=float))
|
||||
slopes = {}
|
||||
for layer in all_stats[self.WIDTHS[0]]:
|
||||
values = [all_stats[w][layer] for w in self.WIDTHS]
|
||||
log_values = np.log2(np.array(values) + 1e-10)
|
||||
slope, _ = np.polyfit(log_widths, log_values, 1)
|
||||
slopes[layer] = slope
|
||||
|
||||
return slopes
|
||||
|
||||
def test_mup_activations_width_independent(self):
|
||||
"""Under muP, internal activation slopes should be near zero.
|
||||
|
||||
Note: output logits are excluded because at init (no training steps),
|
||||
muP logits are expected to decrease with width — the init scaling
|
||||
(std * sqrt(base/width)) combined with forward scaling (logits * base/width)
|
||||
gives O(1/width) initial logits. muP preserves update magnitude, not init magnitude.
|
||||
"""
|
||||
slopes = self._compute_slopes(use_mup=True)
|
||||
for layer, slope in slopes.items():
|
||||
if layer == 'logits':
|
||||
continue # skip — output logits have expected width dependence at init
|
||||
assert abs(slope) < 0.2, \
|
||||
f"muP activation slope for '{layer}' is {slope:.4f}, expected near 0"
|
||||
|
||||
def test_sp_activations_width_dependent(self):
|
||||
"""Under SP, at least some activation slopes should be nonzero (sanity check)."""
|
||||
slopes = self._compute_slopes(use_mup=False)
|
||||
max_slope = max(abs(s) for s in slopes.values())
|
||||
assert max_slope > 0.1, \
|
||||
f"SP max slope is only {max_slope:.4f}, expected SP to show width dependence"
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
|
||||
class TestMuPTransfer:
|
||||
"""Test that optimal LR transfers across widths under muP."""
|
||||
|
||||
WIDTHS = [128, 256]
|
||||
BASE_WIDTH = 128
|
||||
LR_MULTS = [0.25, 1.0, 4.0, 16.0]
|
||||
NUM_STEPS = 50
|
||||
|
||||
def _run_sweep(self, use_mup):
|
||||
"""Sweep LR multipliers across widths, return optimal mult per width."""
|
||||
device = torch.device('cuda')
|
||||
seq_len, vocab_size = 64, 256
|
||||
matrix_lr = 0.12
|
||||
embedding_lr = 6.0
|
||||
unembedding_lr = 0.12
|
||||
|
||||
torch.manual_seed(42)
|
||||
x = torch.randint(0, vocab_size, (8, seq_len), device=device)
|
||||
y = torch.roll(x, -1, dims=1)
|
||||
|
||||
optimal_mults = {}
|
||||
for width in self.WIDTHS:
|
||||
best_loss, best_mult = float('inf'), None
|
||||
for lr_mult in self.LR_MULTS:
|
||||
torch.manual_seed(42)
|
||||
mup_base = self.BASE_WIDTH if use_mup else 0
|
||||
model, _ = create_model(width, seq_len, vocab_size=vocab_size, mup_base_width=mup_base)
|
||||
optimizer = model.setup_optimizer(
|
||||
matrix_lr=matrix_lr * lr_mult,
|
||||
embedding_lr=embedding_lr * lr_mult,
|
||||
unembedding_lr=unembedding_lr * lr_mult,
|
||||
weight_decay=0.0,
|
||||
use_mup=use_mup,
|
||||
base_width=self.BASE_WIDTH,
|
||||
)
|
||||
model.train()
|
||||
for _ in range(self.NUM_STEPS):
|
||||
loss = model(x, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
final_loss = loss.item()
|
||||
if final_loss < best_loss:
|
||||
best_loss = final_loss
|
||||
best_mult = lr_mult
|
||||
|
||||
del model, optimizer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
optimal_mults[width] = best_mult
|
||||
|
||||
return optimal_mults
|
||||
|
||||
def test_mup_lr_transfer(self):
|
||||
"""Under muP, optimal LR multiplier should be similar across widths."""
|
||||
optimal = self._run_sweep(use_mup=True)
|
||||
mults = list(optimal.values())
|
||||
spread = np.log2(max(mults)) - np.log2(min(mults))
|
||||
assert spread <= 2.0, \
|
||||
f"muP LR spread is {spread:.1f} log2 (optimal mults: {optimal}), expected <= 2.0"
|
||||
Loading…
Reference in New Issue
Block a user