nanochat/scripts/base_train.py
2026-01-08 15:43:36 +00:00

504 lines
24 KiB
Python

"""
Train model. Run as:
python base_train.py
or distributed as:
torchrun --nproc_per_node=8 base_train.py
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_iters=10 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
import math
import pickle
from contextlib import nullcontext
import numpy as np
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True
import wandb
# Import from nanoMoE model (keeping train.py's original model)
import sys
from nanochat.gpt import GPTConfig, GPT
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.engine import Engine
from nanochat.dataloader import tokenizing_distributed_data_loader_with_state, tokenizing_distributed_data_loader
from nanochat.loss_eval import evaluate_bpb
from scripts.base_eval import evaluate_model
print_banner()
# Allow env overrides for common LR knobs used in cluster runs.
def _get_env_float(name, default):
val = os.getenv(name)
if val is None or val == "":
return default
try:
return float(val)
except ValueError as exc:
raise ValueError(f"Invalid {name} env value: {val}") from exc
def _get_env_int(name, default):
val = os.getenv(name)
if val is None or val == "":
return default
try:
return int(val)
except ValueError as exc:
raise ValueError(f"Invalid {name} env value: {val}") from exc
# -----------------------------------------------------------------------------
# User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# Runtime
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
# Model architecture
depth = _get_env_int("DEPTH", 6) # the depth of the Transformer model to train (matches nanoMoE n_layer=6), rest of the kwargs are derived
depth = _get_env_int("N_LAYER", depth)
max_seq_len = 1024 # max context length (matches nanoMoE block_size=1024)
dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ (matches nanoMoE)
bias = False # do we use bias inside LayerNorm and Linear layers? (matches nanoMoE)
# MoE settings (matching nanoMoE config/train_nano_moe.py)
n_exp = _get_env_int("N_EXP", 8) # number of experts (matches train_nano_moe.py)
top_k = _get_env_int("TOP_K", 2) # number of active experts (matches train_nano_moe.py)
use_aux_loss = True # apply auxiliary loss (from Switch Transformer) (matches train_nano_moe.py)
use_router_z_loss = True # apply router z loss (from ST-MoE) (matches train_nano_moe.py)
use_noisy_top_k = False # use noisy top-k routing (matches train_nano_moe.py)
aux_loss_weight = 0.01 # auxiliary loss weight (matches train_nano_moe.py)
router_z_loss_weight = 0.001 # router z loss weight (matches train_nano_moe.py)
train_capacity = 1.25 # training capacity factor (matches train_nano_moe.py)
eval_capacity = 2.0 # evaluation capacity factor (matches train_nano_moe.py)
min_capacity = 4 # minimum batch size per expert (default from ST-MoE)
stride = 2 # one in every stride layers uses MoE (matches train_nano_moe.py)
use_switch_tfm_init = True # use weight init scheme from Switch Transformer (matches train_nano_moe.py)
switch_tfm_init_scale = 1.0 # scale for switch transformer init (matches train_nano_moe.py)
router_use_full_prec = True # use float32 in router (matches train_nano_moe.py)
# Training horizon. Only one of these 3 will be used, in this order of precedence.
num_iterations = 50000 # explicit number of steps (matches nanoMoE max_iters=50000, makes total tokens ~25B)
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
target_param_data_ratio = -1 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
# Optimization
device_batch_size = 12 # per-device batch size (matches nanoMoE batch_size=12)
total_batch_size = 491520 # total desired batch size in #tokens (matches nanoMoE: 12 * 1024 * 40 = 491,520 for 8 GPUs)
embedding_lr = 0.0006 # learning rate for the embedding parameters (Adam)
unembedding_lr = 0.0006 # learning rate for the unembedding parameters (Adam)
weight_decay = 0.1 # weight decay (matches nanoMoE weight_decay=1e-1)
matrix_lr = 0.0006 # learning rate for the matrix parameters (Muon)
learning_rate = _get_env_float("LEARNING_RATE", 6e-4) # learning rate for AdamW optimizer (matches nanoMoE: 6e-4)
betas = (0.9, 0.95) # betas for AdamW optimizer (matches nanoMoE: beta1=0.9, beta2=0.95)
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
decay_lr = True # whether to decay the learning rate (matches train_nano_moe.py)
# Learning rate decay parameters (matching train.py and train_nano_moe.py)
warmup_iters = 2000 # how many steps to warm up for (matches train.py default)
lr_decay_iters = 50000 # learning rate decay iterations (matches train_nano_moe.py)
min_lr = _get_env_float("MIN_LR", 6e-5) # minimum learning rate (matches train.py default, which equals 6e-4 * 0.1)
final_lr_frac = 0.1 # final learning rate as fraction of initial learning rate (for compatibility)
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
# Evaluation
eval_every = 500000000 # every how many steps to evaluate the model for val bpb (matches nanoMoE eval_interval=500)
eval_iters = 200 # number of iterations to evaluate val loss on (matches nanoMoE eval_iters=200)
log_interval = 10 # every how many steps to log training metrics (matches nanoMoE log_interval=10)
core_metric_every = -1 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = -1 # examples per task in estimating the core metric
sample_every = 200000000 # every how many steps to sample from the model
save_every = 10000 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
# System
compile = True # use PyTorch 2.0 to compile the model to be faster (matches nanoMoE)
# Output
model_tag = os.getenv("MODEL_TAG", "") # optionally override the model tag for the output checkpoint directory name
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
if model_tag == "":
model_tag = f"d{depth}_min_lr{min_lr}_max_lr{learning_rate}"
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
# Set random seed (matching nanoMoE/train.py)
seed_offset = ddp_rank if ddp else 0 # each process gets a different seed in DDP mode
torch.manual_seed(1337 + seed_offset)
# Set tf32 precision (matching nanoMoE/train.py)
if device_type == 'cuda':
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer()
token_bytes = get_token_bytes(device=device)
vocab_size = tokenizer.get_vocab_size()
print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
# For nanoMoE, we use n_layer, n_head, n_embd directly
n_layer = depth
model_dim = 384 # matches train_nano_moe.py
num_heads = 6 # matches train_nano_moe.py
n_head = num_heads
n_embd = model_dim
num_kv_heads = num_heads
print0(f"num_layers: {n_layer}")
print0(f"model_dim: {model_dim}")
print0(f"num_heads: {num_heads}")
print0(f"num_kv_heads: {num_kv_heads}")
# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
# -----------------------------------------------------------------------------
# Initialize the Model
# Get base directory for data and checkpoints
base_dir = get_base_dir()
# Use vocab_size from tokenizer (already obtained above)
# This ensures the model vocab_size matches the tokenizer vocab_size
model_config_kwargs = dict(
n_layer=n_layer,
n_head=n_head,
n_embd=n_embd,
block_size=max_seq_len,
vocab_size=vocab_size, # Use vocab_size from tokenizer, not hardcoded
dropout=dropout,
bias=bias,
# MoE parameters (matching train_nano_moe.py)
n_exp=n_exp,
top_k=top_k,
use_aux_loss=use_aux_loss,
use_router_z_loss=use_router_z_loss,
use_noisy_top_k=use_noisy_top_k,
aux_loss_weight=aux_loss_weight,
router_z_loss_weight=router_z_loss_weight,
train_capacity=train_capacity,
eval_capacity=eval_capacity,
min_capacity=min_capacity,
stride=stride,
use_switch_tfm_init=use_switch_tfm_init,
switch_tfm_init_scale=switch_tfm_init_scale,
router_use_full_prec=router_use_full_prec,
)
gptconf = GPTConfig(**model_config_kwargs)
model = GPT(gptconf)
model.to(device)
# If we are resuming, overwrite the model parameters with those of the checkpoint
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d6
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
resuming = False
# if resuming:
# print0(f"Resuming optimization from step {resume_from_step}")
# model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank)
# model.load_state_dict(model_data, strict=True, assign=True)
# del model_data # free up this memory after the copy
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
# Calculate FLOPs per token manually (based on PaLM paper Appendix B) before compilation
nparams_embedding = orig_model.transformer.wte.weight.numel()
num_params = sum(p.numel() for p in orig_model.parameters())
l, h, q, t = model_config_kwargs['n_layer'], model_config_kwargs['n_head'], model_config_kwargs['n_embd'] // model_config_kwargs['n_head'], model_config_kwargs['block_size']
num_flops_per_token = 6 * (num_params - nparams_embedding) + 12 * l * h * q * t
print0(f"Number of parameters: {num_params:,}")
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
# Initialize GradScaler (matching nanoMoE train.py - before optimizer)
# note: float16 data type will automatically use a GradScaler
dtype_actual = 'bfloat16' if device_type == 'cuda' and torch.cuda.is_bf16_supported() else 'float16'
scaler = torch.cuda.amp.GradScaler(enabled=(dtype_actual == 'float16'))
# Initialize the Optimizer (AdamW for all parameters) - BEFORE DDP wrapping (matching nanoMoE)
optimizer = model.configure_optimizers(weight_decay=weight_decay, learning_rate=learning_rate, betas=betas, device_type=device_type)
adamw_optimizer = optimizer
# Compile the model (matching nanoMoE)
if compile:
if master_process:
print0("compiling the model... (takes a ~minute)")
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
# Wrap model into DDP container (matching nanoMoE train.py)
from torch.nn.parallel import DistributedDataParallel as DDP
if ddp:
model = DDP(model, device_ids=[ddp_local_rank] if device_type == "cuda" else None)
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
if num_iterations > 0:
print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif target_flops > 0:
# calculate the number of iterations from the target flops
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif target_param_data_ratio > 0:
# calculate the number of iterations from the target param data ratio
target_tokens = target_param_data_ratio * num_params
num_iterations = target_tokens // total_batch_size
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# if resuming:
# for opt, dat in zip(optimizer, optimizer_data):
# if opt is not None and dat is not None:
# opt.load_state_dict(dat)
# del optimizer_data # free up the memory
# -----------------------------------------------------------------------------
# Initialize the DataLoaders for train/val (like nanochat-run)
dataloader_resume_state_dict = None if not resuming else meta_data.get("dataloader_state_dict")
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
# Learning rate scheduler (cosine decay with warmup) - matching nanoMoE/train.py exactly
def get_lr(it):
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * (it + 1) / (warmup_iters + 1)
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr)
# -----------------------------------------------------------------------------
# Loop state (variables updated by the training loop)
if not resuming:
step = 0
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
total_training_time = 0 # total wall-clock time of training
val_bpb = None # Will be set during evaluation
else:
step = meta_data["step"]
loop_state = meta_data["loop_state"]
min_val_bpb = loop_state["min_val_bpb"]
smooth_train_loss = loop_state["smooth_train_loss"]
total_training_time = loop_state["total_training_time"]
val_bpb = None # Will be set during evaluation
# -----------------------------------------------------------------------------
# Training loop
while True:
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
flops_so_far = num_flops_per_token * total_batch_size * step
# determine and set the learning rate for this iteration (matching nanoMoE/train.py)
lr = get_lr(step) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# once in a while: evaluate the val bpb (all ranks participate)
if step % eval_every == 0:
model.eval()
val_loader = build_val_loader()
eval_steps = eval_iters # use eval_iters as number of evaluation steps
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
})
model.train()
# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
results = {}
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"core_metric": results["core_metric"],
"centered_results": results["centered_results"],
})
model.train()
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(), # model parameters
optimizer.state_dict(), # optimizer states
{ # metadata saved as json
"step": step,
"model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script
"device_batch_size": device_batch_size,
"max_seq_len": max_seq_len,
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,
"smooth_train_loss": smooth_train_loss,
"total_training_time": total_training_time,
},
"dataloader_state_dict": dataloader_state_dict, # for resuming data loading
},
rank=ddp_rank,
)
# termination conditions (TODO: possibly also add loss explosions etc.)
if last_step:
break
# -------------------------------------------------------------------------
# forward backward update, with optional gradient accumulation to simulate larger batch size
# and using the GradScaler if data type is float16 (matching nanoMoE train.py exactly)
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
if ddp:
# in DDP training we only need to sync gradients at the last micro step.
# the official way to do this is with model.no_sync() context manager, but
# I really dislike that this bloats the code and forces us to repeat code
# looking at the source of that context manager, it just toggles this variable
model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
with autocast_ctx:
_, loss = model(x, y) # nanoMoE model returns (logits, loss)
loss = loss / grad_accum_steps # scale the loss to account for gradient accumulation
# immediately async prefetch next batch while model is doing the forward pass on the GPU
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# backward pass, with gradient scaling if training in fp16
scaler.scale(loss).backward()
# clip the gradient
grad_clip_enabled = grad_clip > 0.0
grad_norm = None
if grad_clip_enabled:
scaler.unscale_(optimizer)
# clip_grad_norm_ returns the gradient norm before clipping
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
# step the optimizer and scaler if training in fp16
scaler.step(optimizer)
scaler.update()
# flush the gradients as soon as we can, no need for this memory anymore
optimizer.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
dt = t1 - t0
train_loss = loss.detach() # for logging (after scaling)
# -------------------------------------------------------------------------
# logging (base_train.py style - keeping all the detailed logging)
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
# scale up to undo the division above, approximating the true total loss (exact would have been a sum)
lossf = loss.item() * grad_accum_steps
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * lossf # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled and grad_norm is not None else ""
lr_str = f"lr: {lr:.2e} |" if decay_lr else ""
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} {lr_str}dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 100 == 0:
log_data = {
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
}
if decay_lr:
log_data["lr"] = lr
if grad_clip_enabled:
log_data["train/grad_norm"] = grad_norm
wandb_run.log(log_data)
# state update
step += 1
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
from nanochat.report import get_report
get_report().log(section="Base model training", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of parameters": num_params,
"Number of FLOPs per token": f"{num_flops_per_token:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
"DDP world size": ddp_world_size,
"final_lr_frac": final_lr_frac,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
"Final validation bpb": val_bpb,
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
}
])
# cleanup
wandb_run.finish() # wandb run finish
compute_cleanup()