From 8180e1d8c1c3e561b751dcfec54a74b3122c0db5 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Feb 2026 20:23:04 +0000 Subject: [PATCH] tune the data mixture a bit, load optimizer by default when SFT. These were confirmed to be best settings from sweeps of sft --- nanochat/checkpoint_manager.py | 3 +++ scripts/chat_sft.py | 33 +++++++++++++++++++++++---------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index e24533a..f71524e 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -186,6 +186,9 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None): if step is None: step = find_last_step(checkpoint_dir) optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") + if not os.path.exists(optimizer_path): + log0(f"Optimizer checkpoint not found: {optimizer_path}") + return None log0(f"Loading optimizer state from {optimizer_path}") optimizer_data = torch.load(optimizer_path, map_location=device) return optimizer_data diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index edac3d8..a783ed2 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -43,7 +43,7 @@ parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (e # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") -parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") +parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") # Training horizon parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") # Batch sizes (default: inherit from pretrained checkpoint) @@ -64,6 +64,9 @@ parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number o parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)") parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE") parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE") +# Data mixture +parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)") +parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -132,12 +135,21 @@ token_bytes = get_token_bytes(device=device) optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0) # Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) +# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the +# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and +# restore our fresh SFT LRs after loading. base_dir = get_base_dir() if args.load_optimizer: optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step) - optimizer.load_state_dict(optimizer_data) - del optimizer_data - print0("Loaded optimizer state from pretrained checkpoint") + if optimizer_data is not None: + base_lrs = [group["lr"] for group in optimizer.param_groups] + optimizer.load_state_dict(optimizer_data) + del optimizer_data + for group, base_lr in zip(optimizer.param_groups, base_lrs): + group["lr"] = base_lr + print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)") + else: + print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)") # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: @@ -146,16 +158,17 @@ for group in optimizer.param_groups: # SFT data mixture and DataLoader identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") -train_dataset = TaskMixture([ +train_tasks = [ SmolTalk(split="train"), # 460K rows of general conversations - MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE - GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use - GSM8K(subset="main", split="train"), # 2 epochs of GSM8K CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations - CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these + CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these + *[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch + *[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) -]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows +] +train_dataset = TaskMixture(train_tasks) +print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})") val_dataset = TaskMixture([ SmolTalk(split="test"), # 24K rows in test set MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios