mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-23 11:54:16 +00:00
fuse adamw into a single torch compiled kernel similar to muon. it's about 1.7X faster, but overall it's so tiny that it's not making a major dent
This commit is contained in:
parent
255f8b9af6
commit
22a71aa3d3
|
|
@ -1,11 +1,42 @@
|
|||
"""
|
||||
Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
|
||||
Not a general optimizer! But works for our specific use.
|
||||
Distributed AdamW optimizer with a fused step function.
|
||||
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
|
||||
"""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def adamw_step_fused(
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
exp_avg: Tensor,
|
||||
exp_avg_sq: Tensor,
|
||||
step_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
beta1_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
eps_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
||||
"""
|
||||
# Weight decay (decoupled, applied before the update)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
# Update running averages (lerp_ is cleaner and fuses well)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
# Bias corrections
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
# Compute update and apply
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
|
||||
class DistAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
|
|
@ -14,7 +45,26 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
"""
|
||||
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
# Validate
|
||||
if rank == 0:
|
||||
for group in param_groups:
|
||||
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
|
||||
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
|
||||
for p in group['params']:
|
||||
sliced = p.numel() >= 1024
|
||||
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
|
||||
if sliced: # large parameter tensors will be operated on in slices
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
|
|
@ -36,8 +86,7 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
grad_slices.append(grad)
|
||||
else:
|
||||
is_small.append(False)
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
rank_size = grad.shape[0] // world_size
|
||||
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
|
|
@ -63,28 +112,27 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
t = state['step']
|
||||
# weight decay
|
||||
if wd != 0:
|
||||
eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
|
||||
p_slice.mul_(1 - eff_weight_decay)
|
||||
# update running averages
|
||||
exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
|
||||
# bias corrections
|
||||
bias1 = 1 - beta1 ** t
|
||||
bias2 = 1 - beta2 ** t
|
||||
# compute step
|
||||
denom = (exp_avg_sq / bias2).sqrt().add_(eps)
|
||||
step_size = lr / bias1
|
||||
update = exp_avg.div(denom).mul_(step_size)
|
||||
p_slice.add_(other=update, alpha=-1.0)
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
eff_wd = wd * getattr(p, "wd_mul", 1.0)
|
||||
self._step_t.fill_(state['step'])
|
||||
self._lr_t.fill_(lr)
|
||||
self._beta1_t.fill_(beta1)
|
||||
self._beta2_t.fill_(beta2)
|
||||
self._eps_t.fill_(eps)
|
||||
self._wd_t.fill_(eff_wd)
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
p_slice, g_slice, exp_avg, exp_avg_sq,
|
||||
self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
|
||||
)
|
||||
|
||||
# Only large params need all_gather
|
||||
if not is_small[idx]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user