tune the data mixture a bit, load optimizer by default when SFT. These were confirmed to be best settings from sweeps of sft

This commit is contained in:
Andrej Karpathy 2026-02-16 20:23:04 +00:00
parent 788dadeb88
commit 8180e1d8c1
2 changed files with 26 additions and 10 deletions

View File

@ -186,6 +186,9 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None):
if step is None:
step = find_last_step(checkpoint_dir)
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
if not os.path.exists(optimizer_path):
log0(f"Optimizer checkpoint not found: {optimizer_path}")
return None
log0(f"Loading optimizer state from {optimizer_path}")
optimizer_data = torch.load(optimizer_path, map_location=device)
return optimizer_data

View File

@ -43,7 +43,7 @@ parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (e
# Model loading
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
# Training horizon
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
# Batch sizes (default: inherit from pretrained checkpoint)
@ -64,6 +64,9 @@ parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number o
parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)")
parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE")
parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE")
# Data mixture
parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
@ -132,12 +135,21 @@ token_bytes = get_token_bytes(device=device)
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the
# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and
# restore our fresh SFT LRs after loading.
base_dir = get_base_dir()
if args.load_optimizer:
optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step)
optimizer.load_state_dict(optimizer_data)
del optimizer_data
print0("Loaded optimizer state from pretrained checkpoint")
if optimizer_data is not None:
base_lrs = [group["lr"] for group in optimizer.param_groups]
optimizer.load_state_dict(optimizer_data)
del optimizer_data
for group, base_lr in zip(optimizer.param_groups, base_lrs):
group["lr"] = base_lr
print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)")
else:
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
# Override the initial learning rate as a fraction of the base learning rate
for group in optimizer.param_groups:
@ -146,16 +158,17 @@ for group in optimizer.param_groups:
# SFT data mixture and DataLoader
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([
train_tasks = [
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
GSM8K(subset="main", split="train"), # 2 epochs of GSM8K
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these
*[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch
*[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows
]
train_dataset = TaskMixture(train_tasks)
print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})")
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios