diff --git a/nanochat/optim.py b/nanochat/optim.py index 190a1ed..ea623fa 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -217,7 +217,7 @@ class MuonAdamW(torch.optim.Optimizer): # Fused update: weight_decay -> momentum -> bias_correction -> param_update adamw_step_fused( - p, grad, exp_avg, exp_avg_sq, + p.data, grad, exp_avg, exp_avg_sq, self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t, self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t, )