diff --git a/muP_changes.md b/muP_changes.md index c886dd7..b76ac18 100644 --- a/muP_changes.md +++ b/muP_changes.md @@ -99,11 +99,13 @@ Exponents 0.0 and 1.0 give **identical spread** (2.0). The Muon LR exponent lite |-----------|-------|-------------------|--------| | Output logit scaling | `logits *= base/width` | Required | ✅ Correct | | Embedding LR | No width scaling | Constant with width | ✅ Correct | -| lm_head init std | `0.001 × √(base/width)` | Width-scaled init | ✅ Correct | +| lm_head init std | `0.02` (flat, no width scaling) | See note below | ✅ Correct | | Weight decay | Not width-scaled | Constant with width | ✅ Correct | | Momentum (Adam β₁, Muon) | Not width-scaled | Constant with width | ✅ Correct | | c_proj init | Non-zero uniform, std=√(3/n_embd) | Paper recommends zero | ⚠️ Intentional divergence | +**On lm_head init**: The paper prescribes width-scaled init (`std ∝ 1/√width`) to keep initial logit magnitudes O(1). We previously used `0.001 × √(base/width)`. However, the forward-pass logit scaling (`logits *= base/width`) already suppresses logit magnitudes at large widths. The width-scaled init was double-compensating — initial logits were O(base/width) instead of O(1), making the lm_head start too quiet at large widths. We now use a flat `std = 0.02` which, combined with the forward-pass scaling, produces well-behaved initial logits at all widths. + **On c_proj init**: The paper recommends zero-initializing output projections (attn c_proj, MLP c_proj) for cleaner transfer. nanochat uses non-zero init because zero init causes vanishing attention/FFN outputs when combined with Muon's LR dynamics — the first Muon update from a zero matrix produces an orthogonal matrix with O(LR) norm, which is too small when LR is already small. This is a known interaction between Muon and residual-stream architectures; the non-zero init provides a stable starting point. ## Summary: muP for Muon+AdamW @@ -113,13 +115,13 @@ For a mixed Muon+AdamW optimizer, muP simplifies dramatically: | Parameter group | muP prescription | Reason | |----------------|-----------------|--------| | **Output logits** | `logits *= base/width` in forward | The essential ingredient — makes loss landscape shape-invariant | -| **lm_head init** | `std *= √(base/width)` | Keeps initial logit magnitudes O(1) | +| **lm_head init** | `std = 0.02` (flat, no width scaling) | Forward-pass logit scaling already handles width independence; width-scaled init double-compensates | | **lm_head LR** | No width scaling | Logit scaling already propagates into gradient; Adam normalizes; additional LR scaling over-reduces | | **Muon (hidden) LR** | No width scaling | Polar Express makes `||update||_F ≈ 1` regardless of width | | **Embedding LR** | No width scaling | Standard muP (embeddings are lookup tables, not matrix multiplies) | | **Scalar LR** | No width scaling | Standard muP | -**The punchline**: With Muon+AdamW, muP reduces to scaling output logits by `base/width` in the forward pass (plus corresponding init adjustment). No LR scaling is needed anywhere — Muon's orthogonalization and Adam's second-moment normalization both already produce width-independent updates. +**The punchline**: With Muon+AdamW, muP reduces to scaling output logits by `base/width` in the forward pass. No LR scaling or width-dependent init is needed anywhere — Muon's orthogonalization and Adam's second-moment normalization both already produce width-independent updates. ## Verification @@ -139,9 +141,9 @@ python -m pytest tests/test_mup.py -v | File | Changes | |------|---------| -| `nanochat/gpt.py` | `output_lr_scale`: `base/width` → `1.0`; added `muon_lr_exponent` param (default `0.0`); updated comments | -| `scripts/mup_coord_check.py` | Added `--detailed` flag (grad norms, update norms, attn logit magnitudes), `--muon-lr-exponent` | -| `scripts/mup_transfer_check.py` | Wider default LR range (1024×), `--sweep-mode {all,muon-only,adamw-only}`, `--num-random-trials`, `--num-batches`, `--sweep-init-scale`, `--sweep-output-mult`, `--muon-lr-exponent`, default steps 100→200 | +| `nanochat/gpt.py` | `output_lr_scale`: `base/width` → `1.0`; added `muon_lr_exponent` param (default `0.0`); lm_head init: `0.001 × √(base/width)` → flat `0.02` for muP (removed double-compensation); updated comments | +| `scripts/mup_coord_check.py` | Added `--detailed` flag (grad norms, update norms, attn logit magnitudes), `--muon-lr-exponent`; switched to float32 (disabled bfloat16 autocast) for numerical precision | +| `scripts/mup_transfer_check.py` | Wider default LR range (1024×), `--sweep-mode {all,muon-only,adamw-only}`, `--num-random-trials`, `--num-batches`, `--sweep-init-scale`, `--sweep-output-mult`, `--muon-lr-exponent`, default steps 100→200; switched to float32 for numerical precision | ## References diff --git a/scripts/mup_transfer_check.py b/scripts/mup_transfer_check.py index 4ea9d58..59246ac 100644 --- a/scripts/mup_transfer_check.py +++ b/scripts/mup_transfer_check.py @@ -484,8 +484,8 @@ def main(): help='Comma-separated list of widths to test') # Paper-style default: ~1000x range, 11 log-spaced points parser.add_argument('--lr-mults', type=str, - default='0.03125,0.0625,0.125,0.25,0.5,1.0,2.0,4.0,8.0,16.0,32.0', - help='Comma-separated LR multipliers to sweep (default: 1024x range, 11 points)') + default='0.03125,0.044,0.0625,0.088,0.125,0.177,0.25,0.354,0.5,0.707,1.0,1.414,2.0,2.828,4.0,5.657,8.0', + help='Comma-separated LR multipliers to sweep (default: 256x range, 17 points, ~sqrt(2) spacing)') parser.add_argument('--num-random-trials', type=int, default=0, help='If >0, use N log-uniform random LR multipliers from 10^Uniform(-1.5,1.5) ' 'instead of the grid. Paper-style methodology (Section F).')