mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-05 15:15:48 +00:00
Merge 2e0fda1893 into 2f09686724
This commit is contained in:
commit
0efaa80db3
|
|
@ -38,8 +38,9 @@ def adamw_step_fused(
|
|||
# Weight decay (decoupled, applied before the update)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
# Update running averages (lerp_ is cleaner and fuses well)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
# Explicitly cast the scalar to match the buffer's dtype
|
||||
exp_avg.lerp_(grad, (1 - beta1_t).to(exp_avg.dtype))
|
||||
exp_avg_sq.lerp_(grad.square(), (1 - beta2_t).to(exp_avg_sq.dtype))
|
||||
# Bias corrections
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
|
|
@ -132,7 +133,12 @@ def muon_step_fused(
|
|||
red_dim_size = g.size(red_dim)
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
|
||||
# Explicitly cast the scalar (1 - beta2) to match the buffer's dtype
|
||||
second_momentum_buffer.lerp_(
|
||||
v_mean.to(dtype=second_momentum_buffer.dtype),
|
||||
(1 - beta2).to(second_momentum_buffer.dtype)
|
||||
)
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user