nanochat/scripts/base_train.py
William Thurston 25d2573f47 Add MoE configuration and implementation in training scripts and model architecture
- Introduced parameters for Mixture of Experts (MoE) in `runmps.sh`, `base_train.py`, and `gpt.py`, allowing for dynamic configuration of experts during training.
- Enhanced `gpt.py` with new classes `MoEFeedForward` and `ExpertFFN` to implement MoE functionality in the model architecture.
- Updated `configurator.py` to handle type conversions for new MoE parameters.
- Improved logging in `base_train.py` to include MoE-related metrics and configurations during training.
- Added assertions and derived defaults for MoE parameters to ensure valid configurations.
- Implemented methods to estimate and log FLOPs for both dense and MoE active configurations during training.
- Enhanced gradient handling in `muon.py` to accommodate potential absence of gradients for unused experts.
2025-11-11 19:58:38 -08:00

509 lines
23 KiB
Python

"""
Train model. Run as:
python base_train.py
or distributed as:
torchrun --nproc_per_node=8 base_train.py
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
from contextlib import nullcontext
import wandb
import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from scripts.base_eval import evaluate_model
print_banner()
# -----------------------------------------------------------------------------
# User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# Runtime
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
# Model architecture
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
kv_head_mult = 1 # number of query heads that share a single key/value head (1 disables GQA)
moe_num_experts = 0 # routed experts per MoE layer (0 disables MoE)
moe_num_shared_experts = -1 # -1 => derive (defaults to 1 shared expert)
moe_experts_per_token = -1 # -1 => derive using Ling-style sparsity (≈1/32 active)
moe_expert_ffn_mult = -1.0 # -1 => derive from granularity target (defaults to 12)
dense_layers_before_moe = -1 # -1 => derive (≈10% of layers, min 1) before switching to MoE
moe_granularity_target = 12.0 # Ling guidance: target granularity per layer (2*d_model/d_expert)
moe_activation_denominator = 32 # derive top-k as num_experts / denominator (~3% activation)
# Training horizon. Only one of these 3 will be used, in this order of precedence.
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
# Optimization
device_batch_size = 32 # per-device batch size (set to not OOM)
total_batch_size = 524288 # total desired batch size, in #tokens
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
# Evaluation
eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model
# Output
model_tag = "" # optionally override the model tag for the output checkpoint directory name
checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable)
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
if use_dummy_wandb:
wandb_run = DummyWandb()
else:
wandb_kwargs = {
"project": os.environ.get("WANDB_PROJECT", "nanochat"),
"name": run,
"config": user_config,
"reinit": True,
}
wandb_id = os.environ.get("WANDB_RUN_ID")
if wandb_id:
wandb_kwargs.update({"id": wandb_id, "resume": "allow"})
wandb_run = wandb.init(**wandb_kwargs)
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer()
token_bytes = get_token_bytes(device=device)
vocab_size = tokenizer.get_vocab_size()
print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
num_layers = depth
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
assert kv_head_mult >= 1, "kv_head_mult must be >= 1"
assert num_heads % kv_head_mult == 0, f"num_heads ({num_heads}) must be divisible by kv_head_mult ({kv_head_mult})"
num_kv_heads = max(1, num_heads // kv_head_mult)
activation_denom = max(1, int(round(moe_activation_denominator)))
granularity_target = moe_granularity_target if moe_granularity_target > 0 else 12.0
auto_dense_layers = max(1, num_layers // 10) if moe_num_experts > 0 else num_layers
if dense_layers_before_moe < 0:
dense_layers_before_moe = auto_dense_layers
dense_layers_before_moe = max(0, min(dense_layers_before_moe, num_layers))
if moe_num_experts <= 0:
dense_layers_before_moe = num_layers
if moe_num_experts > 0:
derived_top_k = max(1, round(moe_num_experts / activation_denom))
moe_experts_per_token = moe_experts_per_token if moe_experts_per_token > 0 else derived_top_k
moe_num_shared_experts = moe_num_shared_experts if moe_num_shared_experts >= 0 else 1
if moe_expert_ffn_mult <= 0:
moe_expert_ffn_mult = max(1e-6, 2.0 / granularity_target)
assert moe_experts_per_token > 0, "moe_experts_per_token must be > 0 when MoE is enabled"
assert moe_num_experts >= moe_experts_per_token, "moe_num_experts must be >= moe_experts_per_token"
assert moe_num_shared_experts >= 0, "moe_num_shared_experts must be >= 0"
assert moe_expert_ffn_mult > 0, "moe_expert_ffn_mult must be > 0"
moe_activation_ratio = moe_experts_per_token / (moe_num_experts + moe_num_shared_experts)
moe_granularity_actual = 2.0 / moe_expert_ffn_mult
else:
moe_num_shared_experts = 0
moe_experts_per_token = 0
moe_expert_ffn_mult = 4.0 if moe_expert_ffn_mult <= 0 else moe_expert_ffn_mult
moe_activation_ratio = 0.0
moe_granularity_actual = 0.0
def _resolve_checkpoint_tag(tag, run_name, depth_value):
if tag:
return tag
run_name = run_name or ""
if run_name and run_name != "dummy":
return run_name
return f"d{depth_value}"
model_tag = _resolve_checkpoint_tag(model_tag, run, depth)
user_config["model_tag"] = model_tag
print0(f"num_layers: {num_layers}")
print0(f"model_dim: {model_dim}")
print0(f"kv_head_mult: {kv_head_mult}")
print0(f"num_heads: {num_heads}")
print0(f"num_kv_heads: {num_kv_heads}")
if moe_num_experts > 0:
print0(
"MoE config: experts=%d shared=%d topk=%d granularity=%.1f (mult=%.3f) sparsity=%.2f%% dense_preface=%d" % (
moe_num_experts,
moe_num_shared_experts,
moe_experts_per_token,
moe_granularity_actual,
moe_expert_ffn_mult,
moe_activation_ratio * 100,
dense_layers_before_moe,
)
)
print0(f"Checkpoint tag: {model_tag}")
# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
# -----------------------------------------------------------------------------
# Initialize the Model
model_config_kwargs = dict(
sequence_len=max_seq_len,
vocab_size=vocab_size,
n_layer=num_layers,
n_head=num_heads,
n_kv_head=num_kv_heads,
n_embd=model_dim,
moe_num_experts=moe_num_experts,
moe_num_shared_experts=moe_num_shared_experts,
moe_experts_per_token=moe_experts_per_token,
moe_expert_ffn_mult=moe_expert_ffn_mult,
dense_layers_before_moe=dense_layers_before_moe,
moe_granularity_target=granularity_target,
moe_activation_denominator=activation_denom,
)
with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
dense_like_flops = model.estimate_flops()
active_flops_per_token, dense_ref_flops = model.estimate_moe_active_flops()
num_flops_per_token = active_flops_per_token
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
print0(f"Estimated FLOPs per token (dense-like): {dense_like_flops:e}")
if active_flops_per_token != dense_like_flops:
print0(f"Estimated FLOPs per token (MoE active): {active_flops_per_token:e}")
print0(f"Estimated FLOPs per token (dense reference): {dense_ref_flops:e}")
user_config.update({
"moe_num_experts": moe_num_experts,
"moe_num_shared_experts": moe_num_shared_experts,
"moe_experts_per_token": moe_experts_per_token,
"moe_expert_ffn_mult": moe_expert_ffn_mult,
"moe_dense_layers": dense_layers_before_moe,
"moe_activation_ratio": moe_activation_ratio,
"moe_granularity_actual": moe_granularity_actual,
"flops_per_token_dense_like": dense_like_flops,
"flops_per_token_moe_active": active_flops_per_token,
"flops_per_token_dense_reference": dense_ref_flops,
})
if not use_dummy_wandb:
wandb_run.config.update({
"moe_num_experts": moe_num_experts,
"moe_num_shared_experts": moe_num_shared_experts,
"moe_experts_per_token": moe_experts_per_token,
"moe_expert_ffn_mult": moe_expert_ffn_mult,
"moe_dense_layers": dense_layers_before_moe,
"moe_activation_ratio": moe_activation_ratio,
"moe_granularity_actual": moe_granularity_actual,
"flops_per_token_dense_like": dense_like_flops,
"flops_per_token_moe_active": active_flops_per_token,
"flops_per_token_dense_reference": dense_ref_flops,
}, allow_val_change=True)
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
if num_iterations > 0:
print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif target_flops > 0:
# calculate the number of iterations from the target flops
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif target_param_data_ratio > 0:
# calculate the number of iterations from the target param data ratio
target_tokens = target_param_data_ratio * num_params
num_iterations = target_tokens // total_batch_size
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
if eval_every <= 0:
eval_every = max(1, num_iterations // 100)
print0(f"Auto-setting eval_every to {eval_every} (~1% of training)")
sequences_per_step = max(1, total_batch_size // max_seq_len)
checkpoint_every_steps = int(checkpoint_every_steps)
checkpoint_enabled = checkpoint_every_steps > 0
# -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adamw_optimizer, muon_optimizer = optimizers
# Initialize the DataLoaders for train/val
base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data
# Checkpoint output location
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", model_tag)
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
# Learning rate scheduler
def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)
if it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
# -----------------------------------------------------------------------------
# Training loop
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
# Keep track of total tokens/sequences processed for logging
tokens_per_step = total_batch_size
total_tokens_seen = 0
total_sequences_seen = 0
last_val_bpb = None
def save_base_checkpoint(step_idx):
optimizer_state = [opt.state_dict() for opt in optimizers]
meta = {
"step": step_idx,
"val_bpb": last_val_bpb,
"model_config": model_config_kwargs,
"user_config": user_config,
"device_batch_size": device_batch_size,
"max_seq_len": max_seq_len,
"total_tokens_seen": total_tokens_seen,
"total_sequences_seen": total_sequences_seen,
}
save_checkpoint(
checkpoint_dir,
step_idx,
orig_model.state_dict(),
optimizer_state,
meta,
)
# note that we run +1 steps only so that we can eval and save at the end
for step in range(num_iterations + 1):
last_step = step == num_iterations
flops_so_far = num_flops_per_token * total_batch_size * step
# once in a while: evaluate the val bpb (all ranks participate)
if last_step or step % eval_every == 0:
model.eval()
val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
last_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,
})
model.train()
# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
results = {}
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"core_metric": results["core_metric"],
"centered_results": results["centered_results"],
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,
})
model.train()
# once in a while: sample from the model (only on master process)
# use the original uncompiled model because the inputs keep changing shape
if master_process and sample_every > 0 and (last_step or (step > 0 and step % sample_every == 0)):
model.eval()
prompts = [
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
print0(tokenizer.decode(sample[0]))
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step:
save_base_checkpoint(step)
if last_step:
break
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# gradient clipping (TODO possibly expertiment with)
if grad_clip > 0.0:
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
# step the optimizers
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
total_tokens_seen += tokens_per_step
total_sequences_seen += sequences_per_step
current_step = step + 1
if master_process and checkpoint_enabled and not last_step and checkpoint_every_steps > 0 and current_step % checkpoint_every_steps == 0:
save_base_checkpoint(current_step)
dt = t1 - t0
# -------------------------------------------------------------------------
# logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
global_tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec (micro): {tok_per_sec:,} | tok/sec (global): {global_tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 100 == 0:
log_payload = {
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/tok_per_sec_global": global_tok_per_sec,
"train/mfu": mfu,
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,
}
if hasattr(orig_model, "get_moe_stats"):
log_payload.update(orig_model.get_moe_stats())
wandb_run.log(log_payload)
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
from nanochat.report import get_report
get_report().log(section="Base model training", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of parameters": num_params,
"FLOPs per token (MoE active)": f"{num_flops_per_token:e}",
"FLOPs per token (dense-like)": f"{dense_like_flops:e}",
"FLOPs per token (dense reference)": f"{dense_ref_flops:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
"DDP world size": ddp_world_size,
"warmup_ratio": warmup_ratio,
"warmdown_ratio": warmdown_ratio,
"final_lr_frac": final_lr_frac,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
"Final validation bpb": last_val_bpb,
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
"Total tokens processed": total_tokens_seen,
"Total sequences processed": total_sequences_seen,
}
])
# cleanup
wandb_run.finish() # wandb run finish
compute_cleanup()