mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-10 18:00:17 +00:00
correctly reference NorMuon and fix misleading terms that i may have hastily ported over from modded-nanogpt
This commit is contained in:
parent
542beb0c8c
commit
718e5e9d67
|
|
@ -733,8 +733,8 @@ Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon i
|
|||
- Both methods kept in code for easy comparison (`zeropower_via_polar_express` vs `zeropower_via_newtonschulz5`)
|
||||
- **Result:** No dramatic/noticeable difference in training, but keeping the new Polar Express as default.
|
||||
|
||||
**2. Variance Reduction (NorMuon-style)**
|
||||
- Added low-rank variance estimator similar to Adafactor ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491))
|
||||
**2. NorMuon Variance Reduction**
|
||||
- Added per-neuron/column adaptive learning rate from NorMuon ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491))
|
||||
- Maintains `second_momentum_buffer` with shape `[rows, 1]` or `[1, cols]` (whichever is smaller)
|
||||
- Normalizes updates based on running per-row/col variance estimate (beta2=0.95)
|
||||
- Memory overhead: ~1/max(rows, cols) per param, negligible
|
||||
|
|
@ -776,7 +776,7 @@ Example: If d12 optimal is 0.22, then d20 optimal ≈ 0.22 × (12/20)² ≈ 0.08
|
|||
|
||||
### Summary
|
||||
|
||||
Muon was changed to use Polar Express, added Adafactor-style variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth.
|
||||
Muon was changed to use Polar Express, added NorMuon variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth.
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
|
|
@ -67,6 +67,10 @@ Polar Express Sign Method for orthogonalization.
|
|||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
NorMuon variance reduction: per-neuron/column adaptive learning rate that normalizes
|
||||
update scales after orthogonalization (Muon's output has non-uniform scales across neurons).
|
||||
https://arxiv.org/pdf/2510.05491
|
||||
|
||||
Some of the changes in nanochat implementation:
|
||||
- Uses a simpler, more general approach to parameter grouping and stacking
|
||||
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user