nanochat/scripts/chat_sft.py

402 lines
19 KiB
Python

"""
Supervised fine-tuning (SFT) the model.
Run as:
python -m scripts.chat_sft
Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16
"""
import argparse
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import time
import wandb
import torch
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.checkpoint_manager import load_model
import torch.distributed as dist
from tasks.common import TaskMixture
from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
# CLI arguments
parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
# Batch sizes
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
# Optimization
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
# Evaluation
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
# Output
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
# Override the initial learning rate as a fraction of the base learning rate
for group in optimizer.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"]
# SFT data mixture and DataLoader
base_dir = get_base_dir()
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
GSM8K(subset="main", split="train"), # 2 epochs of GSM8K
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
]) # total: 24K + 14K + 1.32K ~= 39K rows
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
# A big problem is that we don't know the final num_iterations in advance. So we create
# these two global variables and update them from within the data generator.
last_step = False # we will toggle this to True when we reach the end of the training dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
current_epoch = 1 # track epoch for logging
def sft_data_generator_bos_bestfit(split, buffer_size=100):
"""
BOS-aligned dataloader for SFT with bestfit-pad packing.
Each row in the batch starts with BOS (beginning of a conversation).
Conversations are packed using best-fit algorithm. When no conversation fits,
the row is padded. If the row is still empty, we fall back to truncating one
conversation to avoid fully padded rows.
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
"""
global last_step, approx_progress, current_epoch
assert split in {"train", "val"}, "split must be 'train' or 'val'"
dataset = train_dataset if split == "train" else val_dataset
dataset_size = len(dataset)
assert dataset_size > 0
row_capacity = args.max_seq_len + 1 # +1 for target at last position
bos_token = tokenizer.get_bos_token_id()
# Conversation buffer: list of token lists
conv_buffer = []
cursor = ddp_rank # Each rank processes different conversations (for fetching)
consumed = ddp_rank # Track actual consumption separately from buffering
epoch = 1
it = 0 # iteration counter
def refill_buffer():
nonlocal cursor, epoch
while len(conv_buffer) < buffer_size:
conversation = dataset[cursor]
ids, _ = tokenizer.render_conversation(conversation)
conv_buffer.append(ids)
cursor += ddp_world_size
if cursor >= dataset_size:
cursor = cursor % dataset_size
epoch += 1
# Note: last_step is now triggered based on consumption, not fetching
while True:
rows = []
row_lengths = [] # Track actual content length (excluding padding) for each row
for _ in range(args.device_batch_size):
row = []
padded = False
while len(row) < row_capacity:
# Ensure buffer has conversations
while len(conv_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - len(row)
# Find largest conversation that fits entirely
best_idx = -1
best_len = 0
for i, conv in enumerate(conv_buffer):
conv_len = len(conv)
if conv_len <= remaining and conv_len > best_len:
best_idx = i
best_len = conv_len
if best_idx >= 0:
# Found a conversation that fits - use it entirely
conv = conv_buffer.pop(best_idx)
row.extend(conv)
consumed += ddp_world_size # Track actual consumption
else:
if len(row) == 0:
# No conversation fits even on an empty row.
# Truncate the shortest one to avoid fully padded rows.
trunc_idx = -1
trunc_len = None
for i, conv in enumerate(conv_buffer):
conv_len = len(conv)
if conv_len > remaining and (trunc_len is None or conv_len < trunc_len):
trunc_idx = i
trunc_len = conv_len
conv = conv_buffer.pop(trunc_idx)
row.extend(conv[:row_capacity])
consumed += ddp_world_size
else:
# No conversation fits - pad the remainder
content_len = len(row)
row.extend([bos_token] * remaining) # Pad with BOS tokens
padded = True
break # Row is now full (with padding or truncation)
# Track content length: full row if no padding, otherwise the length before padding
if padded:
row_lengths.append(content_len)
else:
row_lengths.append(row_capacity)
rows.append(row[:row_capacity])
# Stopping condition to respect num_iterations, if given
it += 1
if 0 < args.num_iterations <= it and split == "train":
last_step = True
# Update progress tracking (based on consumed, not cursor, to account for buffering)
if split == "train":
current_epoch = epoch
if args.num_iterations > 0:
approx_progress = it / args.num_iterations
else:
approx_progress = consumed / dataset_size
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
if consumed >= dataset_size:
last_step = True
# Build tensors
use_cuda = device_type == "cuda"
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
# Mask out padding positions in targets (set to -1 = ignore_index)
# For each row, positions >= (content_length - 1) in targets should be masked
for i, content_len in enumerate(row_lengths):
if content_len < row_capacity:
targets[i, content_len-1:] = -1
yield inputs, targets
train_loader = sft_data_generator_bos_bestfit("train")
build_val_loader = lambda: sft_data_generator_bos_bestfit("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
def get_lr_multiplier(progress):
# first 80% of training: no decay, then linearly ramp down to 0.
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
# -----------------------------------------------------------------------------
# Training loop
x, y = next(train_loader) # prefetch the very first batch of data
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
step = 0
while True:
flops_so_far = num_flops_per_token * args.total_batch_size * step
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
if ddp:
last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
last_step = bool(last_step_tensor.item())
# once in a while: evaluate the val bpb (all ranks participate)
if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
model.eval()
val_loader = build_val_loader()
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
})
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step and not args.dry_run:
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "sft_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(),
optimizer.state_dict(),
{
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": {
"sequence_len": args.max_seq_len,
"vocab_size": tokenizer.get_vocab_size(),
"n_layer": depth,
"n_head": model.config.n_head,
"n_kv_head": model.config.n_kv_head,
"n_embd": model.config.n_embd,
},
"user_config": user_config, # inputs to the training script
}
)
if last_step:
break
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
progress = max(progress, approx_progress) # only increase progress monotonically
# step the optimizer
lrm = get_lr_multiplier(progress)
muon_momentum = get_muon_momentum(step)
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon':
group["momentum"] = muon_momentum
optimizer.step()
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
dt = t1 - t0
# -------------------------------------------------------------------------
# State
step += 1
# logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * progress
tok_per_sec = int(args.total_batch_size / dt)
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
if step % 10 == 0:
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
"train/epoch": current_epoch,
})
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
if not args.dry_run:
from nanochat.report import get_report
get_report().log(section="SFT", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of iterations": step,
"DDP world size": ddp_world_size,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
}
])
# cleanup
wandb_run.finish() # wandb run finish
compute_cleanup()