Fix epsilon scaling in DistAdamW to match standard AdamW

Fixes #304

The epsilon term was being added before bias correction of the second
moment, which differs from standard AdamW (PyTorch). This fix applies
epsilon after the bias correction, matching PyTorch's implementation.

Mathematical change:
- Before: denom = sqrt(exp_avg_sq) + eps
- After:  denom = sqrt(exp_avg_sq) / sqrt(bias2) + eps

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Charles Weill 2025-12-02 12:26:41 -08:00
parent 4a87a0d19f
commit 96ec37e5fd

View File

@ -66,9 +66,10 @@ class DistAdamW(torch.optim.Optimizer):
# bias corrections # bias corrections
bias1 = 1 - beta1 ** t bias1 = 1 - beta1 ** t
bias2 = 1 - beta2 ** t bias2 = 1 - beta2 ** t
# compute step # compute step (standard AdamW: apply eps after bias correction)
denom = exp_avg_sq.sqrt().add_(eps) bias_correction2_sqrt = torch.sqrt(bias2)
step_size = lr * (torch.sqrt(bias2) / bias1) denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
step_size = lr / bias1
update = exp_avg.div(denom).mul_(step_size) update = exp_avg.div(denom).mul_(step_size)
p_slice.add_(other=update, alpha=-1.0) p_slice.add_(other=update, alpha=-1.0)
idx += 1 idx += 1