mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-07 09:50:28 +00:00
a number of upgrades to SFT script to bring it up to date w.r.t. pretraining and tuning some of its kwargs based on sweeps
This commit is contained in:
parent
124f49be98
commit
77f8fb8303
|
|
@ -170,3 +170,22 @@ def load_model(source, *args, **kwargs):
|
|||
base_dir = get_base_dir()
|
||||
checkpoints_dir = os.path.join(base_dir, model_dir)
|
||||
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
|
||||
|
||||
def load_optimizer_state(source, device, rank, model_tag=None, step=None):
|
||||
"""Load just the optimizer shard for a given rank, without re-loading the model."""
|
||||
model_dir = {
|
||||
"base": "base_checkpoints",
|
||||
"sft": "chatsft_checkpoints",
|
||||
"rl": "chatrl_checkpoints",
|
||||
}[source]
|
||||
base_dir = get_base_dir()
|
||||
checkpoints_dir = os.path.join(base_dir, model_dir)
|
||||
if model_tag is None:
|
||||
model_tag = find_largest_model(checkpoints_dir)
|
||||
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
||||
if step is None:
|
||||
step = find_last_step(checkpoint_dir)
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||
log0(f"Loading optimizer state from {optimizer_path}")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
return optimizer_data
|
||||
|
|
|
|||
|
|
@ -468,6 +468,7 @@ while True:
|
|||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": args.device_batch_size,
|
||||
"max_seq_len": args.max_seq_len,
|
||||
"total_batch_size": total_batch_size,
|
||||
"dataloader_state_dict": dataloader_state_dict,
|
||||
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||
"min_val_bpb": min_val_bpb,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ Or torchrun for training:
|
|||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16
|
||||
"""
|
||||
|
||||
import gc
|
||||
import argparse
|
||||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
|
@ -16,12 +17,14 @@ import time
|
|||
import wandb
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
import torch.distributed as dist
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from nanochat.engine import Engine
|
||||
from scripts.chat_eval import run_chat_eval
|
||||
|
||||
from tasks.common import TaskMixture
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
|
@ -37,27 +40,30 @@ parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the m
|
|||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
|
||||
# Training horizon
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
||||
# Batch sizes
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
|
||||
# Batch sizes (default: inherit from pretrained checkpoint)
|
||||
parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)")
|
||||
parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)")
|
||||
parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)")
|
||||
# Optimization (default: inherit from pretrained checkpoint)
|
||||
parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR")
|
||||
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
# Output
|
||||
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
|
||||
parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE")
|
||||
parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -66,20 +72,48 @@ user_config = vars(args).copy()
|
|||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
if device_type == "cuda":
|
||||
gpu_device_name = torch.cuda.get_device_name(0)
|
||||
gpu_peak_flops = get_peak_flops(gpu_device_name)
|
||||
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
||||
else:
|
||||
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
if not HAS_FA3:
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.")
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
||||
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
|
||||
|
||||
# Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override)
|
||||
pretrain_user_config = meta.get("user_config", {})
|
||||
for name, fallback, source in [
|
||||
("max_seq_len", 2048, meta),
|
||||
("device_batch_size", 32, meta),
|
||||
("total_batch_size", 524288, meta),
|
||||
("embedding_lr", 0.3, pretrain_user_config),
|
||||
("unembedding_lr", 0.004, pretrain_user_config),
|
||||
("matrix_lr", 0.02, pretrain_user_config),
|
||||
]:
|
||||
arg_val = getattr(args, name)
|
||||
pretrain_val = source.get(name)
|
||||
if arg_val is None:
|
||||
resolved = pretrain_val if pretrain_val is not None else fallback
|
||||
setattr(args, name, resolved)
|
||||
print0(f"Inherited {name}={resolved} from pretrained checkpoint")
|
||||
elif pretrain_val is not None and arg_val != pretrain_val:
|
||||
print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}")
|
||||
else:
|
||||
print0(f"Using {name}={arg_val}")
|
||||
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
|
|
@ -94,14 +128,23 @@ print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation ste
|
|||
token_bytes = get_token_bytes(device=device)
|
||||
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
||||
# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero
|
||||
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
|
||||
|
||||
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
|
||||
base_dir = get_base_dir()
|
||||
if args.load_optimizer:
|
||||
optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step)
|
||||
optimizer.load_state_dict(optimizer_data)
|
||||
del optimizer_data
|
||||
print0("Loaded optimizer state from pretrained checkpoint")
|
||||
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
# SFT data mixture and DataLoader
|
||||
base_dir = get_base_dir()
|
||||
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
||||
train_dataset = TaskMixture([
|
||||
SmolTalk(split="train"), # 460K rows of general conversations
|
||||
|
|
@ -236,10 +279,17 @@ train_loader = sft_data_generator_bos_bestfit("train")
|
|||
build_val_loader = lambda: sft_data_generator_bos_bestfit("val")
|
||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||
|
||||
# Learning rate scheduler
|
||||
# Learning rate schedule (linear warmup, constant, linear warmdown)
|
||||
# Same shape as base_train but uses progress (0→1) instead of absolute step counts,
|
||||
# because SFT doesn't always know num_iterations in advance (dataset-driven stopping).
|
||||
def get_lr_multiplier(progress):
|
||||
# first 80% of training: no decay, then linearly ramp down to 0.
|
||||
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
|
||||
if progress < args.warmup_ratio:
|
||||
return (progress + 1e-8) / args.warmup_ratio
|
||||
elif progress <= 1.0 - args.warmdown_ratio:
|
||||
return 1.0
|
||||
else:
|
||||
decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio
|
||||
return (1 - decay) * 1.0 + decay * args.final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
|
|
@ -282,8 +332,44 @@ while True:
|
|||
})
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step and not args.dry_run:
|
||||
# once in a while: estimate the ChatCORE metric (all ranks participate)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
chatcore_results = {}
|
||||
if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)):
|
||||
model.eval()
|
||||
engine = Engine(orig_model, tokenizer)
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
|
||||
categorical_tasks = {'ARC-Easy', 'ARC-Challenge', 'MMLU'}
|
||||
baseline_accuracies = {
|
||||
'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25,
|
||||
'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0,
|
||||
}
|
||||
task_results = {}
|
||||
for task_name in all_tasks:
|
||||
limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
|
||||
max_problems = None if limit < 0 else limit # -1 means no limit
|
||||
with autocast_ctx:
|
||||
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
|
||||
batch_size=args.device_batch_size, max_problems=max_problems)
|
||||
task_results[task_name] = acc
|
||||
print0(f" {task_name}: {100*acc:.2f}%")
|
||||
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
|
||||
def centered_mean(tasks):
|
||||
return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks)
|
||||
chatcore = centered_mean(all_tasks)
|
||||
chatcore_cat = centered_mean(categorical_tasks)
|
||||
print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"chatcore_metric": chatcore,
|
||||
"chatcore_cat": chatcore_cat,
|
||||
**{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()},
|
||||
})
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard)
|
||||
if last_step:
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
|
|
@ -304,7 +390,8 @@ while True:
|
|||
"window_pattern": model.config.window_pattern,
|
||||
},
|
||||
"user_config": user_config, # inputs to the training script
|
||||
}
|
||||
},
|
||||
rank=ddp_rank,
|
||||
)
|
||||
|
||||
if last_step:
|
||||
|
|
@ -346,8 +433,7 @@ while True:
|
|||
pct_done = 100 * progress
|
||||
tok_per_sec = int(args.total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
|
||||
|
|
@ -364,24 +450,32 @@ while True:
|
|||
"train/epoch": current_epoch,
|
||||
})
|
||||
|
||||
# The garbage collector spends ~500ms scanning for cycles quite frequently.
|
||||
# We manually manage it to avoid these pauses during training.
|
||||
if step == 1:
|
||||
gc.collect() # manually collect a lot of garbage from setup
|
||||
gc.freeze() # freeze all currently surviving objects and exclude them from GC
|
||||
gc.disable() # disable GC entirely except:
|
||||
elif step % 5000 == 0: # every 5000 steps...
|
||||
gc.collect() # manually collect, just to be safe for very long runs
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
if not args.dry_run:
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="SFT", data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of iterations": step,
|
||||
"DDP world size": ddp_world_size,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
}
|
||||
])
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="SFT", data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of iterations": step,
|
||||
"DDP world size": ddp_world_size,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
}
|
||||
])
|
||||
|
||||
# cleanup
|
||||
wandb_run.finish() # wandb run finish
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user