mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-06 07:35:32 +00:00
Merge 8be3907514 into 1144d186ed
This commit is contained in:
commit
fb9a32a9f3
|
|
@ -247,6 +247,7 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
# from NorMuon: https://arxiv.org/abs/2510.05491
|
||||
if "second_momentum_buffer" not in state:
|
||||
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
|
||||
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user