fix(optim): cast scalars to buffer dtype for torch.compile compat

This commit is contained in:
suraj-self 2026-02-07 20:50:21 +05:30
parent aeff095e97
commit 2e0fda1893

View File

@ -38,8 +38,9 @@ def adamw_step_fused(
# Weight decay (decoupled, applied before the update)
p.mul_(1 - lr_t * wd_t)
# Update running averages (lerp_ is cleaner and fuses well)
exp_avg.lerp_(grad, 1 - beta1_t)
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
# Explicitly cast the scalar to match the buffer's dtype
exp_avg.lerp_(grad, (1 - beta1_t).to(exp_avg.dtype))
exp_avg_sq.lerp_(grad.square(), (1 - beta2_t).to(exp_avg_sq.dtype))
# Bias corrections
bias1 = 1 - beta1_t ** step_t
bias2 = 1 - beta2_t ** step_t
@ -132,7 +133,12 @@ def muon_step_fused(
red_dim_size = g.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
# Explicitly cast the scalar (1 - beta2) to match the buffer's dtype
second_momentum_buffer.lerp_(
v_mean.to(dtype=second_momentum_buffer.dtype),
(1 - beta2).to(second_momentum_buffer.dtype)
)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()