diff --git a/nanochat/optim.py b/nanochat/optim.py index 190a1ed..494c298 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -256,7 +256,12 @@ class MuonAdamW(torch.optim.Optimizer): # Fill all the 0-D tensors with current values self._muon_momentum_t.fill_(group["momentum"]) self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) - self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5) + # Shape-based LR scaling (flipped from original): + # - Tall matrices (input projections like c_fc): 1x + # - Wide matrices (output projections like c_proj): sqrt(cols/rows) → 2x for 1:4 + ratio = shape[-2] / shape[-1] + lr_mult = 1.0 if ratio >= 1 else ratio**-0.5 + self._muon_lr_t.fill_(group["lr"] * lr_mult) self._muon_wd_t.fill_(group["weight_decay"]) # Single fused kernel: momentum -> polar_express -> variance_reduction -> update @@ -473,7 +478,12 @@ class DistMuonAdamW(torch.optim.Optimizer): # Fill 0-D tensors and run fused kernel self._muon_momentum_t.fill_(group["momentum"]) self._muon_beta2_t.fill_(group["beta2"]) - self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5) + # Shape-based LR scaling (flipped from original): + # - Tall matrices (input projections like c_fc): 1x + # - Wide matrices (output projections like c_proj): sqrt(cols/rows) → 2x for 1:4 + ratio = shape[-2] / shape[-1] + lr_mult = 1.0 if ratio >= 1 else ratio**-0.5 + self._muon_lr_t.fill_(group["lr"] * lr_mult) self._muon_wd_t.fill_(group["weight_decay"]) muon_step_fused( grad_chunk[:num_owned], stacked_owned,