mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
Fix epsilon scaling in DistAdamW to match standard AdamW
Fixes #304 The epsilon term was being added before bias correction of the second moment, which differs from standard AdamW (PyTorch). This fix applies epsilon after the bias correction, matching PyTorch's implementation. Mathematical change: - Before: denom = sqrt(exp_avg_sq) + eps - After: denom = sqrt(exp_avg_sq) / sqrt(bias2) + eps 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
4a87a0d19f
commit
96ec37e5fd
|
|
@ -66,9 +66,10 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
# bias corrections
|
||||
bias1 = 1 - beta1 ** t
|
||||
bias2 = 1 - beta2 ** t
|
||||
# compute step
|
||||
denom = exp_avg_sq.sqrt().add_(eps)
|
||||
step_size = lr * (torch.sqrt(bias2) / bias1)
|
||||
# compute step (standard AdamW: apply eps after bias correction)
|
||||
bias_correction2_sqrt = torch.sqrt(bias2)
|
||||
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
|
||||
step_size = lr / bias1
|
||||
update = exp_avg.div(denom).mul_(step_size)
|
||||
p_slice.add_(other=update, alpha=-1.0)
|
||||
idx += 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user