mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 05:35:19 +00:00
505 lines
24 KiB
Python
505 lines
24 KiB
Python
"""
|
|
Supervised fine-tuning (SFT) the model.
|
|
Run as:
|
|
|
|
python -m scripts.chat_sft
|
|
|
|
Or torchrun for training:
|
|
|
|
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16
|
|
"""
|
|
|
|
import gc
|
|
import math
|
|
import argparse
|
|
import os
|
|
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
|
import time
|
|
import wandb
|
|
import torch
|
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
|
from nanochat.tokenizer import get_token_bytes
|
|
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
|
|
from nanochat.loss_eval import evaluate_bpb
|
|
import torch.distributed as dist
|
|
from nanochat.flash_attention import HAS_FA
|
|
from nanochat.dataloader import sft_data_loader_varlen
|
|
from nanochat.engine import Engine
|
|
from scripts.chat_eval import run_chat_eval
|
|
|
|
from tasks.common import TaskMixture
|
|
from tasks.gsm8k import GSM8K
|
|
from tasks.mmlu import MMLU
|
|
from tasks.smoltalk import SmolTalk
|
|
from tasks.customjson import CustomJSON
|
|
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# CLI arguments
|
|
parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the model")
|
|
# Logging
|
|
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
|
# Runtime
|
|
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
|
# Model loading
|
|
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
|
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
|
parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
|
|
# Training horizon
|
|
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
|
# Batch sizes (default: inherit from pretrained checkpoint)
|
|
parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)")
|
|
parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)")
|
|
parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)")
|
|
# Optimization (default: inherit from pretrained checkpoint)
|
|
parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)")
|
|
parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)")
|
|
parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)")
|
|
parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR")
|
|
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
|
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
|
|
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
|
# Evaluation
|
|
parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)")
|
|
parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on")
|
|
parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)")
|
|
parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE")
|
|
parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE")
|
|
# Data mixture
|
|
parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
|
|
parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
|
|
args = parser.parse_args()
|
|
user_config = vars(args).copy()
|
|
# -----------------------------------------------------------------------------
|
|
|
|
# Compute init
|
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
|
master_process = ddp_rank == 0
|
|
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
|
|
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
|
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
|
if device_type == "cuda":
|
|
gpu_device_name = torch.cuda.get_device_name(0)
|
|
gpu_peak_flops = get_peak_flops(gpu_device_name)
|
|
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
|
else:
|
|
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
|
|
|
# wandb logging init
|
|
use_dummy_wandb = args.run == "dummy" or not master_process
|
|
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
|
|
|
|
# Flash Attention status
|
|
if not HAS_FA:
|
|
print0("WARNING: Flash Attention not available, using PyTorch SDPA fallback. Training will be less efficient.")
|
|
|
|
# Load the model and tokenizer
|
|
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
|
|
|
# Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override)
|
|
pretrain_user_config = meta.get("user_config", {})
|
|
for name, fallback, source in [
|
|
("max_seq_len", 2048, meta),
|
|
("device_batch_size", 32, meta),
|
|
("total_batch_size", 524288, meta),
|
|
("embedding_lr", 0.3, pretrain_user_config),
|
|
("unembedding_lr", 0.004, pretrain_user_config),
|
|
("matrix_lr", 0.02, pretrain_user_config),
|
|
]:
|
|
arg_val = getattr(args, name)
|
|
pretrain_val = source.get(name)
|
|
if arg_val is None:
|
|
resolved = pretrain_val if pretrain_val is not None else fallback
|
|
setattr(args, name, resolved)
|
|
print0(f"Inherited {name}={resolved} from pretrained checkpoint")
|
|
elif pretrain_val is not None and arg_val != pretrain_val:
|
|
print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}")
|
|
else:
|
|
print0(f"Using {name}={arg_val}")
|
|
|
|
orig_model = model
|
|
model = torch.compile(model, dynamic=False)
|
|
depth = model.config.n_layer
|
|
num_flops_per_token = model.estimate_flops()
|
|
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
|
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
|
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
|
|
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
|
|
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
|
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
|
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
|
token_bytes = get_token_bytes(device=device)
|
|
|
|
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
|
# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero
|
|
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
|
|
|
|
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
|
|
# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the
|
|
# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and
|
|
# restore our fresh SFT LRs after loading.
|
|
base_dir = get_base_dir()
|
|
if args.load_optimizer:
|
|
optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step)
|
|
if optimizer_data is not None:
|
|
base_lrs = [group["lr"] for group in optimizer.param_groups]
|
|
optimizer.load_state_dict(optimizer_data)
|
|
del optimizer_data
|
|
for group, base_lr in zip(optimizer.param_groups, base_lrs):
|
|
group["lr"] = base_lr
|
|
print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)")
|
|
else:
|
|
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
|
|
|
|
# GradScaler for fp16 training (bf16/fp32 don't need it)
|
|
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
|
|
if scaler is not None:
|
|
print0("GradScaler enabled for fp16 training")
|
|
|
|
# Override the initial learning rate as a fraction of the base learning rate
|
|
for group in optimizer.param_groups:
|
|
group["lr"] = group["lr"] * args.init_lr_frac
|
|
group["initial_lr"] = group["lr"]
|
|
|
|
# SFT data mixture and DataLoader
|
|
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
|
train_tasks = [
|
|
SmolTalk(split="train"), # 460K rows of general conversations
|
|
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
|
CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these
|
|
*[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch
|
|
*[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch
|
|
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
|
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
|
]
|
|
train_dataset = TaskMixture(train_tasks)
|
|
print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})")
|
|
val_dataset = TaskMixture([
|
|
SmolTalk(split="test"), # 24K rows in test set
|
|
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
|
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
|
|
]) # total: 24K + 14K + 1.32K ~= 39K rows
|
|
|
|
# Pre-tokenize and pre-pack all conversations into batch plans.
|
|
# This runs the same best-fit packing algorithm offline at startup, so we know
|
|
# num_iterations and max_num_docs exactly before training starts.
|
|
def tokenize_and_pack_sft(dataset, tokenizer, B, T, bos_token, ddp_rank, ddp_world_size, buffer_size=100):
|
|
"""
|
|
Pre-tokenize and pre-pack SFT conversations using best-fit packing.
|
|
|
|
Preserves TaskMixture's shuffled ordering (no length sorting) and uses the
|
|
same buffer-based best-fit algorithm as the original inline dataloader.
|
|
|
|
Returns (conversations, batch_plans, total_micro_batches, max_num_docs).
|
|
"""
|
|
dataset_size = len(dataset)
|
|
buffer_capacity = B * T + 1
|
|
|
|
conversations = []
|
|
num_convs = (dataset_size - ddp_rank + ddp_world_size - 1) // ddp_world_size
|
|
cursor = ddp_rank
|
|
while cursor < dataset_size:
|
|
ids, mask = tokenizer.render_conversation(dataset[cursor])
|
|
conversations.append((ids, mask))
|
|
cursor += ddp_world_size
|
|
if len(conversations) % 5000 == 0:
|
|
print0(f"\r\033[KTokenizing: {len(conversations):,}/{num_convs:,} ({100*len(conversations)/num_convs:.0f}%)", end='', flush=True)
|
|
print0(f"\r\033[KTokenized {len(conversations):,} conversations", flush=True)
|
|
|
|
batch_plans = []
|
|
conv_buffer = []
|
|
fetch_cursor = 0
|
|
max_doc_count = 0
|
|
|
|
def refill():
|
|
nonlocal fetch_cursor
|
|
while len(conv_buffer) < buffer_size and fetch_cursor < len(conversations):
|
|
conv_buffer.append(fetch_cursor)
|
|
fetch_cursor += 1
|
|
|
|
while True:
|
|
refill()
|
|
if not conv_buffer:
|
|
break
|
|
batch_indices = []
|
|
pos = 0
|
|
while pos < buffer_capacity:
|
|
refill()
|
|
if not conv_buffer:
|
|
break
|
|
remaining = buffer_capacity - pos
|
|
best_buf_idx = -1
|
|
best_len = 0
|
|
for i, conv_idx in enumerate(conv_buffer):
|
|
conv_len = len(conversations[conv_idx][0])
|
|
if conv_len <= remaining and conv_len > best_len:
|
|
best_buf_idx = i
|
|
best_len = conv_len
|
|
if best_buf_idx >= 0:
|
|
batch_indices.append(conv_buffer.pop(best_buf_idx))
|
|
pos += best_len
|
|
else:
|
|
break
|
|
if batch_indices:
|
|
doc_count = len(batch_indices) + (1 if pos < buffer_capacity else 0)
|
|
max_doc_count = max(max_doc_count, doc_count)
|
|
batch_plans.append(batch_indices)
|
|
|
|
max_num_docs = math.ceil((max_doc_count + 1) / 16) * 16
|
|
return conversations, batch_plans, len(batch_plans), max(max_num_docs, 16)
|
|
|
|
bos_token = tokenizer.get_bos_token_id()
|
|
t_pack_start = time.time()
|
|
train_convs, train_plans, train_micro_batches, train_max_docs = tokenize_and_pack_sft(
|
|
train_dataset, tokenizer, args.device_batch_size, args.max_seq_len,
|
|
bos_token, ddp_rank, ddp_world_size)
|
|
t_pack_train = time.time()
|
|
val_convs, val_plans, val_micro_batches, val_max_docs = tokenize_and_pack_sft(
|
|
val_dataset, tokenizer, args.device_batch_size, args.max_seq_len,
|
|
bos_token, ddp_rank, ddp_world_size)
|
|
t_pack_val = time.time()
|
|
max_num_docs = max(train_max_docs, val_max_docs)
|
|
print0(f"Pre-tokenize & pack: train {t_pack_train - t_pack_start:.1f}s, val {t_pack_val - t_pack_train:.1f}s, total {t_pack_val - t_pack_start:.1f}s")
|
|
|
|
# Document length and packing statistics
|
|
import numpy as np
|
|
train_doc_lens = [len(ids) for ids, _ in train_convs]
|
|
train_docs_per_batch = [len(plan) for plan in train_plans]
|
|
train_tokens_per_batch = [sum(len(train_convs[i][0]) for i in plan) for plan in train_plans]
|
|
buffer_capacity = args.device_batch_size * args.max_seq_len + 1
|
|
train_packing_eff = [t / buffer_capacity for t in train_tokens_per_batch]
|
|
dl = np.array(train_doc_lens)
|
|
dpb = np.array(train_docs_per_batch)
|
|
pe = np.array(train_packing_eff)
|
|
print0(f"Train doc lengths: n={len(dl):,} | mean={dl.mean():.0f} median={np.median(dl):.0f} "
|
|
f"min={dl.min()} max={dl.max()} p5={np.percentile(dl,5):.0f} p95={np.percentile(dl,95):.0f}")
|
|
print0(f"Train docs/batch: n={len(dpb):,} | mean={dpb.mean():.1f} median={np.median(dpb):.0f} "
|
|
f"min={dpb.min()} max={dpb.max()} p5={np.percentile(dpb,5):.0f} p95={np.percentile(dpb,95):.0f}")
|
|
print0(f"Train packing eff: mean={pe.mean():.3f} median={np.median(pe):.3f} "
|
|
f"min={pe.min():.3f} max={pe.max():.3f}")
|
|
|
|
# num_iterations: exact count of optimization steps. The -1 accounts for the
|
|
# prefetch batch that the training loop requests but never trains on.
|
|
data_num_iterations = (train_micro_batches - 1) // grad_accum_steps
|
|
if args.num_iterations > 0:
|
|
num_iterations = min(args.num_iterations, data_num_iterations)
|
|
else:
|
|
num_iterations = data_num_iterations
|
|
if ddp:
|
|
num_iter_tensor = torch.tensor([num_iterations], dtype=torch.long, device=device)
|
|
dist.all_reduce(num_iter_tensor, op=dist.ReduceOp.MIN)
|
|
num_iterations = num_iter_tensor.item()
|
|
print0(f"Pre-packed {len(train_convs):,} train conversations into {train_micro_batches:,} micro-batches "
|
|
f"=> {num_iterations:,} optimization steps (max {max_num_docs} docs/batch)")
|
|
|
|
train_loader = sft_data_loader_varlen(
|
|
train_convs, train_plans, args.device_batch_size, args.max_seq_len,
|
|
max_num_docs, bos_token, device=device)
|
|
build_val_loader = lambda: sft_data_loader_varlen(
|
|
val_convs, val_plans, args.device_batch_size, args.max_seq_len,
|
|
max_num_docs, bos_token, device=device, cycle=True)
|
|
|
|
# Learning rate schedule (linear warmup, constant, linear warmdown)
|
|
def get_lr_multiplier(it):
|
|
warmup_iters = round(args.warmup_ratio * num_iterations)
|
|
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
|
if it < warmup_iters:
|
|
return (it + 1) / warmup_iters
|
|
elif it <= num_iterations - warmdown_iters:
|
|
return 1.0
|
|
else:
|
|
progress = (num_iterations - it) / warmdown_iters
|
|
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
|
|
|
# Momentum scheduler for Muon optimizer
|
|
def get_muon_momentum(it):
|
|
frac = min(it / 300, 1)
|
|
momentum = (1 - frac) * 0.85 + frac * 0.95
|
|
return momentum
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Training loop
|
|
x, y, cu_seqlens = next(train_loader) # prefetch the very first batch of data
|
|
min_val_bpb = float("inf")
|
|
smooth_train_loss = 0 # EMA of training loss
|
|
ema_beta = 0.9 # EMA decay factor
|
|
total_training_time = 0 # total wall-clock time of training
|
|
step = 0
|
|
while True:
|
|
last_step = step == num_iterations
|
|
flops_so_far = num_flops_per_token * args.total_batch_size * step
|
|
|
|
# once in a while: evaluate the val bpb (all ranks participate)
|
|
if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
|
|
model.eval()
|
|
val_loader = build_val_loader()
|
|
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
|
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
|
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
|
if val_bpb < min_val_bpb:
|
|
min_val_bpb = val_bpb
|
|
wandb_run.log({
|
|
"step": step,
|
|
"total_training_flops": flops_so_far,
|
|
"total_training_time": total_training_time,
|
|
"val/bpb": val_bpb,
|
|
})
|
|
model.train()
|
|
|
|
# once in a while: estimate the ChatCORE metric (all ranks participate)
|
|
# use the original uncompiled model because the inputs keep changing shape
|
|
chatcore_results = {}
|
|
if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)):
|
|
model.eval()
|
|
engine = Engine(orig_model, tokenizer)
|
|
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
|
|
categorical_tasks = {'ARC-Easy', 'ARC-Challenge', 'MMLU'}
|
|
baseline_accuracies = {
|
|
'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25,
|
|
'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0,
|
|
}
|
|
task_results = {}
|
|
for task_name in all_tasks:
|
|
limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
|
|
max_problems = None if limit < 0 else limit # -1 means no limit
|
|
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
|
|
batch_size=args.device_batch_size, max_problems=max_problems)
|
|
task_results[task_name] = acc
|
|
print0(f" {task_name}: {100*acc:.2f}%")
|
|
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
|
|
def centered_mean(tasks):
|
|
return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks)
|
|
chatcore = centered_mean(all_tasks)
|
|
chatcore_cat = centered_mean(categorical_tasks)
|
|
print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}")
|
|
wandb_run.log({
|
|
"step": step,
|
|
"total_training_flops": flops_so_far,
|
|
"chatcore_metric": chatcore,
|
|
"chatcore_cat": chatcore_cat,
|
|
**{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()},
|
|
})
|
|
model.train()
|
|
|
|
# save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard)
|
|
if last_step:
|
|
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
|
|
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
|
|
save_checkpoint(
|
|
checkpoint_dir,
|
|
step,
|
|
orig_model.state_dict(),
|
|
optimizer.state_dict(),
|
|
{
|
|
"step": step,
|
|
"val_bpb": val_bpb, # loss at last step
|
|
"model_config": {
|
|
"sequence_len": args.max_seq_len,
|
|
"vocab_size": tokenizer.get_vocab_size(),
|
|
"n_layer": depth,
|
|
"n_head": model.config.n_head,
|
|
"n_kv_head": model.config.n_kv_head,
|
|
"n_embd": model.config.n_embd,
|
|
"window_pattern": model.config.window_pattern,
|
|
},
|
|
"user_config": user_config, # inputs to the training script
|
|
},
|
|
rank=ddp_rank,
|
|
)
|
|
|
|
if last_step:
|
|
break
|
|
|
|
# -------------------------------------------------------------------------
|
|
# single training step
|
|
# evaluate the gradient
|
|
synchronize()
|
|
t0 = time.time()
|
|
for micro_step in range(grad_accum_steps):
|
|
loss = model(x, y, cu_seqlens=cu_seqlens)
|
|
train_loss = loss.detach() # for logging
|
|
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
|
if scaler is not None:
|
|
scaler.scale(loss).backward()
|
|
else:
|
|
loss.backward()
|
|
x, y, cu_seqlens = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
|
# step the optimizer
|
|
lrm = get_lr_multiplier(step)
|
|
muon_momentum = get_muon_momentum(step)
|
|
for group in optimizer.param_groups:
|
|
group["lr"] = group["initial_lr"] * lrm
|
|
if group['kind'] == 'muon':
|
|
group["momentum"] = muon_momentum
|
|
if scaler is not None:
|
|
scaler.unscale_(optimizer)
|
|
if is_ddp_initialized():
|
|
for v in scaler._found_inf_per_device(optimizer).values():
|
|
dist.all_reduce(v, op=dist.ReduceOp.MAX)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
else:
|
|
optimizer.step()
|
|
model.zero_grad(set_to_none=True)
|
|
synchronize()
|
|
t1 = time.time()
|
|
dt = t1 - t0
|
|
# -------------------------------------------------------------------------
|
|
|
|
# State
|
|
step += 1
|
|
|
|
# logging
|
|
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
|
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
|
pct_done = 100 * step / num_iterations
|
|
tok_per_sec = int(args.total_batch_size / dt)
|
|
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
|
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
|
if step > 10:
|
|
total_training_time += dt # only count the time after the first 10 steps
|
|
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
|
if step % 10 == 0:
|
|
wandb_run.log({
|
|
"step": step,
|
|
"total_training_flops": flops_so_far,
|
|
"total_training_time": total_training_time,
|
|
"train/loss": debiased_smooth_loss,
|
|
"train/lrm": lrm,
|
|
"train/dt": dt,
|
|
"train/tok_per_sec": tok_per_sec,
|
|
"train/mfu": mfu,
|
|
})
|
|
|
|
# The garbage collector spends ~500ms scanning for cycles quite frequently.
|
|
# We manually manage it to avoid these pauses during training.
|
|
if step == 1:
|
|
gc.collect() # manually collect a lot of garbage from setup
|
|
gc.freeze() # freeze all currently surviving objects and exclude them from GC
|
|
gc.disable() # disable GC entirely except:
|
|
elif step % 5000 == 0: # every 5000 steps...
|
|
gc.collect() # manually collect, just to be safe for very long runs
|
|
|
|
# print a few more stats
|
|
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
|
print0(f"Total training time: {total_training_time/60:.2f}m")
|
|
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
|
|
|
# Log to report
|
|
from nanochat.report import get_report
|
|
get_report().log(section="SFT", data=[
|
|
user_config, # CLI args
|
|
{ # stats about the training setup
|
|
"Number of iterations": step,
|
|
"DDP world size": ddp_world_size,
|
|
},
|
|
{ # stats about training outcomes
|
|
"Minimum validation bpb": min_val_bpb,
|
|
}
|
|
])
|
|
|
|
# cleanup
|
|
wandb_run.finish() # wandb run finish
|
|
compute_cleanup()
|