Flip LR mult

This commit is contained in:
Chris McCormick 2026-02-02 08:58:34 -08:00
parent 5129a34288
commit 7ac837cff8

View File

@ -256,7 +256,12 @@ class MuonAdamW(torch.optim.Optimizer):
# Fill all the 0-D tensors with current values
self._muon_momentum_t.fill_(group["momentum"])
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
# Shape-based LR scaling (flipped from original):
# - Tall matrices (input projections like c_fc): 1x
# - Wide matrices (output projections like c_proj): sqrt(cols/rows) → 2x for 1:4
ratio = shape[-2] / shape[-1]
lr_mult = 1.0 if ratio >= 1 else ratio**-0.5
self._muon_lr_t.fill_(group["lr"] * lr_mult)
self._muon_wd_t.fill_(group["weight_decay"])
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
@ -473,7 +478,12 @@ class DistMuonAdamW(torch.optim.Optimizer):
# Fill 0-D tensors and run fused kernel
self._muon_momentum_t.fill_(group["momentum"])
self._muon_beta2_t.fill_(group["beta2"])
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
# Shape-based LR scaling (flipped from original):
# - Tall matrices (input projections like c_fc): 1x
# - Wide matrices (output projections like c_proj): sqrt(cols/rows) → 2x for 1:4
ratio = shape[-2] / shape[-1]
lr_mult = 1.0 if ratio >= 1 else ratio**-0.5
self._muon_lr_t.fill_(group["lr"] * lr_mult)
self._muon_wd_t.fill_(group["weight_decay"])
muon_step_fused(
grad_chunk[:num_owned], stacked_owned,