mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 13:15:21 +00:00
changing mu_transfer check to float32 and updating .md
This commit is contained in:
parent
075e3bb476
commit
5c92dd02cb
|
|
@ -99,11 +99,13 @@ Exponents 0.0 and 1.0 give **identical spread** (2.0). The Muon LR exponent lite
|
|||
|-----------|-------|-------------------|--------|
|
||||
| Output logit scaling | `logits *= base/width` | Required | ✅ Correct |
|
||||
| Embedding LR | No width scaling | Constant with width | ✅ Correct |
|
||||
| lm_head init std | `0.001 × √(base/width)` | Width-scaled init | ✅ Correct |
|
||||
| lm_head init std | `0.02` (flat, no width scaling) | See note below | ✅ Correct |
|
||||
| Weight decay | Not width-scaled | Constant with width | ✅ Correct |
|
||||
| Momentum (Adam β₁, Muon) | Not width-scaled | Constant with width | ✅ Correct |
|
||||
| c_proj init | Non-zero uniform, std=√(3/n_embd) | Paper recommends zero | ⚠️ Intentional divergence |
|
||||
|
||||
**On lm_head init**: The paper prescribes width-scaled init (`std ∝ 1/√width`) to keep initial logit magnitudes O(1). We previously used `0.001 × √(base/width)`. However, the forward-pass logit scaling (`logits *= base/width`) already suppresses logit magnitudes at large widths. The width-scaled init was double-compensating — initial logits were O(base/width) instead of O(1), making the lm_head start too quiet at large widths. We now use a flat `std = 0.02` which, combined with the forward-pass scaling, produces well-behaved initial logits at all widths.
|
||||
|
||||
**On c_proj init**: The paper recommends zero-initializing output projections (attn c_proj, MLP c_proj) for cleaner transfer. nanochat uses non-zero init because zero init causes vanishing attention/FFN outputs when combined with Muon's LR dynamics — the first Muon update from a zero matrix produces an orthogonal matrix with O(LR) norm, which is too small when LR is already small. This is a known interaction between Muon and residual-stream architectures; the non-zero init provides a stable starting point.
|
||||
|
||||
## Summary: muP for Muon+AdamW
|
||||
|
|
@ -113,13 +115,13 @@ For a mixed Muon+AdamW optimizer, muP simplifies dramatically:
|
|||
| Parameter group | muP prescription | Reason |
|
||||
|----------------|-----------------|--------|
|
||||
| **Output logits** | `logits *= base/width` in forward | The essential ingredient — makes loss landscape shape-invariant |
|
||||
| **lm_head init** | `std *= √(base/width)` | Keeps initial logit magnitudes O(1) |
|
||||
| **lm_head init** | `std = 0.02` (flat, no width scaling) | Forward-pass logit scaling already handles width independence; width-scaled init double-compensates |
|
||||
| **lm_head LR** | No width scaling | Logit scaling already propagates into gradient; Adam normalizes; additional LR scaling over-reduces |
|
||||
| **Muon (hidden) LR** | No width scaling | Polar Express makes `||update||_F ≈ 1` regardless of width |
|
||||
| **Embedding LR** | No width scaling | Standard muP (embeddings are lookup tables, not matrix multiplies) |
|
||||
| **Scalar LR** | No width scaling | Standard muP |
|
||||
|
||||
**The punchline**: With Muon+AdamW, muP reduces to scaling output logits by `base/width` in the forward pass (plus corresponding init adjustment). No LR scaling is needed anywhere — Muon's orthogonalization and Adam's second-moment normalization both already produce width-independent updates.
|
||||
**The punchline**: With Muon+AdamW, muP reduces to scaling output logits by `base/width` in the forward pass. No LR scaling or width-dependent init is needed anywhere — Muon's orthogonalization and Adam's second-moment normalization both already produce width-independent updates.
|
||||
|
||||
## Verification
|
||||
|
||||
|
|
@ -139,9 +141,9 @@ python -m pytest tests/test_mup.py -v
|
|||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `nanochat/gpt.py` | `output_lr_scale`: `base/width` → `1.0`; added `muon_lr_exponent` param (default `0.0`); updated comments |
|
||||
| `scripts/mup_coord_check.py` | Added `--detailed` flag (grad norms, update norms, attn logit magnitudes), `--muon-lr-exponent` |
|
||||
| `scripts/mup_transfer_check.py` | Wider default LR range (1024×), `--sweep-mode {all,muon-only,adamw-only}`, `--num-random-trials`, `--num-batches`, `--sweep-init-scale`, `--sweep-output-mult`, `--muon-lr-exponent`, default steps 100→200 |
|
||||
| `nanochat/gpt.py` | `output_lr_scale`: `base/width` → `1.0`; added `muon_lr_exponent` param (default `0.0`); lm_head init: `0.001 × √(base/width)` → flat `0.02` for muP (removed double-compensation); updated comments |
|
||||
| `scripts/mup_coord_check.py` | Added `--detailed` flag (grad norms, update norms, attn logit magnitudes), `--muon-lr-exponent`; switched to float32 (disabled bfloat16 autocast) for numerical precision |
|
||||
| `scripts/mup_transfer_check.py` | Wider default LR range (1024×), `--sweep-mode {all,muon-only,adamw-only}`, `--num-random-trials`, `--num-batches`, `--sweep-init-scale`, `--sweep-output-mult`, `--muon-lr-exponent`, default steps 100→200; switched to float32 for numerical precision |
|
||||
|
||||
## References
|
||||
|
||||
|
|
|
|||
|
|
@ -484,8 +484,8 @@ def main():
|
|||
help='Comma-separated list of widths to test')
|
||||
# Paper-style default: ~1000x range, 11 log-spaced points
|
||||
parser.add_argument('--lr-mults', type=str,
|
||||
default='0.03125,0.0625,0.125,0.25,0.5,1.0,2.0,4.0,8.0,16.0,32.0',
|
||||
help='Comma-separated LR multipliers to sweep (default: 1024x range, 11 points)')
|
||||
default='0.03125,0.044,0.0625,0.088,0.125,0.177,0.25,0.354,0.5,0.707,1.0,1.414,2.0,2.828,4.0,5.657,8.0',
|
||||
help='Comma-separated LR multipliers to sweep (default: 256x range, 17 points, ~sqrt(2) spacing)')
|
||||
parser.add_argument('--num-random-trials', type=int, default=0,
|
||||
help='If >0, use N log-uniform random LR multipliers from 10^Uniform(-1.5,1.5) '
|
||||
'instead of the grid. Paper-style methodology (Section F).')
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user