diff --git a/nanochat/optim.py b/nanochat/optim.py index 42d862b..dac61dc 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -38,8 +38,9 @@ def adamw_step_fused( # Weight decay (decoupled, applied before the update) p.mul_(1 - lr_t * wd_t) # Update running averages (lerp_ is cleaner and fuses well) - exp_avg.lerp_(grad, 1 - beta1_t) - exp_avg_sq.lerp_(grad.square(), 1 - beta2_t) + # Explicitly cast the scalar to match the buffer's dtype + exp_avg.lerp_(grad, (1 - beta1_t).to(exp_avg.dtype)) + exp_avg_sq.lerp_(grad.square(), (1 - beta2_t).to(exp_avg_sq.dtype)) # Bias corrections bias1 = 1 - beta1_t ** step_t bias2 = 1 - beta2_t ** step_t @@ -132,7 +133,12 @@ def muon_step_fused( red_dim_size = g.size(red_dim) v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size v_norm = v_norm_sq.sqrt() - second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2) + + # Explicitly cast the scalar (1 - beta2) to match the buffer's dtype + second_momentum_buffer.lerp_( + v_mean.to(dtype=second_momentum_buffer.dtype), + (1 - beta2).to(second_momentum_buffer.dtype) + ) step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt() scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square() v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()