mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-05 23:25:35 +00:00
Flip LR mult
This commit is contained in:
parent
5129a34288
commit
7ac837cff8
|
|
@ -256,7 +256,12 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
# Fill all the 0-D tensors with current values
|
||||
self._muon_momentum_t.fill_(group["momentum"])
|
||||
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
# Shape-based LR scaling (flipped from original):
|
||||
# - Tall matrices (input projections like c_fc): 1x
|
||||
# - Wide matrices (output projections like c_proj): sqrt(cols/rows) → 2x for 1:4
|
||||
ratio = shape[-2] / shape[-1]
|
||||
lr_mult = 1.0 if ratio >= 1 else ratio**-0.5
|
||||
self._muon_lr_t.fill_(group["lr"] * lr_mult)
|
||||
self._muon_wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
|
|
@ -473,7 +478,12 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
# Fill 0-D tensors and run fused kernel
|
||||
self._muon_momentum_t.fill_(group["momentum"])
|
||||
self._muon_beta2_t.fill_(group["beta2"])
|
||||
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
# Shape-based LR scaling (flipped from original):
|
||||
# - Tall matrices (input projections like c_fc): 1x
|
||||
# - Wide matrices (output projections like c_proj): sqrt(cols/rows) → 2x for 1:4
|
||||
ratio = shape[-2] / shape[-1]
|
||||
lr_mult = 1.0 if ratio >= 1 else ratio**-0.5
|
||||
self._muon_lr_t.fill_(group["lr"] * lr_mult)
|
||||
self._muon_wd_t.fill_(group["weight_decay"])
|
||||
muon_step_fused(
|
||||
grad_chunk[:num_owned], stacked_owned,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user