diff --git a/nanochat/optims.py b/nanochat/optims.py index 8bc66565..51fabf57 100644 --- a/nanochat/optims.py +++ b/nanochat/optims.py @@ -1,7 +1,3 @@ -""" -Distributed AdamW optimizer with a fused step function. -A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt. -""" import torch import torch.distributed as dist from torch import Tensor @@ -61,33 +57,6 @@ Some of the changes in nanochat implementation: - Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format) """ -""" -Muon - MomentUm Orthogonalized by Newton-schulz - -https://kellerjordan.github.io/posts/muon/ - -Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- -processing step, in which each 2D parameter's update is replaced with the nearest orthogonal -matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has -the advantage that it can be stably run in bfloat16 on the GPU. - -Some warnings: -- This optimizer should not be used for the embedding layer, the final fully connected layer, -or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). -- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. - -Arguments: - lr: The learning rate used by the internal SGD. - momentum: The momentum used by the internal SGD. - ns_steps: The number of Newton-Schulz iteration steps to use. - beta2: The decay rate for the second moment (variance) estimate. Set to None to disable. - weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree. -""" - -import torch -from torch import Tensor -import torch.distributed as dist - # Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2) # From https://arxiv.org/pdf/2505.16932 polar_express_coeffs = [ @@ -157,27 +126,43 @@ def muon_step_fused( class MuonAdamW(torch.optim.Optimizer): """ Combined optimizer: Muon for 2D matrix params, AdamW for others. - Non-distributed version for single-GPU training. - - Args: + + AdamW - Distributed AdamW optimizer with a fused step function. + + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - The Muon optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. + + AdamW Arguments: adamw_groups: List of dicts with 'params' and optional 'lr' for AdamW params muon_params: List of 2D tensors to optimize with Muon - adamw_lr: Default learning rate for AdamW (default: 1e-3) adamw_betas: Beta coefficients for AdamW (default: (0.9, 0.999)) - adamw_eps: Epsilon for AdamW numerical stability (default: 1e-8) + adamw_eps: Epsilon for AdamW numerical stability (default: 1e-8) adamw_weight_decay: Weight decay for AdamW (default: 0.01) - muon_lr: Learning rate for Muon (default: 0.02) - muon_momentum: Momentum for Muon (default: 0.95) - muon_ns_steps: Number of Newton-Schulz iterations (default: 5) - muon_beta2: Second moment decay for Muon variance reduction (default: 0.95) - muon_weight_decay: Cautious weight decay for Muon (default: 0.0) + + Muon Arguments: + muon_lr: The learning rate used by the internal SGD. + muon_momentum: The momentum used by the internal SGD. + muon_ns_steps: The number of Newton-Schulz iteration steps to use. + muon_beta2: The decay rate for the second moment (variance) estimate. Set to None to disable. + muon_weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree. """ def __init__( self, adamw_groups: list[dict], muon_params, # AdamW hyperparams - adamw_lr: float = 1e-3, + adamw_lr: float = 1e-3, # can be overridden per-group adamw_betas: tuple[float, float] = (0.9, 0.999), adamw_eps: float = 1e-8, adamw_weight_decay: float = 0.01, @@ -198,7 +183,7 @@ class MuonAdamW(torch.optim.Optimizer): for group in adamw_groups: assert isinstance(group, dict) and 'params' in group params = list(group['params']) - lr = group.get('lr', adamw_lr) + lr = group.get('lr', adamw_lr) # AdamW supports per-group learning rates for p in params: print(f"AdamW: 1 param of shape {p.shape}") param_groups.append(dict( @@ -217,7 +202,7 @@ class MuonAdamW(torch.optim.Optimizer): beta2=muon_beta2, weight_decay=muon_weight_decay, )) - defaults = dict(lr=adamw_lr) + defaults = dict(lr=adamw_lr) # torch.optim.Optimizer requires a default lr super().__init__(param_groups, defaults) # 0-D CPU tensors to avoid torch.compile recompilation when values change @@ -331,21 +316,16 @@ class MuonAdamW(torch.optim.Optimizer): class DistMuonAdamW(torch.optim.Optimizer): """ Combined distributed optimizer: Muon for 2D matrix params, AdamW for others. - - Communication optimization: starts communications for largest tensors first, - then processes smallest tensors first while large tensor comms complete. - This overlaps communication with computation for better efficiency. - - Args: - adamw_groups: List of dicts with 'params' and optional 'lr' for AdamW params - muon_params: List of 2D tensors to optimize with Muon - adamw_betas: Beta coefficients for AdamW (default: (0.9, 0.999)) - adamw_eps: Epsilon for AdamW numerical stability (default: 1e-8) - muon_lr: Learning rate for Muon (default: 0.02) - muon_momentum: Momentum for Muon (default: 0.95) - muon_ns_steps: Number of Newton-Schulz iterations (default: 5) - muon_beta2: Second moment decay for Muon variance reduction (default: 0.95) - muon_weight_decay: Cautious weight decay for Muon (default: 0.0) + + (See MuonAdamW for algorithmic details.) + + AdamW Communication: + In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction. + A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt. + + Muon Communication: + Parameters are grouped by shape, then stacked into single Tensors for efficient communication. + We launch comms largest-first, then process smallest-first so large comms finish in time. """ def __init__( self, @@ -374,11 +354,11 @@ class DistMuonAdamW(torch.optim.Optimizer): rank = dist.get_rank() world_size = dist.get_world_size() - # Validate - if rank == 0: - - # AdamW groups: each input group becomes one param_group - for group in adamw_groups: + + # AdamW groups: each input group becomes one param_group + for group in adamw_groups: + # Validate + if rank == 0: assert isinstance(group, dict), "expecting param_groups to be a list of dicts" assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors" for p in group['params']: @@ -386,18 +366,16 @@ class DistMuonAdamW(torch.optim.Optimizer): print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}") if sliced: # large parameter tensors will be operated on in slices assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}" - - # AdamW groups: each input group becomes one param_group - for group in adamw_groups: + # Add to param_groups params = list(group['params']) - lr = group.get('lr', adamw_lr) + lr = group.get('lr', adamw_lr) # AdamW supports per-group learning rates param_groups.append(dict( params=params, lr=lr, kind='adamw', betas=adamw_betas, eps=adamw_eps, weight_decay=adamw_weight_decay, )) # Muon groups: group by shape for stacking, with all Muon hyperparams in the group - muon_shapes = sorted({p.shape for p in muon_params}) + muon_shapes = sorted({p.shape for p in muon_params}) # sort for deterministic ordering across ranks for shape in muon_shapes: group_params = [p for p in muon_params if p.shape == shape] device, dtype = group_params[0].device, group_params[0].dtype @@ -413,9 +391,9 @@ class DistMuonAdamW(torch.optim.Optimizer): beta2=muon_beta2, weight_decay=muon_weight_decay, )) - defaults = dict(lr=adamw_lr) + defaults = dict(lr=adamw_lr) # torch.optim.Optimizer requires a default lr super().__init__(param_groups, defaults) - + # 0-D CPU tensors to avoid torch.compile recompilation when values change # AdamW tensors self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") @@ -429,25 +407,6 @@ class DistMuonAdamW(torch.optim.Optimizer): self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - - # Precompute group order sorted by communication size for optimal overlap - # We launch comms largest-first, then process smallest-first so large comms finish in time - group_comm_sizes = [] - for group_idx, group in enumerate(self.param_groups): - if group['kind'] == 'adamw': - # AdamW group comm size = size of first param (all params in group have same shape) - comm_size = group['params'][0].numel() - else: # muon - # Muon group comm size = padded stacked tensor size - chunk_size = group['chunk_size'] - comm_size = chunk_size * world_size * group['params'][0].numel() - group_comm_sizes.append((comm_size, group_idx)) - - # Sort: largest first for comms, we'll reverse for compute - group_comm_sizes.sort(key=lambda x: x[0], reverse=True) - self._group_order = [group_idx for _, group_idx in group_comm_sizes] - if rank == 0: - print(f"DistMuonAdamW: {len(self._group_order)} groups, comm order by size (largest first)") @torch.no_grad() def step(self): @@ -457,13 +416,11 @@ class DistMuonAdamW(torch.optim.Optimizer): # Ensure all grads exist assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads" - # First pass: launch all async communications (largest groups first) + # First pass: launch all async communications adamw_infos: dict[Tensor, dict] = {} # param -> {reduce_future, grad_slice, is_small} muon_infos: dict[int, dict] = {} # group_idx -> {reduce_future, grad_chunk, stacked_grads} - for group_idx in self._group_order: - group = self.param_groups[group_idx] - + for group_idx, group in enumerate(self.param_groups): if group['kind'] == 'adamw': params: list[Tensor] = group['params'] for p in params: @@ -474,7 +431,6 @@ class DistMuonAdamW(torch.optim.Optimizer): adamw_infos[p] = dict(reduce_future=reduce_future, grad_slice=grad, is_small=True) # Large param: reduce_scatter else: - rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__ grad_slice = torch.empty_like(grad[:rank_size]) reduce_future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future() @@ -508,13 +464,10 @@ class DistMuonAdamW(torch.optim.Optimizer): stacked_grads=stacked_grads, # reuse for all_gather output ) - # Second pass: process groups (smallest first, so large comms finish in time) wait for reduce, compute batched updates, kick off all_gather - gather_futures: list[torch.Future] = [] - muon_gather_infos: list[dict] = [] - - for group_idx in reversed(self._group_order): - group = self.param_groups[group_idx] + # Second pass: wait for reduce, compute updates, kick off all_gather + gather_infos: list[dict] = [] # unified list for both AdamW and Muon gathers + for group_idx, group in enumerate(self.param_groups): if group['kind'] == 'adamw': beta1, beta2 = group['betas'] eps = group['eps'] @@ -561,7 +514,8 @@ class DistMuonAdamW(torch.optim.Optimizer): # Only large params need all_gather if not info['is_small']: - gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + gather_future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future() + gather_infos.append(dict(gather_future=gather_future, params=None)) else: # muon info = muon_infos[group_idx] @@ -642,19 +596,18 @@ class DistMuonAdamW(torch.optim.Optimizer): stacked_params, updated_params, async_op=True ).get_future() - muon_gather_infos.append(dict( + gather_infos.append(dict( gather_future=gather_future, stacked_params=stacked_params, params=params, )) - # Final pass: wait for all_gather and copy back to params - if gather_futures: - torch.futures.collect_all(gather_futures).wait() - - for info in muon_gather_infos: + # Final pass: wait for all_gather and copy back to params (Muon only) + for info in gather_infos: info["gather_future"].wait() - stacked_params = info["stacked_params"] - params = info["params"] - # Batched copy back (single kernel instead of N individual copies) - torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0))) \ No newline at end of file + # Muon params need to be copied back from stacked buffer + if info["params"] is not None: + stacked_params = info["stacked_params"] + params = info["params"] + # Batched copy back (single kernel instead of N individual copies) + torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0))) \ No newline at end of file diff --git a/speedrun.sh b/speedrun.sh new file mode 100644 index 00000000..9b3dff52 --- /dev/null +++ b/speedrun.sh @@ -0,0 +1,153 @@ +#!/bin/bash + +# This script is the "Best ChatGPT clone that $100 can buy", +# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour. + +# 1) Example launch (simplest): +# bash speedrun.sh +# 2) Example launch in a screen session (because the run takes ~4 hours): +# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh +# 3) Example launch with wandb logging, but see below for setting up wandb first: +# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh + +# Default intermediate artifacts directory is in ~/.cache/nanochat +export OMP_NUM_THREADS=1 +export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" +mkdir -p $NANOCHAT_BASE_DIR + +# ----------------------------------------------------------------------------- +# Python venv setup with uv + +# install uv (if not already installed) +if ! command -v uv &> /dev/null; then + curl -LsSf https://astral.sh/uv/install.sh | sh +fi +# Add uv to PATH (it installs to ~/.local/bin) +export PATH="$HOME/.local/bin:$PATH" +# create a .venv local virtual environment (if it doesn't exist) +[ -d ".venv" ] || uv venv +# install the repo dependencies +uv sync --extra gpu +# activate venv so that `python` uses the project's venv instead of system python +source .venv/bin/activate +# Ensure we're using the venv Python and torchrun +PYTHON=".venv/bin/python" +TORCHRUN=".venv/bin/torchrun" + +# Install flash_attn if the wheel exists (for A100 compatibility) +if [ -f "flash_attn-2.8.3+cu128torch2.9-cp310-cp310-linux_x86_64.whl" ]; then + uv pip install flash_attn-2.8.3+cu128torch2.9-cp310-cp310-linux_x86_64.whl +fi + +# ----------------------------------------------------------------------------- +# wandb setup +# If you wish to use wandb for logging (it's nice!, recommended). +# You can authenticate in one of two ways: +# 1) Set WANDB_API_KEY environment variable before running: +# `export WANDB_API_KEY=your_api_key_here` +# `bash runs/speedrun.sh` +# 2) Or run `wandb login` after the venv is set up (the venv will be active) +# The script will automatically use wandb if WANDB_API_KEY is set or if you've logged in. +# Set the WANDB_RUN environment variable when running this script, e.g.: +# `WANDB_RUN=d26 bash runs/speedrun.sh` +if [ -z "$WANDB_RUN" ]; then + # by default use "dummy" : it's handled as a special case, skips logging to wandb + WANDB_RUN=dummy +fi + +# If WANDB_API_KEY is set, export it so wandb can use it automatically +if [ -n "$WANDB_API_KEY" ]; then + export WANDB_API_KEY + echo "Using WANDB_API_KEY from environment for wandb authentication" +fi + +# ----------------------------------------------------------------------------- +# During the course of the run, we will be writing markdown reports to the report/ +# directory in the base dir. This command clears it out and writes a header section +# with a bunch of system info and a timestamp that marks the start of the run. +$PYTHON -m nanochat.report reset + +# ----------------------------------------------------------------------------- +# Tokenizer + +# Download the first ~2B characters of pretraining dataset +# look at dev/repackage_data_reference.py for details on how this data was prepared +# each data shard is ~250M chars +# so we download 2e9 / 250e6 = 8 data shards at this point +# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk +$PYTHON -m nanochat.dataset -n 8 +# Immediately also kick off downloading more shards in the background while tokenizer trains +# See comment below for why 370 is the right number here +$PYTHON -m nanochat.dataset -n 370 & +DATASET_DOWNLOAD_PID=$! +# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data +$PYTHON -m scripts.tok_train --max-chars=20000000 --vocab-size=50304 +# evaluate the tokenizer (report compression ratio etc.) +$PYTHON -m scripts.tok_eval + +# ----------------------------------------------------------------------------- +# Base model (pretraining) + +# The d20 model is 561M parameters. +# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens. +# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars. +# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining. +# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping +# so 240 / (1 - 0.35) = 370 shards are needed. +# At ~100MB/shard, this downloads ~37GB of data to disk. +# (The total number of shards available in the entire dataset is 1822.) +echo "Waiting for dataset download to complete..." +wait $DATASET_DOWNLOAD_PID + +# Number of processes/GPUs to use +NPROC_PER_NODE=8 +# Per-device batch size (reduce this if you hit OOM - gradient accumulation will automatically increase) Default is 32. +# To match modded-nanogpt initial batch: 8 seqs * 2048 seq_len * 8 GPUs = 131,072 tokens +DEVICE_BATCH_SIZE=8 +TOTAL_BATCH_SIZE=131072 + +# pretrain the d20 model +#$TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train-mine -- --depth=12 --target-param-data-ratio=20 --device-batch-size=$DEVICE_BATCH_SIZE --run=$WANDB_RUN +$TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train-main-profiled -- --depth=11 --target-param-data-ratio=20 --device-batch-size=$DEVICE_BATCH_SIZE --total-batch-size=$TOTAL_BATCH_SIZE --run=$WANDB_RUN +# # evaluate the model on a larger chunk of train/val data and draw some samples +# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss +# # evaluate the model on CORE tasks +# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval + +# # ----------------------------------------------------------------------------- +# # Midtraining (teach the model conversation special tokens, tool use, multiple choice) + +# # download 2.3MB of synthetic identity conversations to impart a personality to nanochat +# # see dev/gen_synthetic_data.py for details on how this data was prepared and to get a sense of how you can easily tune it +# curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl + +# # run midtraining and eval the model +# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=$DEVICE_BATCH_SIZE --run=$WANDB_RUN +# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid + +# # ----------------------------------------------------------------------------- +# # Supervised Finetuning (domain adaptation to each sequence all by itself per row) + +# # train sft and re-eval right away (should see a small bump) +# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN +# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft + +# # chat with the model over CLI! Leave out the -p to chat interactively +# # python -m scripts.chat_cli -p "Why is the sky blue?" + +# # even better, chat with your model over a pretty WebUI ChatGPT style +# # python -m scripts.chat_web + +# # ----------------------------------------------------------------------------- +# # Reinforcement Learning. Optional, and currently only on GSM8K +# # (optional) + +# # run reinforcement learning +# # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN +# # eval the RL model only on GSM8K +# # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K + +# # ----------------------------------------------------------------------------- +# # Generate the full report by putting together all the sections +# # report.md is the output and will be copied to current directory for convenience +# $PYTHON -m nanochat.report generate