From 22a71aa3d30b50d6e7659ca0e8bcad9f3b7bfd98 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 15 Jan 2026 23:30:44 +0000 Subject: [PATCH] fuse adamw into a single torch compiled kernel similar to muon. it's about 1.7X faster, but overall it's so tiny that it's not making a major dent --- nanochat/adamw.py | 90 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 69 insertions(+), 21 deletions(-) diff --git a/nanochat/adamw.py b/nanochat/adamw.py index 48945b3..70ccf7b 100644 --- a/nanochat/adamw.py +++ b/nanochat/adamw.py @@ -1,11 +1,42 @@ """ -Borrowed from modded-nanogpt. By Keller, @vagrawal, et al. -Not a general optimizer! But works for our specific use. +Distributed AdamW optimizer with a fused step function. +A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt. """ import torch import torch.distributed as dist from torch import Tensor +@torch.compile(dynamic=False, fullgraph=True) +def adamw_step_fused( + p: Tensor, + grad: Tensor, + exp_avg: Tensor, + exp_avg_sq: Tensor, + step_t: Tensor, + lr_t: Tensor, + beta1_t: Tensor, + beta2_t: Tensor, + eps_t: Tensor, + wd_t: Tensor, +) -> None: + """ + Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update + All in one compiled graph to eliminate Python overhead between ops. + The 0-D CPU tensors avoid recompilation when hyperparameter values change. + """ + # Weight decay (decoupled, applied before the update) + p.mul_(1 - lr_t * wd_t) + # Update running averages (lerp_ is cleaner and fuses well) + exp_avg.lerp_(grad, 1 - beta1_t) + exp_avg_sq.lerp_(grad.square(), 1 - beta2_t) + # Bias corrections + bias1 = 1 - beta1_t ** step_t + bias2 = 1 - beta2_t ** step_t + # Compute update and apply + denom = (exp_avg_sq / bias2).sqrt() + eps_t + step_size = lr_t / bias1 + p.add_(exp_avg / denom, alpha=-step_size) + class DistAdamW(torch.optim.Optimizer): """ @@ -14,7 +45,26 @@ class DistAdamW(torch.optim.Optimizer): """ def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + rank = dist.get_rank() + world_size = dist.get_world_size() + # Validate + if rank == 0: + for group in param_groups: + assert isinstance(group, dict), "expecting param_groups to be a list of dicts" + assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors" + for p in group['params']: + sliced = p.numel() >= 1024 + print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}") + if sliced: # large parameter tensors will be operated on in slices + assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}" super().__init__(param_groups, defaults) + # 0-D CPU tensors to avoid torch.compile recompilation when values change + self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") @torch.no_grad() def step(self): @@ -36,8 +86,7 @@ class DistAdamW(torch.optim.Optimizer): grad_slices.append(grad) else: is_small.append(False) - assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}" - rank_size = grad.shape[0] // world_size + rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__ grad_slice = torch.empty_like(grad[:rank_size]) reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) grad_slices.append(grad_slice) @@ -63,28 +112,27 @@ class DistAdamW(torch.optim.Optimizer): # State init if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['step'] = 0 state['exp_avg'] = torch.zeros_like(p_slice) state['exp_avg_sq'] = torch.zeros_like(p_slice) exp_avg = state['exp_avg'] exp_avg_sq = state['exp_avg_sq'] state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = (exp_avg_sq / bias2).sqrt().add_(eps) - step_size = lr / bias1 - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) + + # Fill 0-D tensors with current values + eff_wd = wd * getattr(p, "wd_mul", 1.0) + self._step_t.fill_(state['step']) + self._lr_t.fill_(lr) + self._beta1_t.fill_(beta1) + self._beta2_t.fill_(beta2) + self._eps_t.fill_(eps) + self._wd_t.fill_(eff_wd) + + # Fused update: weight_decay -> momentum -> bias_correction -> param_update + adamw_step_fused( + p_slice, g_slice, exp_avg, exp_avg_sq, + self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t, + ) # Only large params need all_gather if not is_small[idx]: