fuse adamw into a single torch compiled kernel similar to muon. it's about 1.7X faster, but overall it's so tiny that it's not making a major dent

This commit is contained in:
Andrej Karpathy 2026-01-15 23:30:44 +00:00
parent 255f8b9af6
commit 22a71aa3d3

View File

@ -1,11 +1,42 @@
"""
Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
Not a general optimizer! But works for our specific use.
Distributed AdamW optimizer with a fused step function.
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
"""
import torch
import torch.distributed as dist
from torch import Tensor
@torch.compile(dynamic=False, fullgraph=True)
def adamw_step_fused(
p: Tensor,
grad: Tensor,
exp_avg: Tensor,
exp_avg_sq: Tensor,
step_t: Tensor,
lr_t: Tensor,
beta1_t: Tensor,
beta2_t: Tensor,
eps_t: Tensor,
wd_t: Tensor,
) -> None:
"""
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
All in one compiled graph to eliminate Python overhead between ops.
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
"""
# Weight decay (decoupled, applied before the update)
p.mul_(1 - lr_t * wd_t)
# Update running averages (lerp_ is cleaner and fuses well)
exp_avg.lerp_(grad, 1 - beta1_t)
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
# Bias corrections
bias1 = 1 - beta1_t ** step_t
bias2 = 1 - beta2_t ** step_t
# Compute update and apply
denom = (exp_avg_sq / bias2).sqrt() + eps_t
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)
class DistAdamW(torch.optim.Optimizer):
"""
@ -14,7 +45,26 @@ class DistAdamW(torch.optim.Optimizer):
"""
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
rank = dist.get_rank()
world_size = dist.get_world_size()
# Validate
if rank == 0:
for group in param_groups:
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
for p in group['params']:
sliced = p.numel() >= 1024
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
if sliced: # large parameter tensors will be operated on in slices
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
super().__init__(param_groups, defaults)
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
@torch.no_grad()
def step(self):
@ -36,8 +86,7 @@ class DistAdamW(torch.optim.Optimizer):
grad_slices.append(grad)
else:
is_small.append(False)
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
rank_size = grad.shape[0] // world_size
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
grad_slice = torch.empty_like(grad[:rank_size])
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad_slice)
@ -63,28 +112,27 @@ class DistAdamW(torch.optim.Optimizer):
# State init
if not state:
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_slice)
state['exp_avg_sq'] = torch.zeros_like(p_slice)
exp_avg = state['exp_avg']
exp_avg_sq = state['exp_avg_sq']
state['step'] += 1
t = state['step']
# weight decay
if wd != 0:
eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
p_slice.mul_(1 - eff_weight_decay)
# update running averages
exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
# bias corrections
bias1 = 1 - beta1 ** t
bias2 = 1 - beta2 ** t
# compute step
denom = (exp_avg_sq / bias2).sqrt().add_(eps)
step_size = lr / bias1
update = exp_avg.div(denom).mul_(step_size)
p_slice.add_(other=update, alpha=-1.0)
# Fill 0-D tensors with current values
eff_wd = wd * getattr(p, "wd_mul", 1.0)
self._step_t.fill_(state['step'])
self._lr_t.fill_(lr)
self._beta1_t.fill_(beta1)
self._beta2_t.fill_(beta2)
self._eps_t.fill_(eps)
self._wd_t.fill_(eff_wd)
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
adamw_step_fused(
p_slice, g_slice, exp_avg, exp_avg_sq,
self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
)
# Only large params need all_gather
if not is_small[idx]: