diff --git a/nanochat/optim.py b/nanochat/optim.py index 56e85e14..e62b9cf2 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -198,6 +198,20 @@ class MuonAdamW(torch.optim.Optimizer): AdamW update for each param in the group individually. Lazy init the state, fill in all 0-D tensors, call the fused kernel. """ + for p in group['params']: + if p.grad is not None: + break + else: + # Don't bother with the update if no params in this group have gradients + return + + # Pre-fill 0-D tensors with current values + self._adamw_lr_t.fill_(group['lr']) + self._adamw_beta1_t.fill_(group['betas'][0]) + self._adamw_beta2_t.fill_(group['betas'][1]) + self._adamw_eps_t.fill_(group['eps']) + self._adamw_wd_t.fill_(group['weight_decay']) + for p in group['params']: if p.grad is None: continue @@ -215,11 +229,6 @@ class MuonAdamW(torch.optim.Optimizer): # Fill 0-D tensors with current values self._adamw_step_t.fill_(state['step']) - self._adamw_lr_t.fill_(group['lr']) - self._adamw_beta1_t.fill_(group['betas'][0]) - self._adamw_beta2_t.fill_(group['betas'][1]) - self._adamw_eps_t.fill_(group['eps']) - self._adamw_wd_t.fill_(group['weight_decay']) # Fused update: weight_decay -> momentum -> bias_correction -> param_update adamw_step_fused( @@ -410,6 +419,14 @@ class DistMuonAdamW(torch.optim.Optimizer): def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None: """Wait for reduce, compute AdamW updates, launch gathers for large params.""" param_infos = info['param_infos'] + + # Fill 0-D tensors. + self._adamw_lr_t.fill_(group['lr']) + self._adamw_beta1_t.fill_(group['betas'][0]) + self._adamw_beta2_t.fill_(group['betas'][1]) + self._adamw_eps_t.fill_(group['eps']) + self._adamw_wd_t.fill_(group['weight_decay']) + for p in group['params']: pinfo = param_infos[p] pinfo['future'].wait() @@ -432,11 +449,6 @@ class DistMuonAdamW(torch.optim.Optimizer): # Fill 0-D tensors and run fused kernel self._adamw_step_t.fill_(state['step']) - self._adamw_lr_t.fill_(group['lr']) - self._adamw_beta1_t.fill_(group['betas'][0]) - self._adamw_beta2_t.fill_(group['betas'][1]) - self._adamw_eps_t.fill_(group['eps']) - self._adamw_wd_t.fill_(group['weight_decay']) adamw_step_fused( p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'], self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,