mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-07 09:50:28 +00:00
L3 generalizes token embeddings by placing per-token lookup tables inside the decoder stack. Unlike MoE, routing is static (determined by token ID), eliminating router training and load-balancing losses. Implementation: - nanochat/l3.py: LZW allocation algorithm and L3Layer module with vectorized gather+pad+mask forward pass, tied/untied KV support - GPT integration: L3 layers sit between decoder blocks, applied residually (x = x + l3_layer(x, token_ids)) - CLI: --l3-after-layers, --l3-n-emb, --l3-d-up, --l3-k-max flags with LZW precomputation from training data sample - 17 tests covering allocation, layer, and GPT integration Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
640 lines
33 KiB
Python
640 lines
33 KiB
Python
"""
|
|
Train model. From root directory of the project, run as:
|
|
|
|
python -m scripts.base_train
|
|
|
|
or distributed as:
|
|
|
|
torchrun --nproc_per_node=8 -m scripts.base_train
|
|
|
|
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
|
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
|
"""
|
|
|
|
import os
|
|
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
|
import gc
|
|
import json
|
|
import time
|
|
import math
|
|
import argparse
|
|
from dataclasses import asdict
|
|
from contextlib import nullcontext, contextmanager
|
|
|
|
import wandb
|
|
import torch
|
|
|
|
from nanochat.gpt import GPT, GPTConfig
|
|
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
|
|
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
|
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
|
from nanochat.loss_eval import evaluate_bpb
|
|
from nanochat.engine import Engine
|
|
from nanochat.flash_attention import HAS_FA3
|
|
from scripts.base_eval import evaluate_core
|
|
print_banner()
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# CLI arguments
|
|
parser = argparse.ArgumentParser(description="Pretrain base model")
|
|
# Logging
|
|
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
|
# Runtime
|
|
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
|
# FP8 training
|
|
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
|
|
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
|
|
# Model architecture
|
|
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
|
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
|
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
|
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
|
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
|
# L3 (Large Lookup Layers)
|
|
parser.add_argument("--l3-after-layers", type=str, default="", help="comma-separated layer indices for L3 (empty = disabled)")
|
|
parser.add_argument("--l3-n-emb", type=int, default=0, help="total L3 embeddings (0 = auto-derive from model size)")
|
|
parser.add_argument("--l3-d-up", type=int, default=0, help="L3 up-projection dim (0 = 4*n_embd)")
|
|
parser.add_argument("--l3-k-max", type=int, default=512, help="max embeddings per token for L3")
|
|
# Training horizon (only one used, in order of precedence)
|
|
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
|
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
|
parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
|
# Optimization
|
|
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.")
|
|
parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)")
|
|
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
|
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
|
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
|
|
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
|
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
|
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
|
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
|
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
|
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
|
|
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
|
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
|
# Evaluation
|
|
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
|
parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on")
|
|
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
|
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
|
|
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
|
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
|
# Output
|
|
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
|
args = parser.parse_args()
|
|
user_config = vars(args).copy() # for logging
|
|
# -----------------------------------------------------------------------------
|
|
# Compute init and wandb logging
|
|
|
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
|
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
|
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
|
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
|
if device_type == "cuda":
|
|
gpu_device_name = torch.cuda.get_device_name(0)
|
|
gpu_peak_flops = get_peak_flops(gpu_device_name)
|
|
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
|
else:
|
|
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
|
|
|
# wandb logging init
|
|
use_dummy_wandb = args.run == "dummy" or not master_process
|
|
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
|
|
|
# Flash Attention status
|
|
if HAS_FA3:
|
|
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
|
else:
|
|
print0("!" * 80)
|
|
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
|
print0("WARNING: Training will be less efficient without FA3")
|
|
if args.window_pattern != "L":
|
|
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
|
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
|
|
print0("!" * 80)
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Tokenizer will be useful for evaluation and also we need the vocab size to init the model
|
|
tokenizer = get_tokenizer()
|
|
token_bytes = get_token_bytes(device=device)
|
|
vocab_size = tokenizer.get_vocab_size()
|
|
print0(f"Vocab size: {vocab_size:,}")
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Initialize the Model
|
|
|
|
def build_model_meta(depth, l3_after_layers="", l3_n_emb=0):
|
|
"""Build a model on meta device for a given depth (shapes/dtypes only, no data)."""
|
|
# Model dim is nudged up to nearest multiple of head_dim for clean division
|
|
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
|
|
base_dim = depth * args.aspect_ratio
|
|
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
|
|
num_heads = model_dim // args.head_dim
|
|
config = GPTConfig(
|
|
sequence_len=args.max_seq_len, vocab_size=vocab_size,
|
|
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
|
window_pattern=args.window_pattern,
|
|
l3_after_layers=l3_after_layers,
|
|
l3_n_emb=l3_n_emb,
|
|
l3_d_up=args.l3_d_up,
|
|
l3_k_max=args.l3_k_max,
|
|
)
|
|
with torch.device("meta"):
|
|
model_meta = GPT(config)
|
|
return model_meta
|
|
|
|
# L3 precomputation: compute LZW allocation from training data sample
|
|
l3_n_emb = args.l3_n_emb
|
|
l3_bounds = None
|
|
if args.l3_after_layers:
|
|
from nanochat.l3 import compute_lzw_allocation, allocation_to_bounds
|
|
# Auto-derive n_emb if not specified: scale proportional to model size
|
|
if l3_n_emb == 0:
|
|
# Build a temporary model to get param count for auto-derivation
|
|
tmp_model = build_model_meta(args.depth, l3_after_layers=args.l3_after_layers, l3_n_emb=vocab_size)
|
|
tmp_params = sum(p.numel() for p in tmp_model.parameters())
|
|
l3_n_emb = max(vocab_size, tmp_params // 1000)
|
|
del tmp_model
|
|
print0(f"Auto-derived L3 n_emb: {l3_n_emb:,}")
|
|
# Read a sample of training data for LZW allocation
|
|
sample_loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, 1, args.max_seq_len, split="train", device=device)
|
|
sample_sequences = []
|
|
for _ in range(100): # 100 batches should be enough
|
|
x_sample, _ = next(sample_loader)
|
|
sample_sequences.append(x_sample[0].tolist())
|
|
del sample_loader
|
|
l3_alloc = compute_lzw_allocation(sample_sequences, vocab_size, l3_n_emb, args.l3_k_max)
|
|
l3_bounds = allocation_to_bounds(l3_alloc).to(device)
|
|
print0(f"L3 allocation: {l3_n_emb:,} total embeddings, k_max={args.l3_k_max}, avg={l3_n_emb/vocab_size:.1f}/token")
|
|
|
|
# Build the model, move to device, init the weights
|
|
model = build_model_meta(args.depth, l3_after_layers=args.l3_after_layers, l3_n_emb=l3_n_emb) # 1) Build on meta device (only shapes/dtypes, no data)
|
|
model_config = model.config
|
|
model_config_kwargs = asdict(model_config)
|
|
print0(f"Model config:\n{json.dumps(model_config_kwargs, indent=2)}")
|
|
model.to_empty(device=device) # 2) All tensors get storage on target device but with uninitialized (garbage) data
|
|
model.init_weights() # 3) All tensors get initialized
|
|
|
|
# Set L3 bounds after model creation
|
|
if l3_bounds is not None:
|
|
for l3_layer in model.l3_layers.values():
|
|
l3_layer.set_bounds(l3_bounds)
|
|
print0(f"L3 bounds set for {len(model.l3_layers)} layer(s)")
|
|
|
|
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
|
base_dir = get_base_dir()
|
|
output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12
|
|
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
|
resuming = args.resume_from_step != -1
|
|
if resuming:
|
|
print0(f"Resuming optimization from step {args.resume_from_step}")
|
|
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank)
|
|
model.load_state_dict(model_data, strict=True, assign=True)
|
|
del model_data # free up this memory after the copy
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# FP8 training initialization and management (this has to be done before torch.compile)
|
|
|
|
# Convert Linear layers to Float8Linear if --fp8 is set
|
|
if args.fp8:
|
|
if device_type != "cuda":
|
|
print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag")
|
|
else:
|
|
# our custom fp8 is simpler than torchao, written for exact API compatibility
|
|
from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training
|
|
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
|
|
import torch.nn as nn
|
|
|
|
# Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough
|
|
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
|
|
if not isinstance(mod, nn.Linear):
|
|
return False
|
|
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
|
|
return False
|
|
if min(mod.in_features, mod.out_features) < 128:
|
|
return False
|
|
return True
|
|
|
|
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
|
|
num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
|
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
|
|
num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
|
|
num_skipped = num_linear - num_fp8
|
|
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
|
|
|
|
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
|
|
@contextmanager
|
|
def disable_fp8(model):
|
|
"""Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation.
|
|
|
|
CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead,
|
|
we swap out Float8Linear modules entirely and restore them after.
|
|
"""
|
|
import torch.nn as nn
|
|
|
|
# Find all Float8Linear modules and their locations
|
|
fp8_locations = [] # list of (parent_module, attr_name, fp8_module)
|
|
for name, module in model.named_modules():
|
|
if 'Float8' in type(module).__name__:
|
|
if '.' in name:
|
|
parent_name, attr_name = name.rsplit('.', 1)
|
|
parent = model.get_submodule(parent_name)
|
|
else:
|
|
parent = model
|
|
attr_name = name
|
|
fp8_locations.append((parent, attr_name, module))
|
|
|
|
if not fp8_locations:
|
|
yield # No FP8 modules, nothing to do
|
|
return
|
|
|
|
# Swap Float8Linear -> nn.Linear (shares the same weight tensor, no copy)
|
|
for parent, attr_name, fp8_module in fp8_locations:
|
|
linear = nn.Linear(
|
|
fp8_module.in_features,
|
|
fp8_module.out_features,
|
|
bias=fp8_module.bias is not None,
|
|
device=fp8_module.weight.device,
|
|
dtype=fp8_module.weight.dtype,
|
|
)
|
|
linear.weight = fp8_module.weight # share, don't copy
|
|
if fp8_module.bias is not None:
|
|
linear.bias = fp8_module.bias
|
|
setattr(parent, attr_name, linear)
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
# Restore Float8Linear modules
|
|
for parent, attr_name, fp8_module in fp8_locations:
|
|
setattr(parent, attr_name, fp8_module)
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Compile the model
|
|
|
|
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
|
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
|
|
|
|
# Get the parameter counts of our model
|
|
param_counts = model.num_scaling_params()
|
|
print0(f"Parameter counts:")
|
|
for key, value in param_counts.items():
|
|
print0(f"{key:24s}: {value:,}")
|
|
num_params = param_counts['total']
|
|
num_flops_per_token = model.estimate_flops()
|
|
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
|
|
|
# 1) Use scaling laws to determine the optimal training horizon in tokens
|
|
# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis).
|
|
# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params
|
|
def get_scaling_params(m):
|
|
# As for which params to use exactly, transformer matrices + lm_head gives cleanest scaling laws (see dev/LOG.md Jan 27, 2026)
|
|
params_counts = m.num_scaling_params()
|
|
scaling_params = params_counts['transformer_matrices'] + params_counts['lm_head']
|
|
return scaling_params
|
|
num_scaling_params = get_scaling_params(model)
|
|
target_tokens = int(args.target_param_data_ratio * num_scaling_params) # optimal tokens for the model we are about to train
|
|
|
|
# Our reference model is d12, this is where a lot of hyperparameters are tuned and then transfered to higher depths (muP style)
|
|
d12_ref = build_model_meta(12) # creates the model on meta device
|
|
D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) # compute-optimal d12 training horizon in tokens (measured empirically)
|
|
B_REF = 2**19 # optimal batch size at d12 ~= 524,288 tokens (measured empirically)
|
|
|
|
# 2) Now that we have the token horizon, we can calculate the optimal batch size
|
|
# We follow the Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738
|
|
# The optimal batch size grows as approximately D^0.383, so e.g. if D doubles from d12 to d24, B should grow by 2^0.383 ≈ 1.3x.
|
|
total_batch_size = args.total_batch_size # user-provided override is possible
|
|
if total_batch_size == -1:
|
|
batch_size_ratio = target_tokens / D_REF
|
|
predicted_batch_size = B_REF * batch_size_ratio ** 0.383
|
|
total_batch_size = 2 ** round(math.log2(predicted_batch_size)) # clamp to nearest power of 2 for efficiency
|
|
print0(f"Auto-computed optimal batch size: {total_batch_size:,} tokens")
|
|
|
|
# 3) Knowing the batch size, we can now calculate a learning rate correction (bigger batch size allows higher learning rates)
|
|
batch_lr_scale = 1.0
|
|
batch_ratio = total_batch_size / B_REF # B/B_ref
|
|
if batch_ratio != 1.0:
|
|
# SGD: linear scaling with batch size is standard (not used in nanochat)
|
|
# AdamW: sqrt scaling is standard: η ∝ √(B/B_ref)
|
|
# Muon: we will use the same scaling for Muon as for AdamW: η ∝ √(B/B_ref) (not studied carefully, assumption!)
|
|
batch_lr_scale = batch_ratio ** 0.5 # η ∝ √(B/B_ref)
|
|
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {B_REF:,})")
|
|
|
|
# 4) Knowing the batch size and the token horizon, we can now calculate the appropriate weight decay scaling
|
|
# We adopt the T_epoch framework from https://arxiv.org/abs/2405.13698
|
|
# Central idea of the paper is that T_epoch = B/(η·λ·D) should remain constant.
|
|
# Above, we used learning rate scaling η ∝ √(B/B_ref). So it's a matter of ~10 lines of math to derive that to keep T_epoch constant, we need:
|
|
# λ = λ_ref · √(B/B_ref) · (D_ref/D)
|
|
# Note that these papers study AdamW, *not* Muon. We are blindly following AdamW theory for scaling hoping it ~works for Muon too.
|
|
weight_decay_scaled = args.weight_decay * math.sqrt(total_batch_size / B_REF) * (D_REF / target_tokens)
|
|
if weight_decay_scaled != args.weight_decay:
|
|
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
|
optimizer = model.setup_optimizer(
|
|
# AdamW hyperparameters
|
|
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
|
embedding_lr=args.embedding_lr * batch_lr_scale,
|
|
scalar_lr=args.scalar_lr * batch_lr_scale,
|
|
adam_betas=(args.adam_beta1, args.adam_beta2),
|
|
# Muon hyperparameters
|
|
matrix_lr=args.matrix_lr * batch_lr_scale,
|
|
weight_decay=weight_decay_scaled,
|
|
)
|
|
|
|
if resuming:
|
|
optimizer.load_state_dict(optimizer_data)
|
|
del optimizer_data
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Initialize the DataLoaders for train/val
|
|
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
|
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
|
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
|
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Calculate the number of iterations we will train for and set up the various schedulers
|
|
|
|
# num_iterations: either it is given, or from target flops, or from target data:param ratio (in that order)
|
|
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
|
|
if args.num_iterations > 0:
|
|
# Override num_iterations to a specific value if given
|
|
num_iterations = args.num_iterations
|
|
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
|
elif args.target_flops > 0:
|
|
# Calculate the number of iterations from the target flops (used in scaling laws analysis, e.g. runs/scaling_laws.sh)
|
|
num_iterations = round(args.target_flops / (num_flops_per_token * total_batch_size))
|
|
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
|
elif args.target_param_data_ratio > 0:
|
|
# Calculate the number of iterations from the target param data ratio (the most common use case)
|
|
num_iterations = target_tokens // total_batch_size
|
|
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
|
else:
|
|
raise ValueError("No training horizon specified")
|
|
total_tokens = total_batch_size * num_iterations # the actual number of tokens we will train for
|
|
print0(f"Total number of training tokens: {total_tokens:,}")
|
|
print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20
|
|
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
|
|
|
# Learning rate schedule (linear warmup, constant, linear warmdown)
|
|
def get_lr_multiplier(it):
|
|
warmup_iters = round(args.warmup_ratio * num_iterations)
|
|
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
|
if it < warmup_iters:
|
|
return (it + 1) / warmup_iters
|
|
elif it <= num_iterations - warmdown_iters:
|
|
return 1.0
|
|
else:
|
|
progress = (num_iterations - it) / warmdown_iters
|
|
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
|
|
|
# Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps)
|
|
def get_muon_momentum(it):
|
|
frac = min(it / 300, 1)
|
|
momentum = (1 - frac) * 0.85 + frac * 0.95
|
|
return momentum
|
|
|
|
# Weight decay scheduler for Muon optimizer (linearly decays to zero over the course of training)
|
|
def get_weight_decay(it):
|
|
return weight_decay_scaled * (1 - it / num_iterations)
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Training loop
|
|
|
|
# Loop state (variables updated by the training loop)
|
|
if not resuming:
|
|
step = 0
|
|
val_bpb = None # will be set if eval_every > 0
|
|
min_val_bpb = float("inf")
|
|
smooth_train_loss = 0 # EMA of training loss
|
|
total_training_time = 0 # total wall-clock time of training
|
|
else:
|
|
step = meta_data["step"]
|
|
loop_state = meta_data["loop_state"]
|
|
val_bpb = meta_data["val_bpb"]
|
|
min_val_bpb = loop_state["min_val_bpb"]
|
|
smooth_train_loss = loop_state["smooth_train_loss"]
|
|
total_training_time = loop_state["total_training_time"]
|
|
|
|
# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step
|
|
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
|
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
|
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
|
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
|
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
|
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
|
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
|
|
|
# Go!
|
|
while True:
|
|
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
|
flops_so_far = num_flops_per_token * total_batch_size * step
|
|
|
|
# once in a while: evaluate the val bpb (all ranks participate)
|
|
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
|
|
model.eval()
|
|
val_loader = build_val_loader()
|
|
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
|
with disable_fp8(model), autocast_ctx:
|
|
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
|
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
|
|
if val_bpb < min_val_bpb:
|
|
min_val_bpb = val_bpb
|
|
wandb_run.log({
|
|
"step": step,
|
|
"total_training_flops": flops_so_far,
|
|
"total_training_time": total_training_time,
|
|
"val/bpb": val_bpb,
|
|
})
|
|
model.train()
|
|
|
|
# once in a while: estimate the CORE metric (all ranks participate)
|
|
# use the original uncompiled model because the inputs keep changing shape
|
|
# disable FP8 for evaluation to use BF16 for more consistent/accurate results
|
|
results = {}
|
|
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
|
model.eval()
|
|
with disable_fp8(orig_model), autocast_ctx:
|
|
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
|
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
|
wandb_run.log({
|
|
"step": step,
|
|
"total_training_flops": flops_so_far,
|
|
"core_metric": results["core_metric"],
|
|
"centered_results": results["centered_results"],
|
|
})
|
|
model.train()
|
|
|
|
# once in a while: sample from the model (only on master process)
|
|
# use the original uncompiled model because the inputs keep changing shape
|
|
if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)):
|
|
model.eval()
|
|
prompts = [
|
|
"The capital of France is",
|
|
"The chemical symbol of gold is",
|
|
"If yesterday was Friday, then tomorrow will be",
|
|
"The opposite of hot is",
|
|
"The planets of the solar system are:",
|
|
"My favorite color is",
|
|
"If 5*x + 3 = 13, then x is",
|
|
]
|
|
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
|
for prompt in prompts:
|
|
tokens = tokenizer(prompt, prepend="<|bos|>")
|
|
with disable_fp8(orig_model), autocast_ctx:
|
|
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
|
print0(tokenizer.decode(sample[0]))
|
|
model.train()
|
|
|
|
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
|
|
if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
|
|
save_checkpoint(
|
|
checkpoint_dir,
|
|
step,
|
|
orig_model.state_dict(), # model parameters
|
|
optimizer.state_dict(), # optimizer state
|
|
{ # metadata saved as json
|
|
"step": step,
|
|
"val_bpb": val_bpb, # loss at last step
|
|
"model_config": model_config_kwargs,
|
|
"user_config": user_config, # inputs to the training script
|
|
"device_batch_size": args.device_batch_size,
|
|
"max_seq_len": args.max_seq_len,
|
|
"total_batch_size": total_batch_size,
|
|
"dataloader_state_dict": dataloader_state_dict,
|
|
"loop_state": { # all loop state (other than step) so that we can resume training
|
|
"min_val_bpb": min_val_bpb,
|
|
"smooth_train_loss": smooth_train_loss,
|
|
"total_training_time": total_training_time,
|
|
},
|
|
},
|
|
rank=ddp_rank,
|
|
)
|
|
|
|
# termination conditions (TODO: possibly also add loss explosions etc.)
|
|
if last_step:
|
|
break
|
|
|
|
# -------------------------------------------------------------------------
|
|
# single training step
|
|
# evaluate the gradient
|
|
synchronize()
|
|
t0 = time.time()
|
|
for micro_step in range(grad_accum_steps):
|
|
with autocast_ctx:
|
|
loss = model(x, y)
|
|
train_loss = loss.detach() # for logging
|
|
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
|
loss.backward()
|
|
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
|
# step the optimizer
|
|
lrm = get_lr_multiplier(step)
|
|
muon_momentum = get_muon_momentum(step)
|
|
muon_weight_decay = get_weight_decay(step)
|
|
for group in optimizer.param_groups:
|
|
group["lr"] = group["initial_lr"] * lrm
|
|
if group['kind'] == 'muon':
|
|
group["momentum"] = muon_momentum
|
|
group["weight_decay"] = muon_weight_decay
|
|
optimizer.step()
|
|
model.zero_grad(set_to_none=True)
|
|
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
|
synchronize()
|
|
t1 = time.time()
|
|
dt = t1 - t0
|
|
# -------------------------------------------------------------------------
|
|
|
|
# logging (CPU action only)
|
|
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
|
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
|
|
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
|
pct_done = 100 * step / num_iterations
|
|
tok_per_sec = int(total_batch_size / dt)
|
|
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
|
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
|
if step > 10:
|
|
total_training_time += dt # only count the time after the first 10 steps
|
|
# Calculate ETA based on average time per step (excluding first 10 steps)
|
|
steps_done = step - 10
|
|
if steps_done > 0:
|
|
avg_time_per_step = total_training_time / steps_done
|
|
remaining_steps = num_iterations - step
|
|
eta_seconds = remaining_steps * avg_time_per_step
|
|
eta_str = f" | eta: {eta_seconds/60:.1f}m"
|
|
else:
|
|
eta_str = ""
|
|
epoch = dataloader_state_dict["epoch"]
|
|
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
|
if step % 100 == 0:
|
|
log_data = {
|
|
"step": step,
|
|
"total_training_flops": flops_so_far,
|
|
"total_training_time": total_training_time,
|
|
"train/loss": debiased_smooth_loss,
|
|
"train/lrm": lrm,
|
|
"train/dt": dt,
|
|
"train/tok_per_sec": tok_per_sec,
|
|
"train/mfu": mfu,
|
|
"train/epoch": epoch,
|
|
}
|
|
wandb_run.log(log_data)
|
|
|
|
# state update
|
|
first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step)
|
|
step += 1
|
|
|
|
# The garbage collector is sadly a little bit overactive and for some poorly understood reason,
|
|
# it spends ~500ms scanning for cycles quite frequently, just to end up cleaning up very few tiny objects each time.
|
|
# So we manually manage and help it out here
|
|
if first_step_of_run:
|
|
gc.collect() # manually collect a lot of garbage from setup
|
|
gc.freeze() # immediately freeze all currently surviving objects and exclude them from GC
|
|
gc.disable() # nuclear intervention here: disable GC entirely except:
|
|
elif step % 5000 == 0: # every 5000 steps...
|
|
gc.collect() # manually collect, just to be safe for very, very long runs
|
|
|
|
# print a few more stats
|
|
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
|
print0(f"Total training time: {total_training_time/60:.2f}m")
|
|
if val_bpb is not None:
|
|
print0(f"Minimum validation bpb: {min_val_bpb:.6f}")
|
|
|
|
# Log to report
|
|
from nanochat.report import get_report
|
|
get_report().log(section="Base model training", data=[
|
|
user_config, # CLI args
|
|
{ # stats about the training setup
|
|
"Number of parameters": num_params,
|
|
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
|
"Calculated number of iterations": num_iterations,
|
|
"Number of training tokens": total_tokens,
|
|
"Tokens : Scaling params ratio": total_batch_size * num_iterations / num_scaling_params,
|
|
"DDP world size": ddp_world_size,
|
|
"warmup_ratio": args.warmup_ratio,
|
|
"warmdown_ratio": args.warmdown_ratio,
|
|
"final_lr_frac": args.final_lr_frac,
|
|
},
|
|
{ # stats about training outcomes
|
|
"Minimum validation bpb": min_val_bpb if val_bpb is not None else None,
|
|
"Final validation bpb": val_bpb,
|
|
"CORE metric estimate": results.get("core_metric", None),
|
|
"MFU %": f"{mfu:.2f}%",
|
|
"Total training flops": f"{flops_so_far:e}",
|
|
"Total training time": f"{total_training_time/60:.2f}m",
|
|
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
|
}
|
|
])
|
|
|
|
# cleanup
|
|
wandb_run.finish() # wandb run finish
|
|
compute_cleanup()
|