diff --git a/nanochat/adamw.py b/nanochat/adamw.py index db591de..61cba75 100644 --- a/nanochat/adamw.py +++ b/nanochat/adamw.py @@ -66,9 +66,10 @@ class DistAdamW(torch.optim.Optimizer): # bias corrections bias1 = 1 - beta1 ** t bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) + # compute step (standard AdamW: apply eps after bias correction) + bias_correction2_sqrt = torch.sqrt(bias2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + step_size = lr / bias1 update = exp_avg.div(denom).mul_(step_size) p_slice.add_(other=update, alpha=-1.0) idx += 1