Pass p as tesnor to fused adam

We can avoid a couple recompiles by passing the underlying tensor for a parameter instead of the parameter object.
This commit is contained in:
Chris McCormick 2026-01-30 18:03:20 -08:00
parent 3c3a3d7042
commit 9b9ef3ef38

View File

@ -217,7 +217,7 @@ class MuonAdamW(torch.optim.Optimizer):
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
adamw_step_fused(
p, grad, exp_avg, exp_avg_sq,
p.data, grad, exp_avg, exp_avg_sq,
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
)