From d9678ff0f9c5d9967512adce23cb60ea0a5cd3f3 Mon Sep 17 00:00:00 2001 From: Alan Date: Sun, 15 Feb 2026 14:31:54 +0000 Subject: [PATCH 1/8] Save FP8 tensors in autograd ctx instead of full-precision inputs Store quantized input/weight and their inverse scales in _Float8Matmul ctx to avoid re-quantization in backward and reduce saved-activation memory without changing numerics. --- nanochat/fp8.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 9d9e9c3..8649760 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -123,19 +123,16 @@ def _to_col_major(x): class _Float8Matmul(torch.autograd.Function): """Custom autograd for the three FP8 GEMMs of a Linear layer. - The forward saves input and weight in their original precision for the - backward pass. Each GEMM independently re-quantizes its operands to FP8. - (We don't reuse the forward's FP8 tensors in backward — the backward might - want different precision, and saving FP8 would lose information.) + The forward quantizes input and weight to FP8 and saves + the quantized tensors + scales for backward. """ @staticmethod def forward(ctx, input_2d, weight): - ctx.save_for_backward(input_2d, weight) - # Quantize both operands to e4m3 (higher precision format) input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn) weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn) + ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv) # output = input @ weight.T # input_fp8 is [B, K] contiguous = row-major (good for first arg) @@ -156,13 +153,12 @@ class _Float8Matmul(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): - input_2d, weight = ctx.saved_tensors + in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors # === GEMM 1: grad_input = grad_output @ weight === # Shapes: [B, N] @ [N, K] -> [B, K] # Gradients use e5m2 (wider range), weights use e4m3 (higher precision) go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2) - w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn) # go_fp8 is [B, N] contiguous = row-major, good for first arg # w_fp8 is [N, K] contiguous = row-major, need column-major for second arg w_col = _to_col_major(w_fp8) @@ -178,7 +174,6 @@ class _Float8Matmul(torch.autograd.Function): # === GEMM 2: grad_weight = grad_output.T @ input === # Shapes: [N, B] @ [B, K] -> [N, K] go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) - in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn) # go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg. # Transposing gives column-major, but first arg needs row-major, # so we must call .contiguous() to physically rearrange the memory. From 124f49be98e53bf734e2918dc58a580dbf31a80c Mon Sep 17 00:00:00 2001 From: Alan Date: Sun, 15 Feb 2026 15:41:33 +0000 Subject: [PATCH 2/8] Removed redundant qunatization of gradients --- nanochat/fp8.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 8649760..3e88285 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -173,16 +173,15 @@ class _Float8Matmul(torch.autograd.Function): # === GEMM 2: grad_weight = grad_output.T @ input === # Shapes: [N, B] @ [B, K] -> [N, K] - go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) - # go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg. + # go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg. # Transposing gives column-major, but first arg needs row-major, # so we must call .contiguous() to physically rearrange the memory. - go_T = go_fp8_2.t().contiguous() # [N, B] row-major + go_T = go_fp8.t().contiguous() # [N, B] row-major in_col = _to_col_major(in_fp8) # [B, K] column-major grad_weight = torch._scaled_mm( go_T, in_col, - scale_a=go_inv_2, + scale_a=go_inv, scale_b=in_inv, out_dtype=grad_output.dtype, use_fast_accum=False, From 77f8fb83037d4bb294fb97f987f27c98526c1d96 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Feb 2026 14:41:53 +0000 Subject: [PATCH 3/8] a number of upgrades to SFT script to bring it up to date w.r.t. pretraining and tuning some of its kwargs based on sweeps --- nanochat/checkpoint_manager.py | 19 ++++ scripts/base_train.py | 1 + scripts/chat_sft.py | 184 +++++++++++++++++++++++++-------- 3 files changed, 159 insertions(+), 45 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 5a95fbf..e24533a 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -170,3 +170,22 @@ def load_model(source, *args, **kwargs): base_dir = get_base_dir() checkpoints_dir = os.path.join(base_dir, model_dir) return load_model_from_dir(checkpoints_dir, *args, **kwargs) + +def load_optimizer_state(source, device, rank, model_tag=None, step=None): + """Load just the optimizer shard for a given rank, without re-loading the model.""" + model_dir = { + "base": "base_checkpoints", + "sft": "chatsft_checkpoints", + "rl": "chatrl_checkpoints", + }[source] + base_dir = get_base_dir() + checkpoints_dir = os.path.join(base_dir, model_dir) + if model_tag is None: + model_tag = find_largest_model(checkpoints_dir) + checkpoint_dir = os.path.join(checkpoints_dir, model_tag) + if step is None: + step = find_last_step(checkpoint_dir) + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") + log0(f"Loading optimizer state from {optimizer_path}") + optimizer_data = torch.load(optimizer_path, map_location=device) + return optimizer_data diff --git a/scripts/base_train.py b/scripts/base_train.py index 996b2ba..bb76e90 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -468,6 +468,7 @@ while True: "user_config": user_config, # inputs to the training script "device_batch_size": args.device_batch_size, "max_seq_len": args.max_seq_len, + "total_batch_size": total_batch_size, "dataloader_state_dict": dataloader_state_dict, "loop_state": { # all loop state (other than step) so that we can resume training "min_val_bpb": min_val_bpb, diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 4c81f06..edac3d8 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -9,6 +9,7 @@ Or torchrun for training: torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16 """ +import gc import argparse import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" @@ -16,12 +17,14 @@ import time import wandb import torch from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops from nanochat.tokenizer import get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint +from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state from nanochat.loss_eval import evaluate_bpb -from nanochat.checkpoint_manager import load_model import torch.distributed as dist +from nanochat.flash_attention import HAS_FA3 +from nanochat.engine import Engine +from scripts.chat_eval import run_chat_eval from tasks.common import TaskMixture from tasks.gsm8k import GSM8K @@ -37,27 +40,30 @@ parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the m parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") -parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") +parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") # Training horizon parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") -# Batch sizes -parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") -parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") -parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") -# Optimization -parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") -parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR") +# Batch sizes (default: inherit from pretrained checkpoint) +parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)") +parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)") +parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)") +# Optimization (default: inherit from pretrained checkpoint) +parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)") +parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)") +parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)") +parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR") +parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") +parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") +parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") # Evaluation -parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") -# Output -parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report") +parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)") +parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on") +parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)") +parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE") +parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -66,20 +72,48 @@ user_config = vars(args).copy() device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 -ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 +if device_type == "cuda": + gpu_device_name = torch.cuda.get_device_name(0) + gpu_peak_flops = get_peak_flops(gpu_device_name) + print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") +else: + gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) +# Flash Attention status +if not HAS_FA3: + print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.") + # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) -pretrain_batch_size = meta.get("device_batch_size", None) -if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: - print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") + +# Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override) +pretrain_user_config = meta.get("user_config", {}) +for name, fallback, source in [ + ("max_seq_len", 2048, meta), + ("device_batch_size", 32, meta), + ("total_batch_size", 524288, meta), + ("embedding_lr", 0.3, pretrain_user_config), + ("unembedding_lr", 0.004, pretrain_user_config), + ("matrix_lr", 0.02, pretrain_user_config), +]: + arg_val = getattr(args, name) + pretrain_val = source.get(name) + if arg_val is None: + resolved = pretrain_val if pretrain_val is not None else fallback + setattr(args, name, resolved) + print0(f"Inherited {name}={resolved} from pretrained checkpoint") + elif pretrain_val is not None and arg_val != pretrain_val: + print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}") + else: + print0(f"Using {name}={arg_val}") + orig_model = model model = torch.compile(model, dynamic=False) depth = model.config.n_layer @@ -94,14 +128,23 @@ print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation ste token_bytes = get_token_bytes(device=device) # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) -optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay) +# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero +optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0) + +# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) +base_dir = get_base_dir() +if args.load_optimizer: + optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step) + optimizer.load_state_dict(optimizer_data) + del optimizer_data + print0("Loaded optimizer state from pretrained checkpoint") + # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: group["lr"] = group["lr"] * args.init_lr_frac group["initial_lr"] = group["lr"] # SFT data mixture and DataLoader -base_dir = get_base_dir() identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") train_dataset = TaskMixture([ SmolTalk(split="train"), # 460K rows of general conversations @@ -236,10 +279,17 @@ train_loader = sft_data_generator_bos_bestfit("train") build_val_loader = lambda: sft_data_generator_bos_bestfit("val") progress = 0 # will go from 0 to 1 over the course of the epoch -# Learning rate scheduler +# Learning rate schedule (linear warmup, constant, linear warmdown) +# Same shape as base_train but uses progress (0→1) instead of absolute step counts, +# because SFT doesn't always know num_iterations in advance (dataset-driven stopping). def get_lr_multiplier(progress): - # first 80% of training: no decay, then linearly ramp down to 0. - return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 + if progress < args.warmup_ratio: + return (progress + 1e-8) / args.warmup_ratio + elif progress <= 1.0 - args.warmdown_ratio: + return 1.0 + else: + decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio + return (1 - decay) * 1.0 + decay * args.final_lr_frac # Momentum scheduler for Muon optimizer def get_muon_momentum(it): @@ -282,8 +332,44 @@ while True: }) model.train() - # save checkpoint at the end of the run (only on master process) - if master_process and last_step and not args.dry_run: + # once in a while: estimate the ChatCORE metric (all ranks participate) + # use the original uncompiled model because the inputs keep changing shape + chatcore_results = {} + if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)): + model.eval() + engine = Engine(orig_model, tokenizer) + all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee'] + categorical_tasks = {'ARC-Easy', 'ARC-Challenge', 'MMLU'} + baseline_accuracies = { + 'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25, + 'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0, + } + task_results = {} + for task_name in all_tasks: + limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample + max_problems = None if limit < 0 else limit # -1 means no limit + with autocast_ctx: + acc = run_chat_eval(task_name, orig_model, tokenizer, engine, + batch_size=args.device_batch_size, max_problems=max_problems) + task_results[task_name] = acc + print0(f" {task_name}: {100*acc:.2f}%") + # Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect) + def centered_mean(tasks): + return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks) + chatcore = centered_mean(all_tasks) + chatcore_cat = centered_mean(categorical_tasks) + print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}") + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "chatcore_metric": chatcore, + "chatcore_cat": chatcore_cat, + **{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()}, + }) + model.train() + + # save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard) + if last_step: output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12 checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname) save_checkpoint( @@ -304,7 +390,8 @@ while True: "window_pattern": model.config.window_pattern, }, "user_config": user_config, # inputs to the training script - } + }, + rank=ddp_rank, ) if last_step: @@ -346,8 +433,7 @@ while True: pct_done = 100 * progress tok_per_sec = int(args.total_batch_size / dt) flops_per_sec = num_flops_per_token * args.total_batch_size / dt - promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) if step > 10: total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m") @@ -364,24 +450,32 @@ while True: "train/epoch": current_epoch, }) + # The garbage collector spends ~500ms scanning for cycles quite frequently. + # We manually manage it to avoid these pauses during training. + if step == 1: + gc.collect() # manually collect a lot of garbage from setup + gc.freeze() # freeze all currently surviving objects and exclude them from GC + gc.disable() # disable GC entirely except: + elif step % 5000 == 0: # every 5000 steps... + gc.collect() # manually collect, just to be safe for very long runs + # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") # Log to report -if not args.dry_run: - from nanochat.report import get_report - get_report().log(section="SFT", data=[ - user_config, # CLI args - { # stats about the training setup - "Number of iterations": step, - "DDP world size": ddp_world_size, - }, - { # stats about training outcomes - "Minimum validation bpb": min_val_bpb, - } - ]) +from nanochat.report import get_report +get_report().log(section="SFT", data=[ + user_config, # CLI args + { # stats about the training setup + "Number of iterations": step, + "DDP world size": ddp_world_size, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb, + } +]) # cleanup wandb_run.finish() # wandb run finish From 1415fb761797f94a4933c1a79f8d1fc2e63b9793 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Feb 2026 20:23:04 +0000 Subject: [PATCH 4/8] tune the data mixture a bit, load optimizer by default when SFT. These were confirmed to be best settings from sweeps of sft --- nanochat/checkpoint_manager.py | 3 +++ scripts/chat_sft.py | 33 +++++++++++++++++++++++---------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index e24533a..f71524e 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -186,6 +186,9 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None): if step is None: step = find_last_step(checkpoint_dir) optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") + if not os.path.exists(optimizer_path): + log0(f"Optimizer checkpoint not found: {optimizer_path}") + return None log0(f"Loading optimizer state from {optimizer_path}") optimizer_data = torch.load(optimizer_path, map_location=device) return optimizer_data diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index edac3d8..a783ed2 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -43,7 +43,7 @@ parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (e # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") -parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") +parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") # Training horizon parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") # Batch sizes (default: inherit from pretrained checkpoint) @@ -64,6 +64,9 @@ parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number o parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)") parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE") parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE") +# Data mixture +parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)") +parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -132,12 +135,21 @@ token_bytes = get_token_bytes(device=device) optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0) # Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) +# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the +# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and +# restore our fresh SFT LRs after loading. base_dir = get_base_dir() if args.load_optimizer: optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step) - optimizer.load_state_dict(optimizer_data) - del optimizer_data - print0("Loaded optimizer state from pretrained checkpoint") + if optimizer_data is not None: + base_lrs = [group["lr"] for group in optimizer.param_groups] + optimizer.load_state_dict(optimizer_data) + del optimizer_data + for group, base_lr in zip(optimizer.param_groups, base_lrs): + group["lr"] = base_lr + print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)") + else: + print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)") # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: @@ -146,16 +158,17 @@ for group in optimizer.param_groups: # SFT data mixture and DataLoader identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") -train_dataset = TaskMixture([ +train_tasks = [ SmolTalk(split="train"), # 460K rows of general conversations - MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE - GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use - GSM8K(subset="main", split="train"), # 2 epochs of GSM8K CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations - CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these + CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these + *[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch + *[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) -]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows +] +train_dataset = TaskMixture(train_tasks) +print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})") val_dataset = TaskMixture([ SmolTalk(split="test"), # 24K rows in test set MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios From f5fe7925ed913fbddbc268043c79f82c354c43de Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 17 Feb 2026 15:44:54 +0000 Subject: [PATCH 5/8] update dev log with recent --- dev/LOG.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/dev/LOG.md b/dev/LOG.md index dec2c06..c0d35e4 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,38 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-17: Pretraining Data Mixture Experiment (negative) + +Tried [hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT](https://huggingface.co/datasets/hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT), a mixture of FinePDFs, DCLM, and FineWeb-EDU. Slightly worse on both model sizes tested: + +- d26 (GPT-2): CORE 0.2602 → 0.2549 +- d18: CORE 0.199 → 0.192 + +This is the fourth failed attempt to beat pure FineWeb-EDU on CORE score. + +--- + +## 2026-02-16: SFT Script Upgrades + +Brought `chat_sft.py` up to parity with `base_train.py` and tuned settings based on SFT sweeps. + +Tuning: + +- **Optimizer warm-start** (`--load-optimizer=1`, default on): loads pretrained momentum buffers via new `load_optimizer_state()` in `checkpoint_manager.py`. LRs are reset to fresh SFT values after load. Loading the optimizer works slightly better but not by too much. +- **LR schedule**: replaced "constant 80%, linear to 0" with warmup/constant/warmdown matching `base_train.py` (`--warmup-ratio`, `--warmdown-ratio`, `--init-lr-frac`, `--final-lr-frac`). Similar to pretraining, warmdown ratio of 0.5 worked the best. `--init-lr-frac` changed from 1.0 slightly lower to 0.8. +- **LR tuning**: attempted to tune all the individual LRs (e.g. does SFT prefer lower LR for embeddings? etc.) but all of this produced negative results. +- **Data mixture**: MMLU epochs 1→3, GSM8K epochs 2→4 (confirmed best from sweeps). Epoch counts now configurable via `--mmlu-epochs` / `--gsm8k-epochs`. Might remove these in the future though. + +Quality of life, footguns, minor fixes: + +- **Hyperparameter inheritance**: SFT now inherits batch sizes and LRs from the pretrained checkpoint metadata by default (CLI overrides still work). Also saved `total_batch_size` to `base_train.py` checkpoint metadata. +- **GC management**: disabled Python GC after step 1 to avoid ~500ms pauses (manual collect every 5000 steps), same as base pretraining. +- **ChatCORE eval**: periodic eval during SFT (`--chatcore-every=200`) across all 6 tasks, logged to wandb. +- **MFU**: uses `get_peak_flops()` for actual GPU instead of hardcoded H100 value. +- Removed `--dry-run` and `--dtype` flags. All ranks now participate in checkpoint save. + +--- + ## 2026-02-05: Auto Batch Size Scaling ### Background From cac43e851142289d565c2d22fdc9904ee8b62eb1 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Wed, 18 Feb 2026 01:03:46 +0100 Subject: [PATCH 6/8] Fix MockModel's device definition (#535) * fix MockModel's device definition * cleanup --- tests/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_engine.py b/tests/test_engine.py index 0159111..784ffcb 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -31,7 +31,7 @@ class MockModel: def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens self.vocab_size = vocab_size self.config = MockConfig() - self._device = "cpu" + self._device = torch.device("cpu") def get_device(self): return self._device From ad55575326443db6deda6e19126ebf136c66d8b2 Mon Sep 17 00:00:00 2001 From: George Shakan <43767775+georgeshakan@users.noreply.github.com> Date: Wed, 18 Feb 2026 10:42:11 -0500 Subject: [PATCH 7/8] Fix bug in setting precision (#538) --- nanochat/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/common.py b/nanochat/common.py index 9bcd5dd..2dd0792 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -170,7 +170,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps # Precision if device_type == "cuda": - torch.backends.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls + torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() From bac5a35dd74e331ed6012142e0b4e8c0f0af48e8 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 18 Feb 2026 23:17:29 +0000 Subject: [PATCH 8/8] fix minor bug in fp8 application to skip tiny matmuls --- scripts/base_train.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index bb76e90..24091b6 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -170,20 +170,22 @@ if args.fp8: # from torchao.float8 import Float8LinearConfig, convert_to_float8_training import torch.nn as nn - # Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement) + # Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough def fp8_module_filter(mod: nn.Module, fqn: str) -> bool: if not isinstance(mod, nn.Linear): return False - # FP8 requires both in_features and out_features divisible by 16 if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: return False + if min(mod.in_features, mod.out_features) < 128: + return False return True fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe) + num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter) - num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) - num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers - print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)") + num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) + num_skipped = num_linear - num_fp8 + print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)") # Context manager to temporarily disable FP8 so that model evaluation remains in BF16 @contextmanager