mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-16 01:02:18 +00:00
Compare commits
3 Commits
ed63afa1cb
...
0a6dbb3361
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a6dbb3361 | ||
|
|
190d9515d0 | ||
|
|
b8076dd367 |
|
|
@ -11,7 +11,6 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
|
|||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import copy
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
|
|
@ -23,11 +22,9 @@ from nanochat.checkpoint_manager import save_checkpoint
|
|||
from nanochat.engine import Engine
|
||||
from scripts.chat_eval import run_chat_eval
|
||||
|
||||
from tasks.common import TaskMixture, TaskSequence
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.common import TaskMixture
|
||||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.humaneval import HumanEval
|
||||
from tasks.smoltalk import SmolTalk
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -186,7 +183,7 @@ for step in range(num_iterations):
|
|||
})
|
||||
model.train()
|
||||
|
||||
# evlauate MMLU accuracy
|
||||
# evlauate accuracy of the multiple choice tasks (which are quick to run)
|
||||
if last_step or (step > 0 and step % eval_metrics_every == 0):
|
||||
model.eval()
|
||||
metrics = {}
|
||||
|
|
@ -194,8 +191,6 @@ for step in range(num_iterations):
|
|||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["gsm8k_acc"] = run_chat_eval("GSM8K", model, tokenizer, engine, max_problems=64)
|
||||
metrics["humaneval_acc"] = run_chat_eval("HumanEval", model, tokenizer, engine, max_problems=64)
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
wandb_run.log({
|
||||
|
|
|
|||
|
|
@ -40,10 +40,10 @@ embedding_lr = 0.2
|
|||
matrix_lr = 0.02
|
||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||
weight_decay = 0.0
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
eval_every = 150
|
||||
eval_tokens = 20*524288
|
||||
total_batch_size = 524288
|
||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
|
|
@ -141,7 +141,8 @@ progress = 0 # will go from 0 to 1 over the course of the epoch
|
|||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(progress):
|
||||
return progress * 1.0 + (1 - progress) * final_lr_frac
|
||||
# first 80% of training: no decay, then linearly ramp down to 0.
|
||||
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
|
|
@ -185,7 +186,7 @@ while True:
|
|||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step:
|
||||
if master_process and last_step and not dry_run:
|
||||
output_dirname = f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
|
|
@ -272,17 +273,18 @@ print0(f"Total training time: {total_training_time/60:.2f}m")
|
|||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Midtraining", data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of iterations": step,
|
||||
"DDP world size": ddp_world_size,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
}
|
||||
])
|
||||
if not dry_run:
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Midtraining", data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of iterations": step,
|
||||
"DDP world size": ddp_world_size,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
}
|
||||
])
|
||||
|
||||
# cleanup
|
||||
wandb_run.finish() # wandb run finish
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user