mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-19 20:49:09 +00:00
Merge 3a5a8b3fda into 8630d32be4
This commit is contained in:
commit
f505140920
|
|
@ -1,143 +0,0 @@
|
|||
"""
|
||||
Distributed AdamW optimizer with a fused step function.
|
||||
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
|
||||
"""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def adamw_step_fused(
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
exp_avg: Tensor,
|
||||
exp_avg_sq: Tensor,
|
||||
step_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
beta1_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
eps_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
||||
"""
|
||||
# Weight decay (decoupled, applied before the update)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
# Update running averages (lerp_ is cleaner and fuses well)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
# Bias corrections
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
# Compute update and apply
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
|
||||
class DistAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
Distributed AdamW optimizer.
|
||||
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
|
||||
"""
|
||||
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
# Validate
|
||||
if rank == 0:
|
||||
for group in param_groups:
|
||||
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
|
||||
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
|
||||
for p in group['params']:
|
||||
sliced = p.numel() >= 1024
|
||||
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
|
||||
if sliced: # large parameter tensors will be operated on in slices
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
reduce_futures: list[torch.Future] = []
|
||||
gather_futures: list[torch.Future] = []
|
||||
grad_slices = []
|
||||
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
|
||||
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for p in params:
|
||||
grad = p.grad
|
||||
# Small params: use all_reduce (no scatter/gather needed)
|
||||
if p.numel() < 1024:
|
||||
is_small.append(True)
|
||||
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad)
|
||||
else:
|
||||
is_small.append(False)
|
||||
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
|
||||
idx = 0
|
||||
for group in self.param_groups:
|
||||
beta1, beta2 = group['betas']
|
||||
eps = group['eps']
|
||||
wd = group['weight_decay']
|
||||
params = group['params']
|
||||
for p in params:
|
||||
reduce_futures[idx].wait()
|
||||
g_slice = grad_slices[idx]
|
||||
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
||||
state = self.state[p]
|
||||
|
||||
# For small params, operate on full param; for large, operate on slice
|
||||
if is_small[idx]:
|
||||
p_slice = p
|
||||
else:
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
eff_wd = wd * getattr(p, "wd_mul", 1.0)
|
||||
self._step_t.fill_(state['step'])
|
||||
self._lr_t.fill_(lr)
|
||||
self._beta1_t.fill_(beta1)
|
||||
self._beta2_t.fill_(beta2)
|
||||
self._eps_t.fill_(eps)
|
||||
self._wd_t.fill_(eff_wd)
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
p_slice, g_slice, exp_avg, exp_avg_sq,
|
||||
self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
|
||||
)
|
||||
|
||||
# Only large params need all_gather
|
||||
if not is_small[idx]:
|
||||
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
idx += 1
|
||||
|
||||
if gather_futures:
|
||||
torch.futures.collect_all(gather_futures).wait()
|
||||
|
|
@ -20,8 +20,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.muon import Muon, DistMuon
|
||||
from nanochat.adamw import DistAdamW
|
||||
from nanochat.optims import MuonAdamW, DistMuonAdamW
|
||||
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
|
@ -329,7 +328,7 @@ class GPT(nn.Module):
|
|||
nparams = sum(p.numel() for p in self.parameters())
|
||||
return nparams
|
||||
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
# Separate out all parameters into groups
|
||||
|
|
@ -340,7 +339,6 @@ class GPT(nn.Module):
|
|||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
|
|
@ -351,19 +349,22 @@ class GPT(nn.Module):
|
|||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
||||
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
|
||||
]
|
||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
||||
# Create the Muon optimizer for the linear layers
|
||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
|
||||
MuonFactory = DistMuon if ddp else Muon
|
||||
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
|
||||
# Combine them the two optimizers into one list
|
||||
optimizers = [adamw_optimizer, muon_optimizer]
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["initial_lr"] = group["lr"]
|
||||
return optimizers
|
||||
|
||||
# MuonAdamW for single-GPU, DistMuonAdamW for multi-GPU
|
||||
OptimizerClass = DistMuonAdamW if (ddp and world_size > 1) else MuonAdamW
|
||||
optimizer = OptimizerClass(
|
||||
adamw_groups=adam_groups,
|
||||
muon_params=matrix_params,
|
||||
adamw_betas=adam_betas,
|
||||
adamw_eps=1e-10,
|
||||
muon_lr=matrix_lr,
|
||||
muon_momentum=0.95,
|
||||
muon_weight_decay=weight_decay,
|
||||
)
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group["initial_lr"] = group["lr"]
|
||||
return optimizer
|
||||
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
B, T = idx.size()
|
||||
|
|
|
|||
352
nanochat/muon.py
352
nanochat/muon.py
|
|
@ -1,352 +0,0 @@
|
|||
"""
|
||||
Muon optimizer adapted and simplified from modded-nanogpt.
|
||||
https://github.com/KellerJordan/modded-nanogpt
|
||||
|
||||
Background:
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
|
||||
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
||||
Polar Express Sign Method for orthogonalization.
|
||||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
Some of the changes in nanochat implementation:
|
||||
- Uses a simpler, more general approach to parameter grouping and stacking
|
||||
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
||||
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
|
||||
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
||||
# From https://arxiv.org/pdf/2505.16932
|
||||
polar_express_coeffs = [
|
||||
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
||||
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
||||
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
||||
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def muon_step_fused(
|
||||
stacked_grads: Tensor,
|
||||
stacked_params: Tensor,
|
||||
momentum_buffer: Tensor,
|
||||
second_momentum_buffer: Tensor,
|
||||
momentum_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
ns_steps: int,
|
||||
red_dim: int,
|
||||
) -> None:
|
||||
"""
|
||||
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
||||
"""
|
||||
|
||||
# Nesterov momentum
|
||||
momentum = momentum_t.to(stacked_grads.dtype)
|
||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
# Polar express
|
||||
X = g.bfloat16()
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
g = X
|
||||
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = g.size(red_dim)
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
g = g * final_scale.to(g.dtype)
|
||||
|
||||
# Cautious weight decay + parameter update
|
||||
lr = lr_t.to(g.dtype)
|
||||
wd = wd_t.to(g.dtype)
|
||||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
|
||||
https://kellerjordan.github.io/posts/muon/
|
||||
|
||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
||||
the advantage that it can be stably run in bfloat16 on the GPU.
|
||||
|
||||
Some warnings:
|
||||
- This optimizer should not be used for the embedding layer, the final fully connected layer,
|
||||
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
||||
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
||||
|
||||
Arguments:
|
||||
lr: The learning rate used by the internal SGD.
|
||||
momentum: The momentum used by the internal SGD.
|
||||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
"""
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
|
||||
# Group by shape so we can stack tensors
|
||||
shapes = sorted({p.shape for p in params})
|
||||
param_groups = []
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
param_groups.append(dict(params=group_params))
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
if not params:
|
||||
continue
|
||||
|
||||
# Get or create group-level buffers (stored in first param's state for convenience)
|
||||
state = self.state[params[0]]
|
||||
num_params = len(params) # e.g.: 12 (for a d12 model)
|
||||
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
|
||||
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
|
||||
|
||||
# Momentum for every individual parameter
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
|
||||
|
||||
# Stack grads and params
|
||||
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
|
||||
stacked_params = torch.stack(params) # (12, 768, 3072)
|
||||
|
||||
# Fill all the 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
||||
|
||||
class DistMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Distributed version of the Muon optimizer.
|
||||
"""
|
||||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
||||
ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params)
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
# Group all parameters by their shape
|
||||
shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
|
||||
param_groups = []
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
device, dtype = group_params[0].device, group_params[0].dtype
|
||||
assert all(p.device == device for p in group_params)
|
||||
assert all(p.dtype == dtype for p in group_params)
|
||||
# Compute chunk size for this group (how many params each rank owns)
|
||||
chunk_size = (len(group_params) + world_size - 1) // world_size
|
||||
if rank == 0:
|
||||
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
|
||||
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# Ensure all grads exist
|
||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
||||
|
||||
# First pass: stack grads and kick off reduce_scatter for each group
|
||||
group_infos = []
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
padded_num_params = chunk_size * world_size
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
|
||||
# Stack all gradients into a single tensor (single kernel via torch.stack)
|
||||
grad_stack = torch.stack([p.grad for p in params])
|
||||
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
||||
stacked_grads[:len(params)].copy_(grad_stack)
|
||||
# Zero-pad if we have fewer params than padded size
|
||||
if len(params) < padded_num_params:
|
||||
stacked_grads[len(params):].zero_()
|
||||
|
||||
# Output buffer for this rank's chunk
|
||||
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
# Async reduce_scatter on the stacked tensor
|
||||
reduce_future = dist.reduce_scatter_tensor(
|
||||
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
|
||||
).get_future()
|
||||
|
||||
group_infos.append(dict(
|
||||
grad_chunk=grad_chunk,
|
||||
reduce_future=reduce_future,
|
||||
stacked_grads=stacked_grads, # reuse for all_gather output
|
||||
))
|
||||
|
||||
# Second pass: wait for reduce, compute batched updates, kick off all_gather
|
||||
all_gather_futures = []
|
||||
for group, info in zip(self.param_groups, group_infos):
|
||||
info["reduce_future"].wait()
|
||||
|
||||
params = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
grad_chunk = info["grad_chunk"]
|
||||
|
||||
# How many params does this rank actually own?
|
||||
start_idx = rank * chunk_size
|
||||
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
||||
|
||||
# Get or create group-level state (stored keyed by first param)
|
||||
state = self.state[params[0]]
|
||||
|
||||
# Momentum buffer
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Build updated_params tensor for all_gather
|
||||
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
if num_owned > 0:
|
||||
# Stack owned params (single kernel via torch.stack)
|
||||
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
||||
stacked_owned_params = torch.stack(owned_params)
|
||||
|
||||
# Get owned slices of buffers and grads
|
||||
owned_grads = grad_chunk[:num_owned]
|
||||
owned_momentum = momentum_buffer[:num_owned]
|
||||
owned_second_momentum = second_momentum_buffer[:num_owned]
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
owned_grads,
|
||||
stacked_owned_params,
|
||||
owned_momentum,
|
||||
owned_second_momentum,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy updated params to output buffer
|
||||
updated_params[:num_owned].copy_(stacked_owned_params)
|
||||
|
||||
# Zero-pad the rest (for ranks that own fewer params)
|
||||
if num_owned < chunk_size:
|
||||
updated_params[num_owned:].zero_()
|
||||
|
||||
# Reuse stacked_grads buffer for all_gather output
|
||||
stacked_params = info["stacked_grads"]
|
||||
|
||||
# Async all_gather to replicate updated params to all ranks
|
||||
gather_future = dist.all_gather_into_tensor(
|
||||
stacked_params, updated_params, async_op=True
|
||||
).get_future()
|
||||
|
||||
all_gather_futures.append(dict(
|
||||
gather_future=gather_future,
|
||||
stacked_params=stacked_params,
|
||||
params=params,
|
||||
))
|
||||
|
||||
# Final pass: wait for all_gather and copy back to params
|
||||
for info in all_gather_futures:
|
||||
info["gather_future"].wait()
|
||||
stacked_params = info["stacked_params"]
|
||||
params = info["params"]
|
||||
# Batched copy back (single kernel instead of N individual copies)
|
||||
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
|
||||
613
nanochat/optims.py
Normal file
613
nanochat/optims.py
Normal file
|
|
@ -0,0 +1,613 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def adamw_step_fused(
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
exp_avg: Tensor,
|
||||
exp_avg_sq: Tensor,
|
||||
step_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
beta1_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
eps_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
||||
"""
|
||||
# Weight decay (decoupled, applied before the update)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
# Update running averages (lerp_ is cleaner and fuses well)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
# Bias corrections
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
# Compute update and apply
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
"""
|
||||
Muon optimizer adapted and simplified from modded-nanogpt.
|
||||
https://github.com/KellerJordan/modded-nanogpt
|
||||
|
||||
Background:
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
|
||||
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
||||
Polar Express Sign Method for orthogonalization.
|
||||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
Some of the changes in nanochat implementation:
|
||||
- Uses a simpler, more general approach to parameter grouping and stacking
|
||||
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
||||
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
||||
"""
|
||||
|
||||
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
||||
# From https://arxiv.org/pdf/2505.16932
|
||||
polar_express_coeffs = [
|
||||
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
||||
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
||||
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
||||
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def muon_step_fused(
|
||||
stacked_grads: Tensor,
|
||||
stacked_params: Tensor,
|
||||
momentum_buffer: Tensor,
|
||||
second_momentum_buffer: Tensor,
|
||||
momentum_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
ns_steps: int,
|
||||
red_dim: int,
|
||||
) -> None:
|
||||
"""
|
||||
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
||||
"""
|
||||
|
||||
# Nesterov momentum
|
||||
momentum = momentum_t.to(stacked_grads.dtype)
|
||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
# Polar express
|
||||
X = g.bfloat16()
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
g = X
|
||||
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = g.size(red_dim)
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
g = g * final_scale.to(g.dtype)
|
||||
|
||||
# Cautious weight decay + parameter update
|
||||
lr = lr_t.to(g.dtype)
|
||||
wd = wd_t.to(g.dtype)
|
||||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
class MuonAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
Combined optimizer: Muon for 2D matrix params, AdamW for others.
|
||||
|
||||
AdamW - Distributed AdamW optimizer with a fused step function.
|
||||
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
|
||||
https://kellerjordan.github.io/posts/muon/
|
||||
|
||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
||||
the advantage that it can be stably run in bfloat16 on the GPU.
|
||||
|
||||
Some warnings:
|
||||
- The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
|
||||
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
||||
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
||||
|
||||
AdamW Arguments:
|
||||
adamw_groups: List of dicts with 'params' and optional 'lr' for AdamW params
|
||||
muon_params: List of 2D tensors to optimize with Muon
|
||||
adamw_betas: Beta coefficients for AdamW (default: (0.9, 0.999))
|
||||
adamw_eps: Epsilon for AdamW numerical stability (default: 1e-8)
|
||||
adamw_weight_decay: Weight decay for AdamW (default: 0.01)
|
||||
|
||||
Muon Arguments:
|
||||
muon_lr: The learning rate used by the internal SGD.
|
||||
muon_momentum: The momentum used by the internal SGD.
|
||||
muon_ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||
muon_beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
||||
muon_weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
adamw_groups: list[dict],
|
||||
muon_params,
|
||||
# AdamW hyperparams
|
||||
adamw_lr: float = 1e-3, # can be overridden per-group
|
||||
adamw_betas: tuple[float, float] = (0.9, 0.999),
|
||||
adamw_eps: float = 1e-8,
|
||||
adamw_weight_decay: float = 0.01,
|
||||
# Muon hyperparams
|
||||
muon_lr: float = 0.02,
|
||||
muon_momentum: float = 0.95,
|
||||
muon_ns_steps: int = 5,
|
||||
muon_beta2: float = 0.95,
|
||||
muon_weight_decay: float = 0.0,
|
||||
):
|
||||
muon_params = list(muon_params)
|
||||
assert all(p.ndim == 2 for p in muon_params), "Muon expects 2D parameters only"
|
||||
|
||||
# Build unified param_groups for the base Optimizer class
|
||||
param_groups = []
|
||||
|
||||
# AdamW groups: each input group becomes one param_group
|
||||
for group in adamw_groups:
|
||||
assert isinstance(group, dict) and 'params' in group
|
||||
params = list(group['params'])
|
||||
lr = group.get('lr', adamw_lr) # AdamW supports per-group learning rates
|
||||
for p in params:
|
||||
print(f"AdamW: 1 param of shape {p.shape}")
|
||||
param_groups.append(dict(
|
||||
params=params, lr=lr, kind='adamw',
|
||||
betas=adamw_betas, eps=adamw_eps, weight_decay=adamw_weight_decay,
|
||||
))
|
||||
|
||||
# Muon groups: group by shape for stacking, with all Muon hyperparams in the group
|
||||
muon_shapes = sorted({p.shape for p in muon_params})
|
||||
for shape in muon_shapes:
|
||||
group_params = [p for p in muon_params if p.shape == shape]
|
||||
print(f"Muon: {len(group_params)} params of shape {shape}")
|
||||
param_groups.append(dict(
|
||||
params=group_params, lr=muon_lr, kind='muon',
|
||||
momentum=muon_momentum, ns_steps=muon_ns_steps,
|
||||
beta2=muon_beta2, weight_decay=muon_weight_decay,
|
||||
))
|
||||
|
||||
defaults = dict(lr=adamw_lr) # torch.optim.Optimizer requires a default lr
|
||||
super().__init__(param_groups, defaults)
|
||||
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
# AdamW tensors
|
||||
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
# Muon tensors
|
||||
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
kind = group['kind']
|
||||
|
||||
if kind == 'adamw':
|
||||
beta1, beta2 = group['betas']
|
||||
eps = group['eps']
|
||||
wd = group['weight_decay']
|
||||
# AdamW update for each param individually
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
lr = group['lr']
|
||||
state = self.state[p]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
eff_wd = wd * getattr(p, "wd_mul", 1.0)
|
||||
self._adamw_step_t.fill_(state['step'])
|
||||
self._adamw_lr_t.fill_(lr)
|
||||
self._adamw_beta1_t.fill_(beta1)
|
||||
self._adamw_beta2_t.fill_(beta2)
|
||||
self._adamw_eps_t.fill_(eps)
|
||||
self._adamw_wd_t.fill_(eff_wd)
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
p, grad, exp_avg, exp_avg_sq,
|
||||
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
||||
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
||||
)
|
||||
|
||||
else: # muon
|
||||
params: list[Tensor] = group['params']
|
||||
if not params:
|
||||
continue
|
||||
|
||||
# Get or create group-level buffers (stored in first param's state for convenience)
|
||||
state = self.state[params[0]]
|
||||
num_params = len(params) # e.g.: 12 (for a d12 model)
|
||||
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
|
||||
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
|
||||
|
||||
# Momentum for every individual parameter
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
|
||||
|
||||
# Stack grads and params
|
||||
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
|
||||
stacked_params = torch.stack(params) # (12, 768, 3072)
|
||||
|
||||
# Fill all the 0-D tensors with current values
|
||||
self._muon_momentum_t.fill_(group["momentum"])
|
||||
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._muon_wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
self._muon_momentum_t,
|
||||
self._muon_lr_t,
|
||||
self._muon_wd_t,
|
||||
self._muon_beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
||||
class DistMuonAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
|
||||
|
||||
(See MuonAdamW for algorithmic details.)
|
||||
|
||||
AdamW Communication:
|
||||
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction.
|
||||
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
|
||||
|
||||
Muon Communication:
|
||||
Parameters are grouped by shape, then stacked into single Tensors for efficient communication.
|
||||
We launch comms largest-first, then process smallest-first so large comms finish in time.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
adamw_groups: list[dict],
|
||||
muon_params,
|
||||
# AdamW hyperparams
|
||||
adamw_lr: float = 1e-3,
|
||||
adamw_betas: tuple[float, float] = (0.9, 0.999),
|
||||
adamw_eps: float = 1e-8,
|
||||
adamw_weight_decay: float = 0.01,
|
||||
# Muon hyperparams
|
||||
muon_lr: float = 0.02,
|
||||
muon_momentum: float = 0.95,
|
||||
muon_ns_steps: int = 5,
|
||||
muon_beta2: float = 0.95,
|
||||
muon_weight_decay: float = 0.0,
|
||||
):
|
||||
assert all(p.ndim == 2 for p in muon_params), "Muon expects 2D parameters only"
|
||||
muon_params = list(muon_params)
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
|
||||
# Build unified param_groups for the base Optimizer class
|
||||
# Each group is tagged with 'kind' = 'adamw' or 'muon'
|
||||
param_groups = []
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# AdamW groups: each input group becomes one param_group
|
||||
for group in adamw_groups:
|
||||
# Validate
|
||||
if rank == 0:
|
||||
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
|
||||
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
|
||||
for p in group['params']:
|
||||
sliced = p.numel() >= 1024
|
||||
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
|
||||
if sliced: # large parameter tensors will be operated on in slices
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
# Add to param_groups
|
||||
params = list(group['params'])
|
||||
lr = group.get('lr', adamw_lr) # AdamW supports per-group learning rates
|
||||
param_groups.append(dict(
|
||||
params=params, lr=lr, kind='adamw',
|
||||
betas=adamw_betas, eps=adamw_eps, weight_decay=adamw_weight_decay,
|
||||
))
|
||||
|
||||
# Muon groups: group by shape for stacking, with all Muon hyperparams in the group
|
||||
muon_shapes = sorted({p.shape for p in muon_params}) # sort for deterministic ordering across ranks
|
||||
for shape in muon_shapes:
|
||||
group_params = [p for p in muon_params if p.shape == shape]
|
||||
device, dtype = group_params[0].device, group_params[0].dtype
|
||||
assert all(p.device == device for p in group_params)
|
||||
assert all(p.dtype == dtype for p in group_params)
|
||||
# Compute chunk size for this group (how many params each rank owns)
|
||||
chunk_size = (len(group_params) + world_size - 1) // world_size
|
||||
if rank == 0:
|
||||
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
|
||||
param_groups.append(dict(
|
||||
params=group_params, lr=muon_lr, kind='muon', chunk_size=chunk_size,
|
||||
momentum=muon_momentum, ns_steps=muon_ns_steps,
|
||||
beta2=muon_beta2, weight_decay=muon_weight_decay,
|
||||
))
|
||||
|
||||
defaults = dict(lr=adamw_lr) # torch.optim.Optimizer requires a default lr
|
||||
super().__init__(param_groups, defaults)
|
||||
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
# AdamW tensors
|
||||
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
# Muon tensors
|
||||
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# Ensure all grads exist
|
||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
||||
|
||||
# First pass: launch all async communications
|
||||
adamw_infos: dict[Tensor, dict] = {} # param -> {reduce_future, grad_slice, is_small}
|
||||
muon_infos: dict[int, dict] = {} # group_idx -> {reduce_future, grad_chunk, stacked_grads}
|
||||
|
||||
for group_idx, group in enumerate(self.param_groups):
|
||||
if group['kind'] == 'adamw':
|
||||
params: list[Tensor] = group['params']
|
||||
for p in params:
|
||||
grad = p.grad
|
||||
# Small params: use all_reduce (no scatter/gather needed)
|
||||
if p.numel() < 1024:
|
||||
reduce_future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
adamw_infos[p] = dict(reduce_future=reduce_future, grad_slice=grad, is_small=True)
|
||||
# Large param: reduce_scatter
|
||||
else:
|
||||
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
adamw_infos[p] = dict(reduce_future=reduce_future, grad_slice=grad_slice, is_small=False)
|
||||
|
||||
else: # muon
|
||||
params: list[Tensor] = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
padded_num_params = chunk_size * world_size
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
|
||||
# Stack all gradients into a single tensor (single kernel via torch.stack)
|
||||
grad_stack = torch.stack([p.grad for p in params])
|
||||
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
||||
stacked_grads[:len(params)].copy_(grad_stack)
|
||||
# Zero-pad if we have fewer params than padded size
|
||||
if len(params) < padded_num_params:
|
||||
stacked_grads[len(params):].zero_()
|
||||
|
||||
# Output buffer for this rank's chunk
|
||||
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
# Async reduce_scatter on the stacked tensor to get the grad_chunk for this rank
|
||||
reduce_future = dist.reduce_scatter_tensor(
|
||||
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
|
||||
).get_future()
|
||||
|
||||
muon_infos[group_idx] = dict(
|
||||
grad_chunk=grad_chunk,
|
||||
reduce_future=reduce_future,
|
||||
stacked_grads=stacked_grads, # reuse for all_gather output
|
||||
)
|
||||
|
||||
# Second pass: wait for reduce, compute updates, kick off all_gather
|
||||
gather_infos: list[dict] = [] # unified list for both AdamW and Muon gathers
|
||||
|
||||
for group_idx, group in enumerate(self.param_groups):
|
||||
if group['kind'] == 'adamw':
|
||||
beta1, beta2 = group['betas']
|
||||
eps = group['eps']
|
||||
wd = group['weight_decay']
|
||||
params = group['params']
|
||||
for p in params:
|
||||
info = adamw_infos[p]
|
||||
info['reduce_future'].wait()
|
||||
g_slice = info['grad_slice']
|
||||
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
||||
state = self.state[p]
|
||||
|
||||
# For small params, operate on full param; for large, operate on slice
|
||||
if info['is_small']:
|
||||
p_slice = p
|
||||
else:
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
eff_wd = wd * getattr(p, "wd_mul", 1.0)
|
||||
self._adamw_step_t.fill_(state['step'])
|
||||
self._adamw_lr_t.fill_(lr)
|
||||
self._adamw_beta1_t.fill_(beta1)
|
||||
self._adamw_beta2_t.fill_(beta2)
|
||||
self._adamw_eps_t.fill_(eps)
|
||||
self._adamw_wd_t.fill_(eff_wd)
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
p_slice, g_slice, exp_avg, exp_avg_sq,
|
||||
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
||||
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
||||
)
|
||||
|
||||
# Only large params need all_gather
|
||||
if not info['is_small']:
|
||||
gather_future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
|
||||
gather_infos.append(dict(gather_future=gather_future, params=None))
|
||||
|
||||
else: # muon
|
||||
info = muon_infos[group_idx]
|
||||
info['reduce_future'].wait()
|
||||
|
||||
params = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
grad_chunk = info["grad_chunk"]
|
||||
|
||||
# How many params does this rank actually own?
|
||||
start_idx = rank * chunk_size
|
||||
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
||||
|
||||
# Get or create group-level state (stored keyed by first param)
|
||||
state = self.state[params[0]]
|
||||
|
||||
# Momentum buffer
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Build updated_params tensor for all_gather
|
||||
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
if num_owned > 0:
|
||||
# Stack owned params (single kernel via torch.stack)
|
||||
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
||||
stacked_owned_params = torch.stack(owned_params)
|
||||
|
||||
# Get owned slices of buffers and grads
|
||||
owned_grads = grad_chunk[:num_owned]
|
||||
owned_momentum = momentum_buffer[:num_owned]
|
||||
owned_second_momentum = second_momentum_buffer[:num_owned]
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
self._muon_momentum_t.fill_(group["momentum"])
|
||||
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._muon_wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
owned_grads,
|
||||
stacked_owned_params,
|
||||
owned_momentum,
|
||||
owned_second_momentum,
|
||||
self._muon_momentum_t,
|
||||
self._muon_lr_t,
|
||||
self._muon_wd_t,
|
||||
self._muon_beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy updated params to output buffer
|
||||
updated_params[:num_owned].copy_(stacked_owned_params)
|
||||
|
||||
# Zero-pad the rest (for ranks that own fewer params)
|
||||
if num_owned < chunk_size:
|
||||
updated_params[num_owned:].zero_()
|
||||
|
||||
# Reuse stacked_grads buffer for all_gather output
|
||||
stacked_params = info["stacked_grads"]
|
||||
|
||||
# Async all_gather to replicate updated params to all ranks
|
||||
gather_future = dist.all_gather_into_tensor(
|
||||
stacked_params, updated_params, async_op=True
|
||||
).get_future()
|
||||
|
||||
gather_infos.append(dict(
|
||||
gather_future=gather_future,
|
||||
stacked_params=stacked_params,
|
||||
params=params,
|
||||
))
|
||||
|
||||
# Final pass: wait for all_gather and copy back to params (Muon only)
|
||||
for info in gather_infos:
|
||||
info["gather_future"].wait()
|
||||
# Muon params need to be copied back from stacked buffer
|
||||
if info["params"] is not None:
|
||||
stacked_params = info["stacked_params"]
|
||||
params = info["params"]
|
||||
# Batched copy back (single kernel instead of N individual copies)
|
||||
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
|
||||
605
nanochat/optims_diff.py
Normal file
605
nanochat/optims_diff.py
Normal file
|
|
@ -0,0 +1,605 @@
|
|||
|
||||
"""
|
||||
Distributed AdamW optimizer with a fused step function.
|
||||
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
|
||||
"""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def adamw_step_fused(
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
exp_avg: Tensor,
|
||||
exp_avg_sq: Tensor,
|
||||
step_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
beta1_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
eps_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
||||
"""
|
||||
# Weight decay (decoupled, applied before the update)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
# Update running averages (lerp_ is cleaner and fuses well)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
# Bias corrections
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
# Compute update and apply
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
"""
|
||||
Muon optimizer adapted and simplified from modded-nanogpt.
|
||||
https://github.com/KellerJordan/modded-nanogpt
|
||||
|
||||
Background:
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
|
||||
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
||||
Polar Express Sign Method for orthogonalization.
|
||||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
Some of the changes in nanochat implementation:
|
||||
- Uses a simpler, more general approach to parameter grouping and stacking
|
||||
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
||||
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
||||
"""
|
||||
|
||||
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
||||
# From https://arxiv.org/pdf/2505.16932
|
||||
polar_express_coeffs = [
|
||||
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
||||
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
||||
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
||||
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def muon_step_fused(
|
||||
stacked_grads: Tensor,
|
||||
stacked_params: Tensor,
|
||||
momentum_buffer: Tensor,
|
||||
second_momentum_buffer: Tensor,
|
||||
momentum_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
ns_steps: int,
|
||||
red_dim: int,
|
||||
) -> None:
|
||||
"""
|
||||
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
||||
"""
|
||||
|
||||
# Nesterov momentum
|
||||
momentum = momentum_t.to(stacked_grads.dtype)
|
||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
# Polar express
|
||||
X = g.bfloat16()
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
g = X
|
||||
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = g.size(red_dim)
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
g = g * final_scale.to(g.dtype)
|
||||
|
||||
# Cautious weight decay + parameter update
|
||||
lr = lr_t.to(g.dtype)
|
||||
wd = wd_t.to(g.dtype)
|
||||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
class DistAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
Distributed AdamW optimizer.
|
||||
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
|
||||
"""
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
|
||||
https://kellerjordan.github.io/posts/muon/
|
||||
|
||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
||||
the advantage that it can be stably run in bfloat16 on the GPU.
|
||||
|
||||
Some warnings:
|
||||
- This optimizer should not be used for the embedding layer, the final fully connected layer,
|
||||
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
||||
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
||||
|
||||
Arguments:
|
||||
lr: The learning rate used by the internal SGD.
|
||||
momentum: The momentum used by the internal SGD.
|
||||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
# Validate
|
||||
if rank == 0:
|
||||
for group in param_groups:
|
||||
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
|
||||
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
|
||||
for p in group['params']:
|
||||
sliced = p.numel() >= 1024
|
||||
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
|
||||
if sliced: # large parameter tensors will be operated on in slices
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
|
||||
|
||||
|
||||
# Group by shape so we can stack tensors
|
||||
shapes = sorted({p.shape for p in params})
|
||||
param_groups = []
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
param_groups.append(dict(params=group_params))
|
||||
super().__init__(param_groups, defaults)
|
||||
super().__init__(param_groups, defaults)
|
||||
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
if not params:
|
||||
continue
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
reduce_futures: list[torch.Future] = []
|
||||
gather_futures: list[torch.Future] = []
|
||||
grad_slices = []
|
||||
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
|
||||
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for p in params:
|
||||
grad = p.grad
|
||||
# Small params: use all_reduce (no scatter/gather needed)
|
||||
if p.numel() < 1024:
|
||||
is_small.append(True)
|
||||
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad)
|
||||
else:
|
||||
is_small.append(False)
|
||||
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
|
||||
idx = 0
|
||||
for group in self.param_groups:
|
||||
beta1, beta2 = group['betas']
|
||||
eps = group['eps']
|
||||
wd = group['weight_decay']
|
||||
params = group['params']
|
||||
for p in params:
|
||||
reduce_futures[idx].wait()
|
||||
g_slice = grad_slices[idx]
|
||||
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
||||
state = self.state[p]
|
||||
|
||||
# For small params, operate on full param; for large, operate on slice
|
||||
if is_small[idx]:
|
||||
p_slice = p
|
||||
else:
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
eff_wd = wd * getattr(p, "wd_mul", 1.0)
|
||||
self._step_t.fill_(state['step'])
|
||||
self._lr_t.fill_(lr)
|
||||
self._beta1_t.fill_(beta1)
|
||||
self._beta2_t.fill_(beta2)
|
||||
self._eps_t.fill_(eps)
|
||||
self._wd_t.fill_(eff_wd)
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
p_slice, g_slice, exp_avg, exp_avg_sq,
|
||||
self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
|
||||
)
|
||||
|
||||
# Only large params need all_gather
|
||||
if not is_small[idx]:
|
||||
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
idx += 1
|
||||
|
||||
if gather_futures:
|
||||
torch.futures.collect_all(gather_futures).wait()
|
||||
|
||||
# Get or create group-level buffers (stored in first param's state for convenience)
|
||||
state = self.state[params[0]]
|
||||
num_params = len(params) # e.g.: 12 (for a d12 model)
|
||||
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
|
||||
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
|
||||
|
||||
# Momentum for every individual parameter
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
|
||||
|
||||
# Stack grads and params
|
||||
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
|
||||
stacked_params = torch.stack(params) # (12, 768, 3072)
|
||||
|
||||
# Fill all the 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
||||
class DistAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
Distributed AdamW optimizer.
|
||||
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
|
||||
"""
|
||||
class DistMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Distributed version of the Muon optimizer.
|
||||
"""
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
def __init__(
|
||||
self,
|
||||
param_groups,
|
||||
lr: float = 1e-3,
|
||||
betas: tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0.01
|
||||
lr: float = 0.02,
|
||||
momentum: float = 0.95,
|
||||
ns_steps: int = 5,
|
||||
beta2: float = 0.95,
|
||||
weight_decay: float = 0.0
|
||||
):
|
||||
):
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params)
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
param_groups = []
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
# Validate
|
||||
if rank == 0:
|
||||
for group in param_groups:
|
||||
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
|
||||
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
|
||||
for p in group['params']:
|
||||
sliced = p.numel() >= 1024
|
||||
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
|
||||
if sliced: # large parameter tensors will be operated on in slices
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
|
||||
# Group all parameters by their shape
|
||||
shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
device, dtype = group_params[0].device, group_params[0].dtype
|
||||
assert all(p.device == device for p in group_params)
|
||||
assert all(p.dtype == dtype for p in group_params)
|
||||
# Compute chunk size for this group (how many params each rank owns)
|
||||
chunk_size = (len(group_params) + world_size - 1) // world_size
|
||||
if rank == 0:
|
||||
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
|
||||
super().__init__(param_groups, defaults)
|
||||
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# Ensure all grads exist
|
||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for p in params:
|
||||
grad = p.grad
|
||||
# Small params: use all_reduce (no scatter/gather needed)
|
||||
if p.numel() < 1024:
|
||||
is_small.append(True)
|
||||
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad)
|
||||
else:
|
||||
is_small.append(False)
|
||||
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
|
||||
# First pass: stack grads and kick off reduce_scatter for each group
|
||||
group_infos = []
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
padded_num_params = chunk_size * world_size
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
|
||||
# Stack all gradients into a single tensor (single kernel via torch.stack)
|
||||
grad_stack = torch.stack([p.grad for p in params])
|
||||
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
||||
stacked_grads[:len(params)].copy_(grad_stack)
|
||||
# Zero-pad if we have fewer params than padded size
|
||||
if len(params) < padded_num_params:
|
||||
stacked_grads[len(params):].zero_()
|
||||
|
||||
# Output buffer for this rank's chunk
|
||||
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
# Async reduce_scatter on the stacked tensor
|
||||
reduce_future = dist.reduce_scatter_tensor(
|
||||
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
|
||||
).get_future()
|
||||
|
||||
group_infos.append(dict(
|
||||
grad_chunk=grad_chunk,
|
||||
reduce_future=reduce_future,
|
||||
stacked_grads=stacked_grads, # reuse for all_gather output
|
||||
))
|
||||
|
||||
for group in self.param_groups:
|
||||
beta1, beta2 = group['betas']
|
||||
eps = group['eps']
|
||||
wd = group['weight_decay']
|
||||
params = group['params']
|
||||
for p in params:
|
||||
reduce_futures[idx].wait()
|
||||
g_slice = grad_slices[idx]
|
||||
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
||||
state = self.state[p]
|
||||
|
||||
# For small params, operate on full param; for large, operate on slice
|
||||
if is_small[idx]:
|
||||
p_slice = p
|
||||
else:
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
eff_wd = wd * getattr(p, "wd_mul", 1.0)
|
||||
self._step_t.fill_(state['step'])
|
||||
self._lr_t.fill_(lr)
|
||||
self._beta1_t.fill_(beta1)
|
||||
self._beta2_t.fill_(beta2)
|
||||
self._eps_t.fill_(eps)
|
||||
self._wd_t.fill_(eff_wd)
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
p_slice, g_slice, exp_avg, exp_avg_sq,
|
||||
self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
|
||||
)
|
||||
|
||||
# Only large params need all_gather
|
||||
if not is_small[idx]:
|
||||
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
idx += 1
|
||||
|
||||
# Second pass: wait for reduce, compute batched updates, kick off all_gather
|
||||
all_gather_futures = []
|
||||
for group, info in zip(self.param_groups, group_infos):
|
||||
info["reduce_future"].wait()
|
||||
|
||||
params = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
grad_chunk = info["grad_chunk"]
|
||||
|
||||
# How many params does this rank actually own?
|
||||
start_idx = rank * chunk_size
|
||||
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
||||
|
||||
# Get or create group-level state (stored keyed by first param)
|
||||
state = self.state[params[0]]
|
||||
|
||||
# Momentum buffer
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Build updated_params tensor for all_gather
|
||||
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
if num_owned > 0:
|
||||
# Stack owned params (single kernel via torch.stack)
|
||||
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
||||
stacked_owned_params = torch.stack(owned_params)
|
||||
|
||||
# Get owned slices of buffers and grads
|
||||
owned_grads = grad_chunk[:num_owned]
|
||||
owned_momentum = momentum_buffer[:num_owned]
|
||||
owned_second_momentum = second_momentum_buffer[:num_owned]
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
owned_grads,
|
||||
stacked_owned_params,
|
||||
owned_momentum,
|
||||
owned_second_momentum,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy updated params to output buffer
|
||||
updated_params[:num_owned].copy_(stacked_owned_params)
|
||||
|
||||
# Zero-pad the rest (for ranks that own fewer params)
|
||||
if num_owned < chunk_size:
|
||||
updated_params[num_owned:].zero_()
|
||||
|
||||
# Reuse stacked_grads buffer for all_gather output
|
||||
stacked_params = info["stacked_grads"]
|
||||
|
||||
# Async all_gather to replicate updated params to all ranks
|
||||
gather_future = dist.all_gather_into_tensor(
|
||||
stacked_params, updated_params, async_op=True
|
||||
).get_future()
|
||||
|
||||
all_gather_futures.append(dict(
|
||||
gather_future=gather_future,
|
||||
stacked_params=stacked_params,
|
||||
params=params,
|
||||
))
|
||||
|
||||
# Final pass: wait for all_gather and copy back to params
|
||||
if gather_futures:
|
||||
torch.futures.collect_all(gather_futures).wait()
|
||||
|
||||
for info in all_gather_futures:
|
||||
info["gather_future"].wait()
|
||||
stacked_params = info["stacked_params"]
|
||||
params = info["params"]
|
||||
# Batched copy back (single kernel instead of N individual copies)
|
||||
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
|
||||
|
|
@ -206,9 +206,9 @@ print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_sc
|
|||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
# Initialize the Optimizer (Muon for matrix params, AdamW for embeddings and scalars)
|
||||
adam_betas = (args.adam_beta1, args.adam_beta2)
|
||||
optimizers = model.setup_optimizers(
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
|
|
@ -216,11 +216,9 @@ optimizers = model.setup_optimizers(
|
|||
adam_betas=adam_betas,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
if resuming:
|
||||
for opt, dat in zip(optimizers, optimizer_data):
|
||||
opt.load_state_dict(dat)
|
||||
optimizer.load_state_dict(optimizer_data)
|
||||
del optimizer_data # free up the memory
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -339,7 +337,7 @@ while True:
|
|||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(), # model parameters
|
||||
[opt.state_dict() for opt in optimizers], # optimizer states
|
||||
optimizer.state_dict(), # optimizer state
|
||||
{ # metadata saved as json
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
|
|
@ -373,18 +371,19 @@ while True:
|
|||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# step the optimizers
|
||||
# step the optimizer
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
muon_weight_decay = get_weight_decay(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
|
||||
# Update all hyperparams via param_groups
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group["kind"] == "muon":
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
synchronize()
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
|||
# Training loop
|
||||
|
||||
# Init the optimizer
|
||||
optimizers = model.setup_optimizers(
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
|
|
@ -209,10 +209,9 @@ optimizers = model.setup_optimizers(
|
|||
)
|
||||
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning rate so we can decay easily later
|
||||
|
||||
# Learning rate scheduler: simple rampdown to zero over num_steps
|
||||
def get_lr_multiplier(it):
|
||||
|
|
@ -305,11 +304,9 @@ for step in range(num_steps):
|
|||
|
||||
# Update the model parameters
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers: # first set the learning rate
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
for opt in optimizers: # then step the optimizers
|
||||
opt.step()
|
||||
for group in optimizer.param_groups: # set the learning rate before stepping
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
|
|
|
|||
|
|
@ -150,17 +150,16 @@ build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_bat
|
|||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer
|
||||
|
||||
optimizers = model.setup_optimizers(
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning rate so we can decay easily later
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
|
|
@ -230,13 +229,11 @@ for step in range(num_iterations):
|
|||
|
||||
# learning rate scheduler
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
|
||||
# step the optimizers
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
# step the optimizer
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
# logging
|
||||
|
|
|
|||
|
|
@ -93,14 +93,12 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
|||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
# Initialize the Optimizer (Muon for matrix params, AdamW for embeddings and scalars)
|
||||
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning rate so we can decay easily later
|
||||
|
||||
# Midtraining data mixture and DataLoader
|
||||
base_dir = get_base_dir()
|
||||
|
|
@ -274,7 +272,7 @@ while True:
|
|||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
optimizer.state_dict(), # TODO: make sure saving across ranks is done correctly
|
||||
{
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
|
|
@ -306,16 +304,13 @@ while True:
|
|||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||
# step the optimizers
|
||||
# step the optimizer
|
||||
lrm = get_lr_multiplier(progress)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group["kind"] == "muon":
|
||||
group["momentum"] = get_muon_momentum(step)
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
|
|
|
|||
153
speedrun.sh
Normal file
153
speedrun.sh
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
#!/bin/bash
|
||||
|
||||
# This script is the "Best ChatGPT clone that $100 can buy",
|
||||
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
|
||||
|
||||
# 1) Example launch (simplest):
|
||||
# bash speedrun.sh
|
||||
# 2) Example launch in a screen session (because the run takes ~4 hours):
|
||||
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
# 3) Example launch with wandb logging, but see below for setting up wandb first:
|
||||
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
|
||||
# Default intermediate artifacts directory is in ~/.cache/nanochat
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Python venv setup with uv
|
||||
|
||||
# install uv (if not already installed)
|
||||
if ! command -v uv &> /dev/null; then
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
fi
|
||||
# Add uv to PATH (it installs to ~/.local/bin)
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
# create a .venv local virtual environment (if it doesn't exist)
|
||||
[ -d ".venv" ] || uv venv
|
||||
# install the repo dependencies
|
||||
uv sync --extra gpu
|
||||
# activate venv so that `python` uses the project's venv instead of system python
|
||||
source .venv/bin/activate
|
||||
# Ensure we're using the venv Python and torchrun
|
||||
PYTHON=".venv/bin/python"
|
||||
TORCHRUN=".venv/bin/torchrun"
|
||||
|
||||
# Install flash_attn if the wheel exists (for A100 compatibility)
|
||||
if [ -f "flash_attn-2.8.3+cu128torch2.9-cp310-cp310-linux_x86_64.whl" ]; then
|
||||
uv pip install flash_attn-2.8.3+cu128torch2.9-cp310-cp310-linux_x86_64.whl
|
||||
fi
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# wandb setup
|
||||
# If you wish to use wandb for logging (it's nice!, recommended).
|
||||
# You can authenticate in one of two ways:
|
||||
# 1) Set WANDB_API_KEY environment variable before running:
|
||||
# `export WANDB_API_KEY=your_api_key_here`
|
||||
# `bash runs/speedrun.sh`
|
||||
# 2) Or run `wandb login` after the venv is set up (the venv will be active)
|
||||
# The script will automatically use wandb if WANDB_API_KEY is set or if you've logged in.
|
||||
# Set the WANDB_RUN environment variable when running this script, e.g.:
|
||||
# `WANDB_RUN=d26 bash runs/speedrun.sh`
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
# by default use "dummy" : it's handled as a special case, skips logging to wandb
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
|
||||
# If WANDB_API_KEY is set, export it so wandb can use it automatically
|
||||
if [ -n "$WANDB_API_KEY" ]; then
|
||||
export WANDB_API_KEY
|
||||
echo "Using WANDB_API_KEY from environment for wandb authentication"
|
||||
fi
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# During the course of the run, we will be writing markdown reports to the report/
|
||||
# directory in the base dir. This command clears it out and writes a header section
|
||||
# with a bunch of system info and a timestamp that marks the start of the run.
|
||||
$PYTHON -m nanochat.report reset
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tokenizer
|
||||
|
||||
# Download the first ~2B characters of pretraining dataset
|
||||
# look at dev/repackage_data_reference.py for details on how this data was prepared
|
||||
# each data shard is ~250M chars
|
||||
# so we download 2e9 / 250e6 = 8 data shards at this point
|
||||
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
|
||||
$PYTHON -m nanochat.dataset -n 8
|
||||
# Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# See comment below for why 370 is the right number here
|
||||
$PYTHON -m nanochat.dataset -n 370 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
|
||||
$PYTHON -m scripts.tok_train --max-chars=20000000 --vocab-size=50304
|
||||
# evaluate the tokenizer (report compression ratio etc.)
|
||||
$PYTHON -m scripts.tok_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Base model (pretraining)
|
||||
|
||||
# The d20 model is 561M parameters.
|
||||
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
||||
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
||||
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
|
||||
# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping
|
||||
# so 240 / (1 - 0.35) = 370 shards are needed.
|
||||
# At ~100MB/shard, this downloads ~37GB of data to disk.
|
||||
# (The total number of shards available in the entire dataset is 1822.)
|
||||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# Number of processes/GPUs to use
|
||||
NPROC_PER_NODE=8
|
||||
# Per-device batch size (reduce this if you hit OOM - gradient accumulation will automatically increase) Default is 32.
|
||||
# To match modded-nanogpt initial batch: 8 seqs * 2048 seq_len * 8 GPUs = 131,072 tokens
|
||||
DEVICE_BATCH_SIZE=8
|
||||
TOTAL_BATCH_SIZE=131072
|
||||
|
||||
# pretrain the d20 model
|
||||
#$TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train-mine -- --depth=12 --target-param-data-ratio=20 --device-batch-size=$DEVICE_BATCH_SIZE --run=$WANDB_RUN
|
||||
$TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train-main-profiled -- --depth=11 --target-param-data-ratio=20 --device-batch-size=$DEVICE_BATCH_SIZE --total-batch-size=$TOTAL_BATCH_SIZE --run=$WANDB_RUN
|
||||
# # evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
# # evaluate the model on CORE tasks
|
||||
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||
|
||||
# # -----------------------------------------------------------------------------
|
||||
# # Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||
|
||||
# # download 2.3MB of synthetic identity conversations to impart a personality to nanochat
|
||||
# # see dev/gen_synthetic_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
|
||||
# curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
# # run midtraining and eval the model
|
||||
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=$DEVICE_BATCH_SIZE --run=$WANDB_RUN
|
||||
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||
|
||||
# # -----------------------------------------------------------------------------
|
||||
# # Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
||||
|
||||
# # train sft and re-eval right away (should see a small bump)
|
||||
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
|
||||
|
||||
# # chat with the model over CLI! Leave out the -p to chat interactively
|
||||
# # python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# # even better, chat with your model over a pretty WebUI ChatGPT style
|
||||
# # python -m scripts.chat_web
|
||||
|
||||
# # -----------------------------------------------------------------------------
|
||||
# # Reinforcement Learning. Optional, and currently only on GSM8K
|
||||
# # (optional)
|
||||
|
||||
# # run reinforcement learning
|
||||
# # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
|
||||
# # eval the RL model only on GSM8K
|
||||
# # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
|
||||
# # -----------------------------------------------------------------------------
|
||||
# # Generate the full report by putting together all the sections
|
||||
# # report.md is the output and will be copied to current directory for convenience
|
||||
# $PYTHON -m nanochat.report generate
|
||||
Loading…
Reference in New Issue
Block a user