Remove ordering

- Removed size-based ordering for comms and work.
- Moved the comments around.
- Consolidated the gather futures into one list.
This commit is contained in:
Chris McCormick 2026-01-26 12:14:19 -08:00
parent d1595fb2d1
commit 7aceb020dd
2 changed files with 218 additions and 112 deletions

View File

@ -1,7 +1,3 @@
"""
Distributed AdamW optimizer with a fused step function.
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
"""
import torch
import torch.distributed as dist
from torch import Tensor
@ -61,33 +57,6 @@ Some of the changes in nanochat implementation:
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
"""
"""
Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- This optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
ns_steps: The number of Newton-Schulz iteration steps to use.
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
"""
import torch
from torch import Tensor
import torch.distributed as dist
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
# From https://arxiv.org/pdf/2505.16932
polar_express_coeffs = [
@ -157,27 +126,43 @@ def muon_step_fused(
class MuonAdamW(torch.optim.Optimizer):
"""
Combined optimizer: Muon for 2D matrix params, AdamW for others.
Non-distributed version for single-GPU training.
Args:
AdamW - Distributed AdamW optimizer with a fused step function.
Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
AdamW Arguments:
adamw_groups: List of dicts with 'params' and optional 'lr' for AdamW params
muon_params: List of 2D tensors to optimize with Muon
adamw_lr: Default learning rate for AdamW (default: 1e-3)
adamw_betas: Beta coefficients for AdamW (default: (0.9, 0.999))
adamw_eps: Epsilon for AdamW numerical stability (default: 1e-8)
adamw_eps: Epsilon for AdamW numerical stability (default: 1e-8)
adamw_weight_decay: Weight decay for AdamW (default: 0.01)
muon_lr: Learning rate for Muon (default: 0.02)
muon_momentum: Momentum for Muon (default: 0.95)
muon_ns_steps: Number of Newton-Schulz iterations (default: 5)
muon_beta2: Second moment decay for Muon variance reduction (default: 0.95)
muon_weight_decay: Cautious weight decay for Muon (default: 0.0)
Muon Arguments:
muon_lr: The learning rate used by the internal SGD.
muon_momentum: The momentum used by the internal SGD.
muon_ns_steps: The number of Newton-Schulz iteration steps to use.
muon_beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
muon_weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
"""
def __init__(
self,
adamw_groups: list[dict],
muon_params,
# AdamW hyperparams
adamw_lr: float = 1e-3,
adamw_lr: float = 1e-3, # can be overridden per-group
adamw_betas: tuple[float, float] = (0.9, 0.999),
adamw_eps: float = 1e-8,
adamw_weight_decay: float = 0.01,
@ -198,7 +183,7 @@ class MuonAdamW(torch.optim.Optimizer):
for group in adamw_groups:
assert isinstance(group, dict) and 'params' in group
params = list(group['params'])
lr = group.get('lr', adamw_lr)
lr = group.get('lr', adamw_lr) # AdamW supports per-group learning rates
for p in params:
print(f"AdamW: 1 param of shape {p.shape}")
param_groups.append(dict(
@ -217,7 +202,7 @@ class MuonAdamW(torch.optim.Optimizer):
beta2=muon_beta2, weight_decay=muon_weight_decay,
))
defaults = dict(lr=adamw_lr)
defaults = dict(lr=adamw_lr) # torch.optim.Optimizer requires a default lr
super().__init__(param_groups, defaults)
# 0-D CPU tensors to avoid torch.compile recompilation when values change
@ -331,21 +316,16 @@ class MuonAdamW(torch.optim.Optimizer):
class DistMuonAdamW(torch.optim.Optimizer):
"""
Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
Communication optimization: starts communications for largest tensors first,
then processes smallest tensors first while large tensor comms complete.
This overlaps communication with computation for better efficiency.
Args:
adamw_groups: List of dicts with 'params' and optional 'lr' for AdamW params
muon_params: List of 2D tensors to optimize with Muon
adamw_betas: Beta coefficients for AdamW (default: (0.9, 0.999))
adamw_eps: Epsilon for AdamW numerical stability (default: 1e-8)
muon_lr: Learning rate for Muon (default: 0.02)
muon_momentum: Momentum for Muon (default: 0.95)
muon_ns_steps: Number of Newton-Schulz iterations (default: 5)
muon_beta2: Second moment decay for Muon variance reduction (default: 0.95)
muon_weight_decay: Cautious weight decay for Muon (default: 0.0)
(See MuonAdamW for algorithmic details.)
AdamW Communication:
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction.
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
Muon Communication:
Parameters are grouped by shape, then stacked into single Tensors for efficient communication.
We launch comms largest-first, then process smallest-first so large comms finish in time.
"""
def __init__(
self,
@ -374,11 +354,11 @@ class DistMuonAdamW(torch.optim.Optimizer):
rank = dist.get_rank()
world_size = dist.get_world_size()
# Validate
if rank == 0:
# AdamW groups: each input group becomes one param_group
for group in adamw_groups:
# AdamW groups: each input group becomes one param_group
for group in adamw_groups:
# Validate
if rank == 0:
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
for p in group['params']:
@ -386,18 +366,16 @@ class DistMuonAdamW(torch.optim.Optimizer):
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
if sliced: # large parameter tensors will be operated on in slices
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
# AdamW groups: each input group becomes one param_group
for group in adamw_groups:
# Add to param_groups
params = list(group['params'])
lr = group.get('lr', adamw_lr)
lr = group.get('lr', adamw_lr) # AdamW supports per-group learning rates
param_groups.append(dict(
params=params, lr=lr, kind='adamw',
betas=adamw_betas, eps=adamw_eps, weight_decay=adamw_weight_decay,
))
# Muon groups: group by shape for stacking, with all Muon hyperparams in the group
muon_shapes = sorted({p.shape for p in muon_params})
muon_shapes = sorted({p.shape for p in muon_params}) # sort for deterministic ordering across ranks
for shape in muon_shapes:
group_params = [p for p in muon_params if p.shape == shape]
device, dtype = group_params[0].device, group_params[0].dtype
@ -413,9 +391,9 @@ class DistMuonAdamW(torch.optim.Optimizer):
beta2=muon_beta2, weight_decay=muon_weight_decay,
))
defaults = dict(lr=adamw_lr)
defaults = dict(lr=adamw_lr) # torch.optim.Optimizer requires a default lr
super().__init__(param_groups, defaults)
# 0-D CPU tensors to avoid torch.compile recompilation when values change
# AdamW tensors
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
@ -429,25 +407,6 @@ class DistMuonAdamW(torch.optim.Optimizer):
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
# Precompute group order sorted by communication size for optimal overlap
# We launch comms largest-first, then process smallest-first so large comms finish in time
group_comm_sizes = []
for group_idx, group in enumerate(self.param_groups):
if group['kind'] == 'adamw':
# AdamW group comm size = size of first param (all params in group have same shape)
comm_size = group['params'][0].numel()
else: # muon
# Muon group comm size = padded stacked tensor size
chunk_size = group['chunk_size']
comm_size = chunk_size * world_size * group['params'][0].numel()
group_comm_sizes.append((comm_size, group_idx))
# Sort: largest first for comms, we'll reverse for compute
group_comm_sizes.sort(key=lambda x: x[0], reverse=True)
self._group_order = [group_idx for _, group_idx in group_comm_sizes]
if rank == 0:
print(f"DistMuonAdamW: {len(self._group_order)} groups, comm order by size (largest first)")
@torch.no_grad()
def step(self):
@ -457,13 +416,11 @@ class DistMuonAdamW(torch.optim.Optimizer):
# Ensure all grads exist
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
# First pass: launch all async communications (largest groups first)
# First pass: launch all async communications
adamw_infos: dict[Tensor, dict] = {} # param -> {reduce_future, grad_slice, is_small}
muon_infos: dict[int, dict] = {} # group_idx -> {reduce_future, grad_chunk, stacked_grads}
for group_idx in self._group_order:
group = self.param_groups[group_idx]
for group_idx, group in enumerate(self.param_groups):
if group['kind'] == 'adamw':
params: list[Tensor] = group['params']
for p in params:
@ -474,7 +431,6 @@ class DistMuonAdamW(torch.optim.Optimizer):
adamw_infos[p] = dict(reduce_future=reduce_future, grad_slice=grad, is_small=True)
# Large param: reduce_scatter
else:
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
grad_slice = torch.empty_like(grad[:rank_size])
reduce_future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
@ -508,13 +464,10 @@ class DistMuonAdamW(torch.optim.Optimizer):
stacked_grads=stacked_grads, # reuse for all_gather output
)
# Second pass: process groups (smallest first, so large comms finish in time) wait for reduce, compute batched updates, kick off all_gather
gather_futures: list[torch.Future] = []
muon_gather_infos: list[dict] = []
for group_idx in reversed(self._group_order):
group = self.param_groups[group_idx]
# Second pass: wait for reduce, compute updates, kick off all_gather
gather_infos: list[dict] = [] # unified list for both AdamW and Muon gathers
for group_idx, group in enumerate(self.param_groups):
if group['kind'] == 'adamw':
beta1, beta2 = group['betas']
eps = group['eps']
@ -561,7 +514,8 @@ class DistMuonAdamW(torch.optim.Optimizer):
# Only large params need all_gather
if not info['is_small']:
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
gather_future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
gather_infos.append(dict(gather_future=gather_future, params=None))
else: # muon
info = muon_infos[group_idx]
@ -642,19 +596,18 @@ class DistMuonAdamW(torch.optim.Optimizer):
stacked_params, updated_params, async_op=True
).get_future()
muon_gather_infos.append(dict(
gather_infos.append(dict(
gather_future=gather_future,
stacked_params=stacked_params,
params=params,
))
# Final pass: wait for all_gather and copy back to params
if gather_futures:
torch.futures.collect_all(gather_futures).wait()
for info in muon_gather_infos:
# Final pass: wait for all_gather and copy back to params (Muon only)
for info in gather_infos:
info["gather_future"].wait()
stacked_params = info["stacked_params"]
params = info["params"]
# Batched copy back (single kernel instead of N individual copies)
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
# Muon params need to be copied back from stacked buffer
if info["params"] is not None:
stacked_params = info["stacked_params"]
params = info["params"]
# Batched copy back (single kernel instead of N individual copies)
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))

153
speedrun.sh Normal file
View File

@ -0,0 +1,153 @@
#!/bin/bash
# This script is the "Best ChatGPT clone that $100 can buy",
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
# 1) Example launch (simplest):
# bash speedrun.sh
# 2) Example launch in a screen session (because the run takes ~4 hours):
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# 3) Example launch with wandb logging, but see below for setting up wandb first:
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# Default intermediate artifacts directory is in ~/.cache/nanochat
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
# -----------------------------------------------------------------------------
# Python venv setup with uv
# install uv (if not already installed)
if ! command -v uv &> /dev/null; then
curl -LsSf https://astral.sh/uv/install.sh | sh
fi
# Add uv to PATH (it installs to ~/.local/bin)
export PATH="$HOME/.local/bin:$PATH"
# create a .venv local virtual environment (if it doesn't exist)
[ -d ".venv" ] || uv venv
# install the repo dependencies
uv sync --extra gpu
# activate venv so that `python` uses the project's venv instead of system python
source .venv/bin/activate
# Ensure we're using the venv Python and torchrun
PYTHON=".venv/bin/python"
TORCHRUN=".venv/bin/torchrun"
# Install flash_attn if the wheel exists (for A100 compatibility)
if [ -f "flash_attn-2.8.3+cu128torch2.9-cp310-cp310-linux_x86_64.whl" ]; then
uv pip install flash_attn-2.8.3+cu128torch2.9-cp310-cp310-linux_x86_64.whl
fi
# -----------------------------------------------------------------------------
# wandb setup
# If you wish to use wandb for logging (it's nice!, recommended).
# You can authenticate in one of two ways:
# 1) Set WANDB_API_KEY environment variable before running:
# `export WANDB_API_KEY=your_api_key_here`
# `bash runs/speedrun.sh`
# 2) Or run `wandb login` after the venv is set up (the venv will be active)
# The script will automatically use wandb if WANDB_API_KEY is set or if you've logged in.
# Set the WANDB_RUN environment variable when running this script, e.g.:
# `WANDB_RUN=d26 bash runs/speedrun.sh`
if [ -z "$WANDB_RUN" ]; then
# by default use "dummy" : it's handled as a special case, skips logging to wandb
WANDB_RUN=dummy
fi
# If WANDB_API_KEY is set, export it so wandb can use it automatically
if [ -n "$WANDB_API_KEY" ]; then
export WANDB_API_KEY
echo "Using WANDB_API_KEY from environment for wandb authentication"
fi
# -----------------------------------------------------------------------------
# During the course of the run, we will be writing markdown reports to the report/
# directory in the base dir. This command clears it out and writes a header section
# with a bunch of system info and a timestamp that marks the start of the run.
$PYTHON -m nanochat.report reset
# -----------------------------------------------------------------------------
# Tokenizer
# Download the first ~2B characters of pretraining dataset
# look at dev/repackage_data_reference.py for details on how this data was prepared
# each data shard is ~250M chars
# so we download 2e9 / 250e6 = 8 data shards at this point
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
$PYTHON -m nanochat.dataset -n 8
# Immediately also kick off downloading more shards in the background while tokenizer trains
# See comment below for why 370 is the right number here
$PYTHON -m nanochat.dataset -n 370 &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
$PYTHON -m scripts.tok_train --max-chars=20000000 --vocab-size=50304
# evaluate the tokenizer (report compression ratio etc.)
$PYTHON -m scripts.tok_eval
# -----------------------------------------------------------------------------
# Base model (pretraining)
# The d20 model is 561M parameters.
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping
# so 240 / (1 - 0.35) = 370 shards are needed.
# At ~100MB/shard, this downloads ~37GB of data to disk.
# (The total number of shards available in the entire dataset is 1822.)
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
# Number of processes/GPUs to use
NPROC_PER_NODE=8
# Per-device batch size (reduce this if you hit OOM - gradient accumulation will automatically increase) Default is 32.
# To match modded-nanogpt initial batch: 8 seqs * 2048 seq_len * 8 GPUs = 131,072 tokens
DEVICE_BATCH_SIZE=8
TOTAL_BATCH_SIZE=131072
# pretrain the d20 model
#$TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train-mine -- --depth=12 --target-param-data-ratio=20 --device-batch-size=$DEVICE_BATCH_SIZE --run=$WANDB_RUN
$TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train-main-profiled -- --depth=11 --target-param-data-ratio=20 --device-batch-size=$DEVICE_BATCH_SIZE --total-batch-size=$TOTAL_BATCH_SIZE --run=$WANDB_RUN
# # evaluate the model on a larger chunk of train/val data and draw some samples
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
# # evaluate the model on CORE tasks
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
# # -----------------------------------------------------------------------------
# # Midtraining (teach the model conversation special tokens, tool use, multiple choice)
# # download 2.3MB of synthetic identity conversations to impart a personality to nanochat
# # see dev/gen_synthetic_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
# curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
# # run midtraining and eval the model
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=$DEVICE_BATCH_SIZE --run=$WANDB_RUN
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# # -----------------------------------------------------------------------------
# # Supervised Finetuning (domain adaptation to each sequence all by itself per row)
# # train sft and re-eval right away (should see a small bump)
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
# # chat with the model over CLI! Leave out the -p to chat interactively
# # python -m scripts.chat_cli -p "Why is the sky blue?"
# # even better, chat with your model over a pretty WebUI ChatGPT style
# # python -m scripts.chat_web
# # -----------------------------------------------------------------------------
# # Reinforcement Learning. Optional, and currently only on GSM8K
# # (optional)
# # run reinforcement learning
# # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
# # eval the RL model only on GSM8K
# # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
# # -----------------------------------------------------------------------------
# # Generate the full report by putting together all the sections
# # report.md is the output and will be copied to current directory for convenience
# $PYTHON -m nanochat.report generate