diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 5a95fbf..e24533a 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -170,3 +170,22 @@ def load_model(source, *args, **kwargs): base_dir = get_base_dir() checkpoints_dir = os.path.join(base_dir, model_dir) return load_model_from_dir(checkpoints_dir, *args, **kwargs) + +def load_optimizer_state(source, device, rank, model_tag=None, step=None): + """Load just the optimizer shard for a given rank, without re-loading the model.""" + model_dir = { + "base": "base_checkpoints", + "sft": "chatsft_checkpoints", + "rl": "chatrl_checkpoints", + }[source] + base_dir = get_base_dir() + checkpoints_dir = os.path.join(base_dir, model_dir) + if model_tag is None: + model_tag = find_largest_model(checkpoints_dir) + checkpoint_dir = os.path.join(checkpoints_dir, model_tag) + if step is None: + step = find_last_step(checkpoint_dir) + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") + log0(f"Loading optimizer state from {optimizer_path}") + optimizer_data = torch.load(optimizer_path, map_location=device) + return optimizer_data diff --git a/scripts/base_train.py b/scripts/base_train.py index 996b2ba..bb76e90 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -468,6 +468,7 @@ while True: "user_config": user_config, # inputs to the training script "device_batch_size": args.device_batch_size, "max_seq_len": args.max_seq_len, + "total_batch_size": total_batch_size, "dataloader_state_dict": dataloader_state_dict, "loop_state": { # all loop state (other than step) so that we can resume training "min_val_bpb": min_val_bpb, diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 4c81f06..edac3d8 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -9,6 +9,7 @@ Or torchrun for training: torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16 """ +import gc import argparse import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" @@ -16,12 +17,14 @@ import time import wandb import torch from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops from nanochat.tokenizer import get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint +from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state from nanochat.loss_eval import evaluate_bpb -from nanochat.checkpoint_manager import load_model import torch.distributed as dist +from nanochat.flash_attention import HAS_FA3 +from nanochat.engine import Engine +from scripts.chat_eval import run_chat_eval from tasks.common import TaskMixture from tasks.gsm8k import GSM8K @@ -37,27 +40,30 @@ parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the m parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") -parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") +parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") # Training horizon parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") -# Batch sizes -parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") -parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") -parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") -# Optimization -parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") -parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR") +# Batch sizes (default: inherit from pretrained checkpoint) +parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)") +parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)") +parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)") +# Optimization (default: inherit from pretrained checkpoint) +parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)") +parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)") +parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)") +parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR") +parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") +parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") +parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") # Evaluation -parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") -# Output -parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report") +parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)") +parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on") +parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)") +parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE") +parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -66,20 +72,48 @@ user_config = vars(args).copy() device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 -ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 +if device_type == "cuda": + gpu_device_name = torch.cuda.get_device_name(0) + gpu_peak_flops = get_peak_flops(gpu_device_name) + print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") +else: + gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) +# Flash Attention status +if not HAS_FA3: + print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.") + # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) -pretrain_batch_size = meta.get("device_batch_size", None) -if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: - print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") + +# Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override) +pretrain_user_config = meta.get("user_config", {}) +for name, fallback, source in [ + ("max_seq_len", 2048, meta), + ("device_batch_size", 32, meta), + ("total_batch_size", 524288, meta), + ("embedding_lr", 0.3, pretrain_user_config), + ("unembedding_lr", 0.004, pretrain_user_config), + ("matrix_lr", 0.02, pretrain_user_config), +]: + arg_val = getattr(args, name) + pretrain_val = source.get(name) + if arg_val is None: + resolved = pretrain_val if pretrain_val is not None else fallback + setattr(args, name, resolved) + print0(f"Inherited {name}={resolved} from pretrained checkpoint") + elif pretrain_val is not None and arg_val != pretrain_val: + print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}") + else: + print0(f"Using {name}={arg_val}") + orig_model = model model = torch.compile(model, dynamic=False) depth = model.config.n_layer @@ -94,14 +128,23 @@ print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation ste token_bytes = get_token_bytes(device=device) # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) -optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay) +# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero +optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0) + +# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) +base_dir = get_base_dir() +if args.load_optimizer: + optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step) + optimizer.load_state_dict(optimizer_data) + del optimizer_data + print0("Loaded optimizer state from pretrained checkpoint") + # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: group["lr"] = group["lr"] * args.init_lr_frac group["initial_lr"] = group["lr"] # SFT data mixture and DataLoader -base_dir = get_base_dir() identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") train_dataset = TaskMixture([ SmolTalk(split="train"), # 460K rows of general conversations @@ -236,10 +279,17 @@ train_loader = sft_data_generator_bos_bestfit("train") build_val_loader = lambda: sft_data_generator_bos_bestfit("val") progress = 0 # will go from 0 to 1 over the course of the epoch -# Learning rate scheduler +# Learning rate schedule (linear warmup, constant, linear warmdown) +# Same shape as base_train but uses progress (0→1) instead of absolute step counts, +# because SFT doesn't always know num_iterations in advance (dataset-driven stopping). def get_lr_multiplier(progress): - # first 80% of training: no decay, then linearly ramp down to 0. - return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 + if progress < args.warmup_ratio: + return (progress + 1e-8) / args.warmup_ratio + elif progress <= 1.0 - args.warmdown_ratio: + return 1.0 + else: + decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio + return (1 - decay) * 1.0 + decay * args.final_lr_frac # Momentum scheduler for Muon optimizer def get_muon_momentum(it): @@ -282,8 +332,44 @@ while True: }) model.train() - # save checkpoint at the end of the run (only on master process) - if master_process and last_step and not args.dry_run: + # once in a while: estimate the ChatCORE metric (all ranks participate) + # use the original uncompiled model because the inputs keep changing shape + chatcore_results = {} + if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)): + model.eval() + engine = Engine(orig_model, tokenizer) + all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee'] + categorical_tasks = {'ARC-Easy', 'ARC-Challenge', 'MMLU'} + baseline_accuracies = { + 'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25, + 'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0, + } + task_results = {} + for task_name in all_tasks: + limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample + max_problems = None if limit < 0 else limit # -1 means no limit + with autocast_ctx: + acc = run_chat_eval(task_name, orig_model, tokenizer, engine, + batch_size=args.device_batch_size, max_problems=max_problems) + task_results[task_name] = acc + print0(f" {task_name}: {100*acc:.2f}%") + # Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect) + def centered_mean(tasks): + return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks) + chatcore = centered_mean(all_tasks) + chatcore_cat = centered_mean(categorical_tasks) + print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}") + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "chatcore_metric": chatcore, + "chatcore_cat": chatcore_cat, + **{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()}, + }) + model.train() + + # save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard) + if last_step: output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12 checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname) save_checkpoint( @@ -304,7 +390,8 @@ while True: "window_pattern": model.config.window_pattern, }, "user_config": user_config, # inputs to the training script - } + }, + rank=ddp_rank, ) if last_step: @@ -346,8 +433,7 @@ while True: pct_done = 100 * progress tok_per_sec = int(args.total_batch_size / dt) flops_per_sec = num_flops_per_token * args.total_batch_size / dt - promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) if step > 10: total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m") @@ -364,24 +450,32 @@ while True: "train/epoch": current_epoch, }) + # The garbage collector spends ~500ms scanning for cycles quite frequently. + # We manually manage it to avoid these pauses during training. + if step == 1: + gc.collect() # manually collect a lot of garbage from setup + gc.freeze() # freeze all currently surviving objects and exclude them from GC + gc.disable() # disable GC entirely except: + elif step % 5000 == 0: # every 5000 steps... + gc.collect() # manually collect, just to be safe for very long runs + # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") # Log to report -if not args.dry_run: - from nanochat.report import get_report - get_report().log(section="SFT", data=[ - user_config, # CLI args - { # stats about the training setup - "Number of iterations": step, - "DDP world size": ddp_world_size, - }, - { # stats about training outcomes - "Minimum validation bpb": min_val_bpb, - } - ]) +from nanochat.report import get_report +get_report().log(section="SFT", data=[ + user_config, # CLI args + { # stats about the training setup + "Number of iterations": step, + "DDP world size": ddp_world_size, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb, + } +]) # cleanup wandb_run.finish() # wandb run finish