mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-15 01:13:15 +00:00
tune the data mixture a bit, load optimizer by default when SFT. These were confirmed to be best settings from sweeps of sft
This commit is contained in:
parent
788dadeb88
commit
8180e1d8c1
|
|
@ -186,6 +186,9 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None):
|
|||
if step is None:
|
||||
step = find_last_step(checkpoint_dir)
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||
if not os.path.exists(optimizer_path):
|
||||
log0(f"Optimizer checkpoint not found: {optimizer_path}")
|
||||
return None
|
||||
log0(f"Loading optimizer state from {optimizer_path}")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
return optimizer_data
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (e
|
|||
# Model loading
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
|
||||
parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
|
||||
# Training horizon
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
||||
# Batch sizes (default: inherit from pretrained checkpoint)
|
||||
|
|
@ -64,6 +64,9 @@ parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number o
|
|||
parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE")
|
||||
parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE")
|
||||
# Data mixture
|
||||
parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
|
||||
parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -132,12 +135,21 @@ token_bytes = get_token_bytes(device=device)
|
|||
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
|
||||
|
||||
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
|
||||
# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the
|
||||
# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and
|
||||
# restore our fresh SFT LRs after loading.
|
||||
base_dir = get_base_dir()
|
||||
if args.load_optimizer:
|
||||
optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step)
|
||||
optimizer.load_state_dict(optimizer_data)
|
||||
del optimizer_data
|
||||
print0("Loaded optimizer state from pretrained checkpoint")
|
||||
if optimizer_data is not None:
|
||||
base_lrs = [group["lr"] for group in optimizer.param_groups]
|
||||
optimizer.load_state_dict(optimizer_data)
|
||||
del optimizer_data
|
||||
for group, base_lr in zip(optimizer.param_groups, base_lrs):
|
||||
group["lr"] = base_lr
|
||||
print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)")
|
||||
else:
|
||||
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
|
||||
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for group in optimizer.param_groups:
|
||||
|
|
@ -146,16 +158,17 @@ for group in optimizer.param_groups:
|
|||
|
||||
# SFT data mixture and DataLoader
|
||||
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
||||
train_dataset = TaskMixture([
|
||||
train_tasks = [
|
||||
SmolTalk(split="train"), # 460K rows of general conversations
|
||||
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
|
||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||
GSM8K(subset="main", split="train"), # 2 epochs of GSM8K
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these
|
||||
*[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch
|
||||
*[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch
|
||||
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows
|
||||
]
|
||||
train_dataset = TaskMixture(train_tasks)
|
||||
print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})")
|
||||
val_dataset = TaskMixture([
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user