Setup optimization in MuonAdamW and DistMuonAdamW.

This commit is contained in:
Crutcher Dunnavant 2026-04-10 17:49:26 -07:00
parent a445144d39
commit 758af69e25

View File

@ -198,6 +198,20 @@ class MuonAdamW(torch.optim.Optimizer):
AdamW update for each param in the group individually.
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
"""
for p in group['params']:
if p.grad is not None:
break
else:
# Don't bother with the update if no params in this group have gradients
return
# Pre-fill 0-D tensors with current values
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
for p in group['params']:
if p.grad is None:
continue
@ -215,11 +229,6 @@ class MuonAdamW(torch.optim.Optimizer):
# Fill 0-D tensors with current values
self._adamw_step_t.fill_(state['step'])
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
adamw_step_fused(
@ -410,6 +419,14 @@ class DistMuonAdamW(torch.optim.Optimizer):
def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
"""Wait for reduce, compute AdamW updates, launch gathers for large params."""
param_infos = info['param_infos']
# Fill 0-D tensors.
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
for p in group['params']:
pinfo = param_infos[p]
pinfo['future'].wait()
@ -432,11 +449,6 @@ class DistMuonAdamW(torch.optim.Optimizer):
# Fill 0-D tensors and run fused kernel
self._adamw_step_t.fill_(state['step'])
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
adamw_step_fused(
p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,