From d9678ff0f9c5d9967512adce23cb60ea0a5cd3f3 Mon Sep 17 00:00:00 2001 From: Alan Date: Sun, 15 Feb 2026 14:31:54 +0000 Subject: [PATCH 01/16] Save FP8 tensors in autograd ctx instead of full-precision inputs Store quantized input/weight and their inverse scales in _Float8Matmul ctx to avoid re-quantization in backward and reduce saved-activation memory without changing numerics. --- nanochat/fp8.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 9d9e9c3..8649760 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -123,19 +123,16 @@ def _to_col_major(x): class _Float8Matmul(torch.autograd.Function): """Custom autograd for the three FP8 GEMMs of a Linear layer. - The forward saves input and weight in their original precision for the - backward pass. Each GEMM independently re-quantizes its operands to FP8. - (We don't reuse the forward's FP8 tensors in backward — the backward might - want different precision, and saving FP8 would lose information.) + The forward quantizes input and weight to FP8 and saves + the quantized tensors + scales for backward. """ @staticmethod def forward(ctx, input_2d, weight): - ctx.save_for_backward(input_2d, weight) - # Quantize both operands to e4m3 (higher precision format) input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn) weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn) + ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv) # output = input @ weight.T # input_fp8 is [B, K] contiguous = row-major (good for first arg) @@ -156,13 +153,12 @@ class _Float8Matmul(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): - input_2d, weight = ctx.saved_tensors + in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors # === GEMM 1: grad_input = grad_output @ weight === # Shapes: [B, N] @ [N, K] -> [B, K] # Gradients use e5m2 (wider range), weights use e4m3 (higher precision) go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2) - w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn) # go_fp8 is [B, N] contiguous = row-major, good for first arg # w_fp8 is [N, K] contiguous = row-major, need column-major for second arg w_col = _to_col_major(w_fp8) @@ -178,7 +174,6 @@ class _Float8Matmul(torch.autograd.Function): # === GEMM 2: grad_weight = grad_output.T @ input === # Shapes: [N, B] @ [B, K] -> [N, K] go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) - in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn) # go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg. # Transposing gives column-major, but first arg needs row-major, # so we must call .contiguous() to physically rearrange the memory. From 124f49be98e53bf734e2918dc58a580dbf31a80c Mon Sep 17 00:00:00 2001 From: Alan Date: Sun, 15 Feb 2026 15:41:33 +0000 Subject: [PATCH 02/16] Removed redundant qunatization of gradients --- nanochat/fp8.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 8649760..3e88285 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -173,16 +173,15 @@ class _Float8Matmul(torch.autograd.Function): # === GEMM 2: grad_weight = grad_output.T @ input === # Shapes: [N, B] @ [B, K] -> [N, K] - go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) - # go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg. + # go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg. # Transposing gives column-major, but first arg needs row-major, # so we must call .contiguous() to physically rearrange the memory. - go_T = go_fp8_2.t().contiguous() # [N, B] row-major + go_T = go_fp8.t().contiguous() # [N, B] row-major in_col = _to_col_major(in_fp8) # [B, K] column-major grad_weight = torch._scaled_mm( go_T, in_col, - scale_a=go_inv_2, + scale_a=go_inv, scale_b=in_inv, out_dtype=grad_output.dtype, use_fast_accum=False, From 8180e1d8c1c3e561b751dcfec54a74b3122c0db5 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Feb 2026 20:23:04 +0000 Subject: [PATCH 03/16] tune the data mixture a bit, load optimizer by default when SFT. These were confirmed to be best settings from sweeps of sft --- nanochat/checkpoint_manager.py | 3 +++ scripts/chat_sft.py | 33 +++++++++++++++++++++++---------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index e24533a..f71524e 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -186,6 +186,9 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None): if step is None: step = find_last_step(checkpoint_dir) optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") + if not os.path.exists(optimizer_path): + log0(f"Optimizer checkpoint not found: {optimizer_path}") + return None log0(f"Loading optimizer state from {optimizer_path}") optimizer_data = torch.load(optimizer_path, map_location=device) return optimizer_data diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index edac3d8..a783ed2 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -43,7 +43,7 @@ parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (e # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") -parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") +parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") # Training horizon parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") # Batch sizes (default: inherit from pretrained checkpoint) @@ -64,6 +64,9 @@ parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number o parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)") parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE") parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE") +# Data mixture +parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)") +parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -132,12 +135,21 @@ token_bytes = get_token_bytes(device=device) optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0) # Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) +# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the +# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and +# restore our fresh SFT LRs after loading. base_dir = get_base_dir() if args.load_optimizer: optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step) - optimizer.load_state_dict(optimizer_data) - del optimizer_data - print0("Loaded optimizer state from pretrained checkpoint") + if optimizer_data is not None: + base_lrs = [group["lr"] for group in optimizer.param_groups] + optimizer.load_state_dict(optimizer_data) + del optimizer_data + for group, base_lr in zip(optimizer.param_groups, base_lrs): + group["lr"] = base_lr + print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)") + else: + print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)") # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: @@ -146,16 +158,17 @@ for group in optimizer.param_groups: # SFT data mixture and DataLoader identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") -train_dataset = TaskMixture([ +train_tasks = [ SmolTalk(split="train"), # 460K rows of general conversations - MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE - GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use - GSM8K(subset="main", split="train"), # 2 epochs of GSM8K CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations - CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these + CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these + *[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch + *[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) -]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows +] +train_dataset = TaskMixture(train_tasks) +print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})") val_dataset = TaskMixture([ SmolTalk(split="test"), # 24K rows in test set MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios From 4a6e47b0c68aa062c62fa859aed2e8dd0d59d684 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 17 Feb 2026 15:44:54 +0000 Subject: [PATCH 04/16] update dev log with recent --- dev/LOG.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/dev/LOG.md b/dev/LOG.md index dec2c06..c0d35e4 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,38 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-17: Pretraining Data Mixture Experiment (negative) + +Tried [hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT](https://huggingface.co/datasets/hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT), a mixture of FinePDFs, DCLM, and FineWeb-EDU. Slightly worse on both model sizes tested: + +- d26 (GPT-2): CORE 0.2602 → 0.2549 +- d18: CORE 0.199 → 0.192 + +This is the fourth failed attempt to beat pure FineWeb-EDU on CORE score. + +--- + +## 2026-02-16: SFT Script Upgrades + +Brought `chat_sft.py` up to parity with `base_train.py` and tuned settings based on SFT sweeps. + +Tuning: + +- **Optimizer warm-start** (`--load-optimizer=1`, default on): loads pretrained momentum buffers via new `load_optimizer_state()` in `checkpoint_manager.py`. LRs are reset to fresh SFT values after load. Loading the optimizer works slightly better but not by too much. +- **LR schedule**: replaced "constant 80%, linear to 0" with warmup/constant/warmdown matching `base_train.py` (`--warmup-ratio`, `--warmdown-ratio`, `--init-lr-frac`, `--final-lr-frac`). Similar to pretraining, warmdown ratio of 0.5 worked the best. `--init-lr-frac` changed from 1.0 slightly lower to 0.8. +- **LR tuning**: attempted to tune all the individual LRs (e.g. does SFT prefer lower LR for embeddings? etc.) but all of this produced negative results. +- **Data mixture**: MMLU epochs 1→3, GSM8K epochs 2→4 (confirmed best from sweeps). Epoch counts now configurable via `--mmlu-epochs` / `--gsm8k-epochs`. Might remove these in the future though. + +Quality of life, footguns, minor fixes: + +- **Hyperparameter inheritance**: SFT now inherits batch sizes and LRs from the pretrained checkpoint metadata by default (CLI overrides still work). Also saved `total_batch_size` to `base_train.py` checkpoint metadata. +- **GC management**: disabled Python GC after step 1 to avoid ~500ms pauses (manual collect every 5000 steps), same as base pretraining. +- **ChatCORE eval**: periodic eval during SFT (`--chatcore-every=200`) across all 6 tasks, logged to wandb. +- **MFU**: uses `get_peak_flops()` for actual GPU instead of hardcoded H100 value. +- Removed `--dry-run` and `--dtype` flags. All ranks now participate in checkpoint save. + +--- + ## 2026-02-05: Auto Batch Size Scaling ### Background From 4800c62f6ed598accb950a8b715a6cab8a264e1e Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Wed, 18 Feb 2026 01:03:46 +0100 Subject: [PATCH 05/16] Fix MockModel's device definition (#535) * fix MockModel's device definition * cleanup --- tests/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_engine.py b/tests/test_engine.py index 0159111..784ffcb 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -31,7 +31,7 @@ class MockModel: def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens self.vocab_size = vocab_size self.config = MockConfig() - self._device = "cpu" + self._device = torch.device("cpu") def get_device(self): return self._device From 0a23f87643945410eb7c0e33951b5acfba05257c Mon Sep 17 00:00:00 2001 From: George Shakan <43767775+georgeshakan@users.noreply.github.com> Date: Wed, 18 Feb 2026 10:42:11 -0500 Subject: [PATCH 06/16] Fix bug in setting precision (#538) --- nanochat/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/common.py b/nanochat/common.py index 9bcd5dd..2dd0792 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -170,7 +170,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps # Precision if device_type == "cuda": - torch.backends.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls + torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() From 77f8fb83037d4bb294fb97f987f27c98526c1d96 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Feb 2026 14:41:53 +0000 Subject: [PATCH 07/16] a number of upgrades to SFT script to bring it up to date w.r.t. pretraining and tuning some of its kwargs based on sweeps --- nanochat/checkpoint_manager.py | 19 ++++ scripts/base_train.py | 1 + scripts/chat_sft.py | 184 +++++++++++++++++++++++++-------- 3 files changed, 159 insertions(+), 45 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 5a95fbf..e24533a 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -170,3 +170,22 @@ def load_model(source, *args, **kwargs): base_dir = get_base_dir() checkpoints_dir = os.path.join(base_dir, model_dir) return load_model_from_dir(checkpoints_dir, *args, **kwargs) + +def load_optimizer_state(source, device, rank, model_tag=None, step=None): + """Load just the optimizer shard for a given rank, without re-loading the model.""" + model_dir = { + "base": "base_checkpoints", + "sft": "chatsft_checkpoints", + "rl": "chatrl_checkpoints", + }[source] + base_dir = get_base_dir() + checkpoints_dir = os.path.join(base_dir, model_dir) + if model_tag is None: + model_tag = find_largest_model(checkpoints_dir) + checkpoint_dir = os.path.join(checkpoints_dir, model_tag) + if step is None: + step = find_last_step(checkpoint_dir) + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") + log0(f"Loading optimizer state from {optimizer_path}") + optimizer_data = torch.load(optimizer_path, map_location=device) + return optimizer_data diff --git a/scripts/base_train.py b/scripts/base_train.py index 996b2ba..bb76e90 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -468,6 +468,7 @@ while True: "user_config": user_config, # inputs to the training script "device_batch_size": args.device_batch_size, "max_seq_len": args.max_seq_len, + "total_batch_size": total_batch_size, "dataloader_state_dict": dataloader_state_dict, "loop_state": { # all loop state (other than step) so that we can resume training "min_val_bpb": min_val_bpb, diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 4c81f06..edac3d8 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -9,6 +9,7 @@ Or torchrun for training: torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16 """ +import gc import argparse import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" @@ -16,12 +17,14 @@ import time import wandb import torch from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops from nanochat.tokenizer import get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint +from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state from nanochat.loss_eval import evaluate_bpb -from nanochat.checkpoint_manager import load_model import torch.distributed as dist +from nanochat.flash_attention import HAS_FA3 +from nanochat.engine import Engine +from scripts.chat_eval import run_chat_eval from tasks.common import TaskMixture from tasks.gsm8k import GSM8K @@ -37,27 +40,30 @@ parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the m parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") -parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") +parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") # Training horizon parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") -# Batch sizes -parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") -parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") -parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") -# Optimization -parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") -parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR") +# Batch sizes (default: inherit from pretrained checkpoint) +parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)") +parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)") +parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)") +# Optimization (default: inherit from pretrained checkpoint) +parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)") +parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)") +parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)") +parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR") +parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") +parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") +parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") # Evaluation -parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") -# Output -parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report") +parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)") +parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on") +parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)") +parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE") +parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -66,20 +72,48 @@ user_config = vars(args).copy() device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 -ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 +if device_type == "cuda": + gpu_device_name = torch.cuda.get_device_name(0) + gpu_peak_flops = get_peak_flops(gpu_device_name) + print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") +else: + gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) +# Flash Attention status +if not HAS_FA3: + print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.") + # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) -pretrain_batch_size = meta.get("device_batch_size", None) -if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: - print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") + +# Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override) +pretrain_user_config = meta.get("user_config", {}) +for name, fallback, source in [ + ("max_seq_len", 2048, meta), + ("device_batch_size", 32, meta), + ("total_batch_size", 524288, meta), + ("embedding_lr", 0.3, pretrain_user_config), + ("unembedding_lr", 0.004, pretrain_user_config), + ("matrix_lr", 0.02, pretrain_user_config), +]: + arg_val = getattr(args, name) + pretrain_val = source.get(name) + if arg_val is None: + resolved = pretrain_val if pretrain_val is not None else fallback + setattr(args, name, resolved) + print0(f"Inherited {name}={resolved} from pretrained checkpoint") + elif pretrain_val is not None and arg_val != pretrain_val: + print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}") + else: + print0(f"Using {name}={arg_val}") + orig_model = model model = torch.compile(model, dynamic=False) depth = model.config.n_layer @@ -94,14 +128,23 @@ print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation ste token_bytes = get_token_bytes(device=device) # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) -optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay) +# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero +optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0) + +# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) +base_dir = get_base_dir() +if args.load_optimizer: + optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step) + optimizer.load_state_dict(optimizer_data) + del optimizer_data + print0("Loaded optimizer state from pretrained checkpoint") + # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: group["lr"] = group["lr"] * args.init_lr_frac group["initial_lr"] = group["lr"] # SFT data mixture and DataLoader -base_dir = get_base_dir() identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") train_dataset = TaskMixture([ SmolTalk(split="train"), # 460K rows of general conversations @@ -236,10 +279,17 @@ train_loader = sft_data_generator_bos_bestfit("train") build_val_loader = lambda: sft_data_generator_bos_bestfit("val") progress = 0 # will go from 0 to 1 over the course of the epoch -# Learning rate scheduler +# Learning rate schedule (linear warmup, constant, linear warmdown) +# Same shape as base_train but uses progress (0→1) instead of absolute step counts, +# because SFT doesn't always know num_iterations in advance (dataset-driven stopping). def get_lr_multiplier(progress): - # first 80% of training: no decay, then linearly ramp down to 0. - return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 + if progress < args.warmup_ratio: + return (progress + 1e-8) / args.warmup_ratio + elif progress <= 1.0 - args.warmdown_ratio: + return 1.0 + else: + decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio + return (1 - decay) * 1.0 + decay * args.final_lr_frac # Momentum scheduler for Muon optimizer def get_muon_momentum(it): @@ -282,8 +332,44 @@ while True: }) model.train() - # save checkpoint at the end of the run (only on master process) - if master_process and last_step and not args.dry_run: + # once in a while: estimate the ChatCORE metric (all ranks participate) + # use the original uncompiled model because the inputs keep changing shape + chatcore_results = {} + if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)): + model.eval() + engine = Engine(orig_model, tokenizer) + all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee'] + categorical_tasks = {'ARC-Easy', 'ARC-Challenge', 'MMLU'} + baseline_accuracies = { + 'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25, + 'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0, + } + task_results = {} + for task_name in all_tasks: + limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample + max_problems = None if limit < 0 else limit # -1 means no limit + with autocast_ctx: + acc = run_chat_eval(task_name, orig_model, tokenizer, engine, + batch_size=args.device_batch_size, max_problems=max_problems) + task_results[task_name] = acc + print0(f" {task_name}: {100*acc:.2f}%") + # Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect) + def centered_mean(tasks): + return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks) + chatcore = centered_mean(all_tasks) + chatcore_cat = centered_mean(categorical_tasks) + print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}") + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "chatcore_metric": chatcore, + "chatcore_cat": chatcore_cat, + **{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()}, + }) + model.train() + + # save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard) + if last_step: output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12 checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname) save_checkpoint( @@ -304,7 +390,8 @@ while True: "window_pattern": model.config.window_pattern, }, "user_config": user_config, # inputs to the training script - } + }, + rank=ddp_rank, ) if last_step: @@ -346,8 +433,7 @@ while True: pct_done = 100 * progress tok_per_sec = int(args.total_batch_size / dt) flops_per_sec = num_flops_per_token * args.total_batch_size / dt - promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) if step > 10: total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m") @@ -364,24 +450,32 @@ while True: "train/epoch": current_epoch, }) + # The garbage collector spends ~500ms scanning for cycles quite frequently. + # We manually manage it to avoid these pauses during training. + if step == 1: + gc.collect() # manually collect a lot of garbage from setup + gc.freeze() # freeze all currently surviving objects and exclude them from GC + gc.disable() # disable GC entirely except: + elif step % 5000 == 0: # every 5000 steps... + gc.collect() # manually collect, just to be safe for very long runs + # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") # Log to report -if not args.dry_run: - from nanochat.report import get_report - get_report().log(section="SFT", data=[ - user_config, # CLI args - { # stats about the training setup - "Number of iterations": step, - "DDP world size": ddp_world_size, - }, - { # stats about training outcomes - "Minimum validation bpb": min_val_bpb, - } - ]) +from nanochat.report import get_report +get_report().log(section="SFT", data=[ + user_config, # CLI args + { # stats about the training setup + "Number of iterations": step, + "DDP world size": ddp_world_size, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb, + } +]) # cleanup wandb_run.finish() # wandb run finish From 1415fb761797f94a4933c1a79f8d1fc2e63b9793 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Feb 2026 20:23:04 +0000 Subject: [PATCH 08/16] tune the data mixture a bit, load optimizer by default when SFT. These were confirmed to be best settings from sweeps of sft --- nanochat/checkpoint_manager.py | 3 +++ scripts/chat_sft.py | 33 +++++++++++++++++++++++---------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index e24533a..f71524e 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -186,6 +186,9 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None): if step is None: step = find_last_step(checkpoint_dir) optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") + if not os.path.exists(optimizer_path): + log0(f"Optimizer checkpoint not found: {optimizer_path}") + return None log0(f"Loading optimizer state from {optimizer_path}") optimizer_data = torch.load(optimizer_path, map_location=device) return optimizer_data diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index edac3d8..a783ed2 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -43,7 +43,7 @@ parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (e # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") -parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") +parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") # Training horizon parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") # Batch sizes (default: inherit from pretrained checkpoint) @@ -64,6 +64,9 @@ parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number o parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)") parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE") parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE") +# Data mixture +parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)") +parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -132,12 +135,21 @@ token_bytes = get_token_bytes(device=device) optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0) # Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) +# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the +# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and +# restore our fresh SFT LRs after loading. base_dir = get_base_dir() if args.load_optimizer: optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step) - optimizer.load_state_dict(optimizer_data) - del optimizer_data - print0("Loaded optimizer state from pretrained checkpoint") + if optimizer_data is not None: + base_lrs = [group["lr"] for group in optimizer.param_groups] + optimizer.load_state_dict(optimizer_data) + del optimizer_data + for group, base_lr in zip(optimizer.param_groups, base_lrs): + group["lr"] = base_lr + print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)") + else: + print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)") # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: @@ -146,16 +158,17 @@ for group in optimizer.param_groups: # SFT data mixture and DataLoader identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") -train_dataset = TaskMixture([ +train_tasks = [ SmolTalk(split="train"), # 460K rows of general conversations - MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE - GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use - GSM8K(subset="main", split="train"), # 2 epochs of GSM8K CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations - CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these + CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these + *[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch + *[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) -]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows +] +train_dataset = TaskMixture(train_tasks) +print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})") val_dataset = TaskMixture([ SmolTalk(split="test"), # 24K rows in test set MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios From f5fe7925ed913fbddbc268043c79f82c354c43de Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 17 Feb 2026 15:44:54 +0000 Subject: [PATCH 09/16] update dev log with recent --- dev/LOG.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/dev/LOG.md b/dev/LOG.md index dec2c06..c0d35e4 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,38 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-17: Pretraining Data Mixture Experiment (negative) + +Tried [hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT](https://huggingface.co/datasets/hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT), a mixture of FinePDFs, DCLM, and FineWeb-EDU. Slightly worse on both model sizes tested: + +- d26 (GPT-2): CORE 0.2602 → 0.2549 +- d18: CORE 0.199 → 0.192 + +This is the fourth failed attempt to beat pure FineWeb-EDU on CORE score. + +--- + +## 2026-02-16: SFT Script Upgrades + +Brought `chat_sft.py` up to parity with `base_train.py` and tuned settings based on SFT sweeps. + +Tuning: + +- **Optimizer warm-start** (`--load-optimizer=1`, default on): loads pretrained momentum buffers via new `load_optimizer_state()` in `checkpoint_manager.py`. LRs are reset to fresh SFT values after load. Loading the optimizer works slightly better but not by too much. +- **LR schedule**: replaced "constant 80%, linear to 0" with warmup/constant/warmdown matching `base_train.py` (`--warmup-ratio`, `--warmdown-ratio`, `--init-lr-frac`, `--final-lr-frac`). Similar to pretraining, warmdown ratio of 0.5 worked the best. `--init-lr-frac` changed from 1.0 slightly lower to 0.8. +- **LR tuning**: attempted to tune all the individual LRs (e.g. does SFT prefer lower LR for embeddings? etc.) but all of this produced negative results. +- **Data mixture**: MMLU epochs 1→3, GSM8K epochs 2→4 (confirmed best from sweeps). Epoch counts now configurable via `--mmlu-epochs` / `--gsm8k-epochs`. Might remove these in the future though. + +Quality of life, footguns, minor fixes: + +- **Hyperparameter inheritance**: SFT now inherits batch sizes and LRs from the pretrained checkpoint metadata by default (CLI overrides still work). Also saved `total_batch_size` to `base_train.py` checkpoint metadata. +- **GC management**: disabled Python GC after step 1 to avoid ~500ms pauses (manual collect every 5000 steps), same as base pretraining. +- **ChatCORE eval**: periodic eval during SFT (`--chatcore-every=200`) across all 6 tasks, logged to wandb. +- **MFU**: uses `get_peak_flops()` for actual GPU instead of hardcoded H100 value. +- Removed `--dry-run` and `--dtype` flags. All ranks now participate in checkpoint save. + +--- + ## 2026-02-05: Auto Batch Size Scaling ### Background From cac43e851142289d565c2d22fdc9904ee8b62eb1 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Wed, 18 Feb 2026 01:03:46 +0100 Subject: [PATCH 10/16] Fix MockModel's device definition (#535) * fix MockModel's device definition * cleanup --- tests/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_engine.py b/tests/test_engine.py index 0159111..784ffcb 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -31,7 +31,7 @@ class MockModel: def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens self.vocab_size = vocab_size self.config = MockConfig() - self._device = "cpu" + self._device = torch.device("cpu") def get_device(self): return self._device From ad55575326443db6deda6e19126ebf136c66d8b2 Mon Sep 17 00:00:00 2001 From: George Shakan <43767775+georgeshakan@users.noreply.github.com> Date: Wed, 18 Feb 2026 10:42:11 -0500 Subject: [PATCH 11/16] Fix bug in setting precision (#538) --- nanochat/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/common.py b/nanochat/common.py index 9bcd5dd..2dd0792 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -170,7 +170,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps # Precision if device_type == "cuda": - torch.backends.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls + torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() From bac5a35dd74e331ed6012142e0b4e8c0f0af48e8 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 18 Feb 2026 23:17:29 +0000 Subject: [PATCH 12/16] fix minor bug in fp8 application to skip tiny matmuls --- scripts/base_train.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index bb76e90..24091b6 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -170,20 +170,22 @@ if args.fp8: # from torchao.float8 import Float8LinearConfig, convert_to_float8_training import torch.nn as nn - # Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement) + # Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough def fp8_module_filter(mod: nn.Module, fqn: str) -> bool: if not isinstance(mod, nn.Linear): return False - # FP8 requires both in_features and out_features divisible by 16 if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: return False + if min(mod.in_features, mod.out_features) < 128: + return False return True fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe) + num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter) - num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) - num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers - print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)") + num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) + num_skipped = num_linear - num_fp8 + print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)") # Context manager to temporarily disable FP8 so that model evaluation remains in BF16 @contextmanager From bb5137860e24efa995b60468e7b867206ae9dd5c Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 18 Feb 2026 23:26:22 +0000 Subject: [PATCH 13/16] fix comment --- runs/speedrun.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runs/speedrun.sh b/runs/speedrun.sh index 62466c7..c757253 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -69,7 +69,7 @@ python -m scripts.tok_eval echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID -# d24 model (slightly overtrained is enough to beat GPT-2 => increase data:params ratio from compute optimal 10.5 (default) to 12) +# d26 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8.25) torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --target-param-data-ratio=8.25 --device-batch-size=16 --fp8 --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 From 48804bff3a487e43ee1e1533b3cfa0aa5ab0028f Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 18 Feb 2026 23:45:31 +0000 Subject: [PATCH 14/16] report negative result on fineweb dataset --- dev/LOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/dev/LOG.md b/dev/LOG.md index c0d35e4..6ac027c 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,16 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-17: Pretraining Data: FineWeb (negative) + +Tried vanilla fineweb instead of fineweb-edu dataset. Significantly, shockingly worse results: + +- d26 (GPT-2): CORE 0.2602 → 0.2241 + +This is the fifth failed attempt to beat pure FineWeb-EDU on CORE score. + +--- + ## 2026-02-17: Pretraining Data Mixture Experiment (negative) Tried [hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT](https://huggingface.co/datasets/hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT), a mixture of FinePDFs, DCLM, and FineWeb-EDU. Slightly worse on both model sizes tested: From 2dffdc8cf6953c5dc10f1caf37016e9daa675b09 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 19 Feb 2026 02:53:47 +0000 Subject: [PATCH 15/16] document MoE exploration --- dev/LOG.md | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/dev/LOG.md b/dev/LOG.md index 6ac027c..0dfaa98 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,53 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-19: Mixture of Experts (negative) + +Implemented a DeepSeekV3-style Mixture of Experts layer as a drop-in replacement for the dense MLP. The MoE branch works and improves per-step validation loss, but is not a net improvement on wall clock time due to MoE overhead (at least for our scale of interest of approx GPT-2 capability). + +### Implementation + +Follows DeepSeekV3 and using torchtitan as reference: + +- **8 routed experts, top-2 routing** with sigmoid gating (not softmax) +- **1 shared expert** (dense MLP processing all tokens, following DeepSeekV3) +- **Auxiliary-loss-free load balancing** (DeepSeekV3's expert bias nudging) +- **Iso-FLOP sizing**: `expert_hidden_dim = round(4 * dim / (top_k + num_shared) / 128) * 128`, so active FLOPs per token match the dense MLP +- **`torch._grouped_mm`** for dispatching tokens to experts in a single kernel (instead of a Python for-loop) +- **3D expert weight tensors** `(num_experts, hidden, dim)` — Muon's Polar Express operates on the last two dims, so each expert is independently orthogonalized +- **Active parameter counting** for scaling laws (only `top_k + shared` experts, not all 8) + +### What was easy + +- The core MoE forward pass: router, sort tokens by expert, grouped matmul, scatter back. Conceptually clean. +- Shared expert: just an `nn.Linear` MLP that runs on all tokens alongside the routed path. +- 3D expert params + Muon: only required fixing `second_momentum_buffer` shape to preserve leading dims. +- Load balancing: DeepSeekV3's bias nudging is simple and effective (~10 lines). + +### What was hard / ugly + +- **`torch._grouped_mm` quirks**: requires bf16 (not fp32), column-major right operand, int32 cumulative offsets. The API is undocumented and only discoverable by trial and error. +- **Token count padding**: torchtitan pads each expert's token count to alignment multiples (8 for bf16) for better grouped_mm throughput. We implemented this with both a pure PyTorch approach and a copy of torchtitan's Triton kernel. Both compiled cleanly (0 graph breaks), but with ~65K tokens across 8 experts, each expert already gets ~8K tokens which is well-aligned. The padding overhead (gather/scatter) actually regressed MFU from 35% to 33%. Reverted. +- **FP8 + MoE**: `torch._grouped_mm` does NOT support FP8. There's a separate `torch._scaled_grouped_mm` API that requires per-row scaling (not per-tensor like our `Float8Linear`). The backward pass for weight gradients needs per-group column-wise scaling, which torchao implements with custom Triton kernels. We investigated thoroughly (see `dev/moe_fp8.md`) but did not implement — would require either depending on `torchao.prototype` (unstable) or writing ~200 lines of custom autograd + quantization code. Partial FP8 support exists: the shared expert's `nn.Linear` layers do get converted, but the routed experts (3D `nn.Parameter`) stay in bf16. + +### Results + +- d18: MFU dropped from ~46% to ~35% (the grouped_mm dispatch + token sorting overhead is significant) +- Per-step improvement in validation loss does not compensate for the throughput hit +- Net negative on wall clock time + +### What remains (if revisited) + +- **FP8 for routed experts**: Use `torch._scaled_grouped_mm` with a custom `_Float8GroupedMatmul` autograd function, with bf16 fallback for weight gradient (avoiding the per-group column-wise Triton kernels). + +What's really needed is a fused "FlashMoE" kernel that handles routing + expert dispatch + matmul in one shot (like FlashAttention did for attention), with all the needed features. This doesn't exist yet. Rawdogging MoE with current PyTorch primitives is painful — lots of sorting, gathering, scattering, and layout wrangling around the actual compute. + +### Verdict + +MoE is not worth the trouble for nanochat right now. The code bloat is substantial (moe.py, router, shared expert, load balancing, optimizer fixes, FP8 gaps, active param counting) and the performance is worse wall-clock at our scale of interest. The fundamental issue is that the grouped_mm dispatch overhead eats the FLOP savings from sparsity, at least at our model scales and sequence lengths. + +--- + ## 2026-02-17: Pretraining Data: FineWeb (negative) Tried vanilla fineweb instead of fineweb-edu dataset. Significantly, shockingly worse results: From c7ba25214276d165eeefca7cb2060587975db189 Mon Sep 17 00:00:00 2001 From: Dipesh Babu <59379458+dipeshbabu@users.noreply.github.com> Date: Fri, 20 Feb 2026 11:03:45 -0500 Subject: [PATCH 16/16] docs: fix typos in experiment log (#547) --- dev/LOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/LOG.md b/dev/LOG.md index 0dfaa98..fce90fd 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -749,7 +749,7 @@ See the branch `fp8_attempt_fail` for: ### Open Questions - Why does the custom op approach use more memory than vanilla BF16? -- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Ahmdal's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized. +- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Amdahl's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized. **Conclusion:** Negative result for now. The implementation works correctly but provides marginal speedup with *increased* memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao. @@ -913,7 +913,7 @@ Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon i - Now defaults to ON for Muon via the `weight_decay` param. AdamW still has no weight decay and is hardcoded to 0 weight decay, might try to re-tune this later. **4. Weight decay schedule** -- Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, them ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is imlpemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule). +- Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, then ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is implemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule). ### Weight Decay Scaling Experiments @@ -957,6 +957,6 @@ Muon was changed to use Polar Express, added NorMuon variance reduction, and cau **Bug Found:** Original implementation clipped local gradients before sync. Since this codebase doesn't use DDP (gradient sync is in the optimizers), each rank was clipping based on its own local norm. Fixed on the branch with proper distributed all-reduce. -**Observartion:** modded-nanogpt does not appear to clip either right now. +**Observation:** modded-nanogpt does not appear to clip either right now. **Summary:** Deleted all grad-clip code paths. The code naturally produces well-behaved gradients. This improves a bit of MFU because we don't have to calculate and sync grad norms.