From 7ac837cff8efc0e85502e2b3a934a35e2d937b8d Mon Sep 17 00:00:00 2001 From: Chris McCormick Date: Mon, 2 Feb 2026 08:58:34 -0800 Subject: [PATCH] Flip LR mult --- nanochat/optim.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/nanochat/optim.py b/nanochat/optim.py index 190a1ed..494c298 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -256,7 +256,12 @@ class MuonAdamW(torch.optim.Optimizer): # Fill all the 0-D tensors with current values self._muon_momentum_t.fill_(group["momentum"]) self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) - self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5) + # Shape-based LR scaling (flipped from original): + # - Tall matrices (input projections like c_fc): 1x + # - Wide matrices (output projections like c_proj): sqrt(cols/rows) → 2x for 1:4 + ratio = shape[-2] / shape[-1] + lr_mult = 1.0 if ratio >= 1 else ratio**-0.5 + self._muon_lr_t.fill_(group["lr"] * lr_mult) self._muon_wd_t.fill_(group["weight_decay"]) # Single fused kernel: momentum -> polar_express -> variance_reduction -> update @@ -473,7 +478,12 @@ class DistMuonAdamW(torch.optim.Optimizer): # Fill 0-D tensors and run fused kernel self._muon_momentum_t.fill_(group["momentum"]) self._muon_beta2_t.fill_(group["beta2"]) - self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5) + # Shape-based LR scaling (flipped from original): + # - Tall matrices (input projections like c_fc): 1x + # - Wide matrices (output projections like c_proj): sqrt(cols/rows) → 2x for 1:4 + ratio = shape[-2] / shape[-1] + lr_mult = 1.0 if ratio >= 1 else ratio**-0.5 + self._muon_lr_t.fill_(group["lr"] * lr_mult) self._muon_wd_t.fill_(group["weight_decay"]) muon_step_fused( grad_chunk[:num_owned], stacked_owned,