diff --git a/nanochat/optim.py b/nanochat/optim.py index 4cc2a1f..baf95ad 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -247,6 +247,7 @@ class MuonAdamW(torch.optim.Optimizer): momentum_buffer = state["momentum_buffer"] # Second momentum buffer is factored, either per-row or per-column + # from NorMuon: https://arxiv.org/abs/2510.05491 if "second_momentum_buffer" not in state: state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1]) state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)