Add NorMuon paper link

This commit is contained in:
zichongli5 2026-02-03 21:53:43 -05:00 committed by GitHub
parent 542beb0c8c
commit 8be3907514
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -243,6 +243,7 @@ class MuonAdamW(torch.optim.Optimizer):
momentum_buffer = state["momentum_buffer"]
# Second momentum buffer is factored, either per-row or per-column
# from NorMuon: https://arxiv.org/abs/2510.05491
if "second_momentum_buffer" not in state:
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)