From 4468824a6effed7fdd89c9f62e537d1ba0b8a1e2 Mon Sep 17 00:00:00 2001 From: Sushrut Karnik Date: Fri, 13 Mar 2026 23:23:21 +0100 Subject: [PATCH] reduce eval problems for mmlu, humaneval, gsm8k, spellingbee, super slow --- nanochat/common.py | 2 +- nanochat/execution.py | 4 +- runs/runcpu.sh | 19 +- scripts/base_train.py | 2 + scripts/chat_eval.py | 8 +- scripts/chat_sft.py | 961 +++++++++++++++++++++--------------------- 6 files changed, 512 insertions(+), 484 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index f7ace6e..39a48af 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -31,7 +31,7 @@ def _detect_compute_dtype(): if torch.backends.mps.is_available(): # torch.float16 # return torch.float16, "auto-detected: mps, float16" - return torch.float32, "auto-detected: mps, float16" + return torch.float32, "auto-detected: mps, float32" return torch.float32, "auto-detected: no CUDA (CPU/MPS)" COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype() diff --git a/nanochat/execution.py b/nanochat/execution.py index 6f50c74..3d9db98 100644 --- a/nanochat/execution.py +++ b/nanochat/execution.py @@ -306,7 +306,6 @@ def execute_code( >>> result.stdout 'hello world\\n' """ - manager = multiprocessing.Manager() result_dict = manager.dict() @@ -347,3 +346,6 @@ def execute_code( memory_exceeded=result_dict["memory_exceeded"], ) +# if __name__ == '__main__': +# multiprocessing.freeze_support() +# # manager = multiprocessing.Manager() diff --git a/runs/runcpu.sh b/runs/runcpu.sh index a156979..4e3807e 100755 --- a/runs/runcpu.sh +++ b/runs/runcpu.sh @@ -31,16 +31,16 @@ if [ -z "$WANDB_RUN" ]; then WANDB_RUN=dummy fi -# train tokenizer on ~2B characters (~34 seconds on my MacBook Pro M3 Max) +# # train tokenizer on ~2B characters (~34 seconds on my MacBook Pro M3 Max) python -m nanochat.dataset -n 8 python -m scripts.tok_train --max-chars=2000000000 python -m scripts.tok_eval -# train a small 4 layer model -# I tuned this run to complete in about 30 minutes on my MacBook Pro M3 Max. -# To get better results, try increasing num_iterations, or get other ideas from your favorite LLM. +# # train a small 4 layer model +# # I tuned this run to complete in about 30 minutes on my MacBook Pro M3 Max. +# # To get better results, try increasing num_iterations, or get other ideas from your favorite LLM. python -m scripts.base_train \ - --depth=6 \ + --depth=24 \ --head-dim=64 \ --window-pattern=L \ --max-seq-len=512 \ @@ -50,7 +50,7 @@ python -m scripts.base_train \ --eval-tokens=524288 \ --core-metric-every=-1 \ --sample-every=100 \ - --num-iterations=100 \ + --num-iterations=1000 \ --run=$WANDB_RUN python -m scripts.base_eval --device-batch-size=1 --split-tokens=16384 --max-per-task=16 @@ -60,16 +60,19 @@ python -m scripts.chat_sft \ --max-seq-len=512 \ --device-batch-size=32 \ --total-batch-size=16384 \ - --eval-every=200 \ + --eval-every=-1 \ --eval-tokens=524288 \ --num-iterations=1500 \ --run=$WANDB_RUN +#reduced mmlu humaneval and spellingbee to just 10 problems, slower than expected +python -m scripts.chat_eval -i sft + # Chat with the model over CLI # The model should be able to say that it is Paris. # It might even know that the color of the sky is blue. # Sometimes the model likes it if you first say Hi before you ask it questions. -# python -m scripts.chat_cli -p "What is the capital of France?" +python -m scripts.chat_cli -p "What is the capital of France?" # Chat with the model over a pretty WebUI ChatGPT style # python -m scripts.chat_web diff --git a/scripts/base_train.py b/scripts/base_train.py index a9a4534..bab8388 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -88,6 +88,8 @@ device_type = autodetect_device_type() if args.device_type == "" else args.devic ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None +synchronize = torch.mps.synchronize if device_type == "mps" else synchronize + get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 if device_type == "cuda": gpu_device_name = torch.cuda.get_device_name(0) diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index 858d4c2..4a1164c 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -119,6 +119,10 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems # The much harder alternative would be to just generate from the Assistant and check if it responded with the correct # letter (e.g. A, B, C, D), but evaluations typically make the task easier in this way. for idx, conversation in enumerate(conversations): + + if i % 10 == 0 and idx % 10 == 0: + print0(f"run_categorical_eval Batch {i}/{num_batches} Processing conversation idx: {idx}/{batch_size}") + # get the token ids of all the available letters of this problem letters = conversation['letters'] letter_ids = [] @@ -215,6 +219,8 @@ if __name__ == "__main__": # Run all the task evaluations sequentially results = {} for task_name in task_names: + max_problems = 10 if (device_type != "cuda" and task_name in ["MMLU", "GSM8K", "HumanEval", "SpellingBee"]) else args.max_problems + acc = run_chat_eval( task_name, model, tokenizer, engine, @@ -223,7 +229,7 @@ if __name__ == "__main__": max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, - max_problems=args.max_problems, + max_problems=max_problems, ) results[task_name] = acc print0(f"{task_name} accuracy: {100 * acc:.2f}%") diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c8382f9..3d125df 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -8,512 +8,527 @@ Or torchrun for training: torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16 """ +if __name__ == '__main__': + import gc + import argparse + import os + os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" + import time + import wandb + import torch + from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized + from nanochat.tokenizer import get_token_bytes + from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state + from nanochat.loss_eval import evaluate_bpb + import torch.distributed as dist + from nanochat.flash_attention import HAS_FA3 + from nanochat.engine import Engine + from scripts.chat_eval import run_chat_eval -import gc -import argparse -import os -os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" -import time -import wandb -import torch -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized -from nanochat.tokenizer import get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state -from nanochat.loss_eval import evaluate_bpb -import torch.distributed as dist -from nanochat.flash_attention import HAS_FA3 -from nanochat.engine import Engine -from scripts.chat_eval import run_chat_eval + from tasks.common import TaskMixture + from tasks.gsm8k import GSM8K + from tasks.mmlu import MMLU + from tasks.smoltalk import SmolTalk + from tasks.customjson import CustomJSON + from tasks.spellingbee import SimpleSpelling, SpellingBee -from tasks.common import TaskMixture -from tasks.gsm8k import GSM8K -from tasks.mmlu import MMLU -from tasks.smoltalk import SmolTalk -from tasks.customjson import CustomJSON -from tasks.spellingbee import SimpleSpelling, SpellingBee + # ----------------------------------------------------------------------------- + # CLI arguments + parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the model") + # Logging + parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") + # Runtime + parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") + # Model loading + parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") + parser.add_argument("--model-step", type=int, default=None, help="model step to load from") + parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") + # Training horizon + parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") + # Batch sizes (default: inherit from pretrained checkpoint) + parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)") + parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)") + parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)") + # Optimization (default: inherit from pretrained checkpoint) + parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)") + parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)") + parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)") + parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR") + parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") + parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") + parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") + # Evaluation + parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)") + parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on") + parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)") + parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE") + parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE") + # Data mixture + parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)") + parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)") + args = parser.parse_args() + user_config = vars(args).copy() + # ----------------------------------------------------------------------------- -# ----------------------------------------------------------------------------- -# CLI arguments -parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the model") -# Logging -parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") -# Runtime -parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") -# Model loading -parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") -parser.add_argument("--model-step", type=int, default=None, help="model step to load from") -parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") -# Training horizon -parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") -# Batch sizes (default: inherit from pretrained checkpoint) -parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)") -parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)") -parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)") -# Optimization (default: inherit from pretrained checkpoint) -parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)") -parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)") -parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)") -parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR") -parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") -parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") -parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") -# Evaluation -parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on") -parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)") -parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE") -parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE") -# Data mixture -parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)") -parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)") -args = parser.parse_args() -user_config = vars(args).copy() -# ----------------------------------------------------------------------------- + # Compute init + device_type = autodetect_device_type() if args.device_type == "" else args.device_type -# Compute init -device_type = autodetect_device_type() if args.device_type == "" else args.device_type -ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) -master_process = ddp_rank == 0 -print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})") -synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None -get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 -if device_type == "cuda": - gpu_device_name = torch.cuda.get_device_name(0) - gpu_peak_flops = get_peak_flops(gpu_device_name) - print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") -else: - gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS + ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) + master_process = ddp_rank == 0 + print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})") + synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None + synchronize = torch.mps.synchronize if device_type == "mps" else synchronize -# wandb logging init -use_dummy_wandb = args.run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) - -# Flash Attention status -if not HAS_FA3: - print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.") - -# Load the model and tokenizer -model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) - -# Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override) -pretrain_user_config = meta.get("user_config", {}) -for name, fallback, source in [ - ("max_seq_len", 2048, meta), - ("device_batch_size", 32, meta), - ("total_batch_size", 524288, meta), - ("embedding_lr", 0.3, pretrain_user_config), - ("unembedding_lr", 0.004, pretrain_user_config), - ("matrix_lr", 0.02, pretrain_user_config), -]: - arg_val = getattr(args, name) - pretrain_val = source.get(name) - if arg_val is None: - resolved = pretrain_val if pretrain_val is not None else fallback - setattr(args, name, resolved) - print0(f"Inherited {name}={resolved} from pretrained checkpoint") - elif pretrain_val is not None and arg_val != pretrain_val: - print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}") + get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 + if device_type == "cuda": + gpu_device_name = torch.cuda.get_device_name(0) + gpu_peak_flops = get_peak_flops(gpu_device_name) + print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") else: - print0(f"Using {name}={arg_val}") + gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS -orig_model = model -# model = torch.compile(model, dynamic=False) -depth = model.config.n_layer -num_flops_per_token = model.estimate_flops() -tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank -world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert args.total_batch_size % world_tokens_per_fwdbwd == 0 -grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd -print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") -print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") -print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") -token_bytes = get_token_bytes(device=device) + # wandb logging init + use_dummy_wandb = args.run == "dummy" or not master_process + wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) -# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) -# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero -optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0, device_type=device_type) + # Flash Attention status + if not HAS_FA3: + print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.") -# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) -# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the -# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and -# restore our fresh SFT LRs after loading. -base_dir = get_base_dir() -if args.load_optimizer: - optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step) - if optimizer_data is not None: - base_lrs = [group["lr"] for group in optimizer.param_groups] - optimizer.load_state_dict(optimizer_data) - del optimizer_data - for group, base_lr in zip(optimizer.param_groups, base_lrs): - group["lr"] = base_lr - print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)") - else: - print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)") + # Load the model and tokenizer + model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) -# GradScaler for fp16 training (bf16/fp32 don't need it) -scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None -if scaler is not None: - print0("GradScaler enabled for fp16 training") + # Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override) + pretrain_user_config = meta.get("user_config", {}) + for name, fallback, source in [ + ("max_seq_len", 2048, meta), + ("device_batch_size", 32, meta), + ("total_batch_size", 524288, meta), + ("embedding_lr", 0.3, pretrain_user_config), + ("unembedding_lr", 0.004, pretrain_user_config), + ("matrix_lr", 0.02, pretrain_user_config), + ]: + arg_val = getattr(args, name) + pretrain_val = source.get(name) + if arg_val is None: + resolved = pretrain_val if pretrain_val is not None else fallback + setattr(args, name, resolved) + print0(f"Inherited {name}={resolved} from pretrained checkpoint") + elif pretrain_val is not None and arg_val != pretrain_val: + print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}") + else: + print0(f"Using {name}={arg_val}") -# Override the initial learning rate as a fraction of the base learning rate -for group in optimizer.param_groups: - group["lr"] = group["lr"] * args.init_lr_frac - group["initial_lr"] = group["lr"] + orig_model = model + # model = torch.compile(model, dynamic=False) + depth = model.config.n_layer + num_flops_per_token = model.estimate_flops() + tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank + world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks + assert args.total_batch_size % world_tokens_per_fwdbwd == 0 + grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd + print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") + print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") + print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") + token_bytes = get_token_bytes(device=device) -# SFT data mixture and DataLoader -identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") -train_tasks = [ - SmolTalk(split="train"), # 460K rows of general conversations - CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations - CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these - *[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch - *[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch - SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') - SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) -] -train_dataset = TaskMixture(train_tasks) -print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})") -val_dataset = TaskMixture([ - SmolTalk(split="test"), # 24K rows in test set - MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios - GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios -]) # total: 24K + 14K + 1.32K ~= 39K rows -# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) -# A big problem is that we don't know the final num_iterations in advance. So we create -# these two global variables and update them from within the data generator. -last_step = False # we will toggle this to True when we reach the end of the training dataset -approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch -current_epoch = 1 # track epoch for logging -def sft_data_generator_bos_bestfit(split, buffer_size=100): - """ - BOS-aligned dataloader for SFT with bestfit-pad packing. + # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) + # Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero + optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0, device_type=device_type) - Each row in the batch starts with BOS (beginning of a conversation). - Conversations are packed using best-fit algorithm. When no conversation fits, - the row is padded (instead of cropping) to ensure no tokens are ever discarded. - Padding positions have targets masked with -1 (ignore_index for cross-entropy). - """ - global last_step, approx_progress, current_epoch - assert split in {"train", "val"}, "split must be 'train' or 'val'" - dataset = train_dataset if split == "train" else val_dataset - dataset_size = len(dataset) - assert dataset_size > 0 - row_capacity = args.max_seq_len + 1 # +1 for target at last position - bos_token = tokenizer.get_bos_token_id() + # Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) + # Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the + # pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and + # restore our fresh SFT LRs after loading. + base_dir = get_base_dir() + if args.load_optimizer: + optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step) + if optimizer_data is not None: + base_lrs = [group["lr"] for group in optimizer.param_groups] + optimizer.load_state_dict(optimizer_data) + del optimizer_data + for group, base_lr in zip(optimizer.param_groups, base_lrs): + group["lr"] = base_lr + print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)") + else: + print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)") - # Conversation buffer: list of (token_ids, loss_mask) tuples - conv_buffer = [] - cursor = ddp_rank # Each rank processes different conversations (for fetching) - consumed = ddp_rank # Track actual consumption separately from buffering - epoch = 1 - it = 0 # iteration counter + # GradScaler for fp16 training (bf16/fp32 don't need it) + scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None + if scaler is not None: + print0("GradScaler enabled for fp16 training") - def refill_buffer(): - nonlocal cursor, epoch - while len(conv_buffer) < buffer_size: - conversation = dataset[cursor] - ids, mask = tokenizer.render_conversation(conversation) - conv_buffer.append((ids, mask)) - cursor += ddp_world_size - if cursor >= dataset_size: - cursor = cursor % dataset_size - epoch += 1 - # Note: last_step is now triggered based on consumption, not fetching + # Override the initial learning rate as a fraction of the base learning rate + for group in optimizer.param_groups: + group["lr"] = group["lr"] * args.init_lr_frac + group["initial_lr"] = group["lr"] - while True: - rows = [] - mask_rows = [] - row_lengths = [] # Track actual content length (excluding padding) for each row - for _ in range(args.device_batch_size): - row = [] - mask_row = [] - padded = False - while len(row) < row_capacity: - # Ensure buffer has conversations - while len(conv_buffer) < buffer_size: - refill_buffer() + # SFT data mixture and DataLoader + identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") + train_tasks = [ + SmolTalk(split="train"), # 460K rows of general conversations + CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations + CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these + *[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch + *[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch + SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') + SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) + ] + train_dataset = TaskMixture(train_tasks) + print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})") + val_dataset = TaskMixture([ + SmolTalk(split="test"), # 24K rows in test set + MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios + GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios + ]) # total: 24K + 14K + 1.32K ~= 39K rows + # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) + # A big problem is that we don't know the final num_iterations in advance. So we create + # these two global variables and update them from within the data generator. + last_step = False # we will toggle this to True when we reach the end of the training dataset + approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch + current_epoch = 1 # track epoch for logging + def sft_data_generator_bos_bestfit(split, buffer_size=100): + """ + BOS-aligned dataloader for SFT with bestfit-pad packing. - remaining = row_capacity - len(row) + Each row in the batch starts with BOS (beginning of a conversation). + Conversations are packed using best-fit algorithm. When no conversation fits, + the row is padded (instead of cropping) to ensure no tokens are ever discarded. + Padding positions have targets masked with -1 (ignore_index for cross-entropy). + """ + global last_step, approx_progress, current_epoch + assert split in {"train", "val"}, "split must be 'train' or 'val'" + dataset = train_dataset if split == "train" else val_dataset + dataset_size = len(dataset) + assert dataset_size > 0 + row_capacity = args.max_seq_len + 1 # +1 for target at last position + bos_token = tokenizer.get_bos_token_id() - # Find largest conversation that fits entirely - best_idx = -1 - best_len = 0 - for i, (conv, _) in enumerate(conv_buffer): - conv_len = len(conv) - if conv_len <= remaining and conv_len > best_len: - best_idx = i - best_len = conv_len + # Conversation buffer: list of (token_ids, loss_mask) tuples + conv_buffer = [] + cursor = ddp_rank # Each rank processes different conversations (for fetching) + consumed = ddp_rank # Track actual consumption separately from buffering + epoch = 1 + it = 0 # iteration counter - if best_idx >= 0: - # Found a conversation that fits - use it entirely - conv, conv_mask = conv_buffer.pop(best_idx) - row.extend(conv) - mask_row.extend(conv_mask) - consumed += ddp_world_size # Track actual consumption + def refill_buffer(): + nonlocal cursor, epoch + while len(conv_buffer) < buffer_size: + conversation = dataset[cursor] + ids, mask = tokenizer.render_conversation(conversation) + conv_buffer.append((ids, mask)) + cursor += ddp_world_size + if cursor >= dataset_size: + cursor = cursor % dataset_size + epoch += 1 + # Note: last_step is now triggered based on consumption, not fetching + + while True: + rows = [] + mask_rows = [] + row_lengths = [] # Track actual content length (excluding padding) for each row + for _ in range(args.device_batch_size): + row = [] + mask_row = [] + padded = False + while len(row) < row_capacity: + # Ensure buffer has conversations + while len(conv_buffer) < buffer_size: + refill_buffer() + + remaining = row_capacity - len(row) + + # Find largest conversation that fits entirely + best_idx = -1 + best_len = 0 + for i, (conv, _) in enumerate(conv_buffer): + conv_len = len(conv) + if conv_len <= remaining and conv_len > best_len: + best_idx = i + best_len = conv_len + + if best_idx >= 0: + # Found a conversation that fits - use it entirely + conv, conv_mask = conv_buffer.pop(best_idx) + row.extend(conv) + mask_row.extend(conv_mask) + consumed += ddp_world_size # Track actual consumption + else: + # No conversation fits - pad the remainder instead of cropping + # This ensures we never discard any tokens + content_len = len(row) + row.extend([bos_token] * remaining) # Pad with BOS tokens + mask_row.extend([0] * remaining) + padded = True + break # Row is now full (with padding) + + # Track content length: full row if no padding, otherwise the length before padding + if padded: + row_lengths.append(content_len) else: - # No conversation fits - pad the remainder instead of cropping - # This ensures we never discard any tokens - content_len = len(row) - row.extend([bos_token] * remaining) # Pad with BOS tokens - mask_row.extend([0] * remaining) - padded = True - break # Row is now full (with padding) + row_lengths.append(row_capacity) + rows.append(row[:row_capacity]) + mask_rows.append(mask_row[:row_capacity]) - # Track content length: full row if no padding, otherwise the length before padding - if padded: - row_lengths.append(content_len) - else: - row_lengths.append(row_capacity) - rows.append(row[:row_capacity]) - mask_rows.append(mask_row[:row_capacity]) - - # Stopping condition to respect num_iterations, if given - it += 1 - if 0 < args.num_iterations <= it and split == "train": - last_step = True - - # Update progress tracking (based on consumed, not cursor, to account for buffering) - if split == "train": - current_epoch = epoch - if args.num_iterations > 0: - approx_progress = it / args.num_iterations - else: - approx_progress = consumed / dataset_size - # Trigger last_step when we've consumed enough (instead of when cursor wraps) - if consumed >= dataset_size: + # Stopping condition to respect num_iterations, if given + it += 1 + if 0 < args.num_iterations <= it and split == "train": last_step = True - # Build tensors - use_cuda = device_type == "cuda" - batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) - inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous() - targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous() + # Update progress tracking (based on consumed, not cursor, to account for buffering) + if split == "train": + current_epoch = epoch + if args.num_iterations > 0: + approx_progress = it / args.num_iterations + else: + approx_progress = consumed / dataset_size + # Trigger last_step when we've consumed enough (instead of when cursor wraps) + if consumed >= dataset_size: + last_step = True - # Apply the loss mask from render_conversation (mask=1 for assistant completions, - # mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns - # with targets (shifted by 1). Unmasked positions get -1 (ignore_index). - mask_tensor = torch.tensor(mask_rows, dtype=torch.int8) - mask_targets = mask_tensor[:, 1:].to(device=device) - targets[mask_targets == 0] = -1 + # Build tensors + use_cuda = device_type == "cuda" + batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) + inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous() + targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous() - # Mask out padding positions in targets (set to -1 = ignore_index) - # For each row, positions >= (content_length - 1) in targets should be masked - for i, content_len in enumerate(row_lengths): - if content_len < row_capacity: - targets[i, content_len-1:] = -1 + # Apply the loss mask from render_conversation (mask=1 for assistant completions, + # mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns + # with targets (shifted by 1). Unmasked positions get -1 (ignore_index). + mask_tensor = torch.tensor(mask_rows, dtype=torch.int8) + mask_targets = mask_tensor[:, 1:].to(device=device) + targets[mask_targets == 0] = -1 - yield inputs, targets + # Mask out padding positions in targets (set to -1 = ignore_index) + # For each row, positions >= (content_length - 1) in targets should be masked + for i, content_len in enumerate(row_lengths): + if content_len < row_capacity: + targets[i, content_len-1:] = -1 -train_loader = sft_data_generator_bos_bestfit("train") -build_val_loader = lambda: sft_data_generator_bos_bestfit("val") -progress = 0 # will go from 0 to 1 over the course of the epoch + yield inputs, targets -# Learning rate schedule (linear warmup, constant, linear warmdown) -# Same shape as base_train but uses progress (0→1) instead of absolute step counts, -# because SFT doesn't always know num_iterations in advance (dataset-driven stopping). -def get_lr_multiplier(progress): - if progress < args.warmup_ratio: - return (progress + 1e-8) / args.warmup_ratio - elif progress <= 1.0 - args.warmdown_ratio: - return 1.0 - else: - decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio - return (1 - decay) * 1.0 + decay * args.final_lr_frac + train_loader = sft_data_generator_bos_bestfit("train") + build_val_loader = lambda: sft_data_generator_bos_bestfit("val") + progress = 0 # will go from 0 to 1 over the course of the epoch -# Momentum scheduler for Muon optimizer -def get_muon_momentum(it): - frac = min(it / 300, 1) - momentum = (1 - frac) * 0.85 + frac * 0.95 - return momentum - -# ----------------------------------------------------------------------------- -# Training loop -x, y = next(train_loader) # prefetch the very first batch of data -min_val_bpb = float("inf") -smooth_train_loss = 0 # EMA of training loss -ema_beta = 0.9 # EMA decay factor -total_training_time = 0 # total wall-clock time of training -step = 0 -while True: - flops_so_far = num_flops_per_token * args.total_batch_size * step - - # Synchronize last_step across all ranks to avoid hangs in the distributed setting - if ddp: - last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device) - dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) - last_step = bool(last_step_tensor.item()) - - # once in a while: evaluate the val bpb (all ranks participate) - if last_step or (args.eval_every > 0 and step % args.eval_every == 0): - model.eval() - val_loader = build_val_loader() - eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) - val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) - print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") - if val_bpb < min_val_bpb: - min_val_bpb = val_bpb - wandb_run.log({ - "step": step, - "total_training_flops": flops_so_far, - "total_training_time": total_training_time, - "val/bpb": val_bpb, - }) - model.train() - - # once in a while: estimate the ChatCORE metric (all ranks participate) - # use the original uncompiled model because the inputs keep changing shape - chatcore_results = {} - if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)): - model.eval() - engine = Engine(orig_model, tokenizer) - all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee'] - categorical_tasks = {'ARC-Easy', 'ARC-Challenge', 'MMLU'} - baseline_accuracies = { - 'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25, - 'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0, - } - task_results = {} - for task_name in all_tasks: - limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample - max_problems = None if limit < 0 else limit # -1 means no limit - acc = run_chat_eval(task_name, orig_model, tokenizer, engine, - batch_size=args.device_batch_size, max_problems=max_problems) - task_results[task_name] = acc - print0(f" {task_name}: {100*acc:.2f}%") - # Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect) - def centered_mean(tasks): - return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks) - chatcore = centered_mean(all_tasks) - chatcore_cat = centered_mean(categorical_tasks) - print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}") - wandb_run.log({ - "step": step, - "total_training_flops": flops_so_far, - "chatcore_metric": chatcore, - "chatcore_cat": chatcore_cat, - **{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()}, - }) - model.train() - - # save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard) - if last_step: - output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12 - checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname) - save_checkpoint( - checkpoint_dir, - step, - orig_model.state_dict(), - optimizer.state_dict(), - { - "step": step, - "val_bpb": val_bpb, # loss at last step - "model_config": { - "sequence_len": args.max_seq_len, - "vocab_size": tokenizer.get_vocab_size(), - "n_layer": depth, - "n_head": model.config.n_head, - "n_kv_head": model.config.n_kv_head, - "n_embd": model.config.n_embd, - "window_pattern": model.config.window_pattern, - }, - "user_config": user_config, # inputs to the training script - }, - rank=ddp_rank, - ) - - if last_step: - break - - # ------------------------------------------------------------------------- - # single training step - # evaluate the gradient - synchronize() - t0 = time.time() - for micro_step in range(grad_accum_steps): - loss = model(x, y) - train_loss = loss.detach() # for logging - loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here - if scaler is not None: - scaler.scale(loss).backward() + # Learning rate schedule (linear warmup, constant, linear warmdown) + # Same shape as base_train but uses progress (0→1) instead of absolute step counts, + # because SFT doesn't always know num_iterations in advance (dataset-driven stopping). + def get_lr_multiplier(progress): + if progress < args.warmup_ratio: + return (progress + 1e-8) / args.warmup_ratio + elif progress <= 1.0 - args.warmdown_ratio: + return 1.0 else: - loss.backward() - x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward - progress = max(progress, approx_progress) # only increase progress monotonically - # step the optimizer - lrm = get_lr_multiplier(progress) - muon_momentum = get_muon_momentum(step) - for group in optimizer.param_groups: - group["lr"] = group["initial_lr"] * lrm - if group['kind'] == 'muon': - group["momentum"] = muon_momentum - if scaler is not None: - scaler.unscale_(optimizer) - if is_ddp_initialized(): - for v in scaler._found_inf_per_device(optimizer).values(): - dist.all_reduce(v, op=dist.ReduceOp.MAX) - scaler.step(optimizer) - scaler.update() - else: - optimizer.step() - model.zero_grad(set_to_none=True) - synchronize() - t1 = time.time() - dt = t1 - t0 - # ------------------------------------------------------------------------- + decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio + return (1 - decay) * 1.0 + decay * args.final_lr_frac - # State - step += 1 + # Momentum scheduler for Muon optimizer + def get_muon_momentum(it): + frac = min(it / 300, 1) + momentum = (1 - frac) * 0.85 + frac * 0.95 + return momentum - # logging - smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss - debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA - pct_done = 100 * progress - tok_per_sec = int(args.total_batch_size / dt) - flops_per_sec = num_flops_per_token * args.total_batch_size / dt - mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) - if step > 10: - total_training_time += dt # only count the time after the first 10 steps - print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m") - if step % 10 == 0: - wandb_run.log({ - "step": step, - "total_training_flops": flops_so_far, - "total_training_time": total_training_time, - "train/loss": debiased_smooth_loss, - "train/lrm": lrm, - "train/dt": dt, - "train/tok_per_sec": tok_per_sec, - "train/mfu": mfu, - "train/epoch": current_epoch, - }) + # ----------------------------------------------------------------------------- + # Training loop + x, y = next(train_loader) # prefetch the very first batch of data + min_val_bpb = float("inf") + smooth_train_loss = 0 # EMA of training loss + ema_beta = 0.9 # EMA decay factor + total_training_time = 0 # total wall-clock time of training + step = 0 + while True: + flops_so_far = num_flops_per_token * args.total_batch_size * step - # The garbage collector spends ~500ms scanning for cycles quite frequently. - # We manually manage it to avoid these pauses during training. - if step == 1: - gc.collect() # manually collect a lot of garbage from setup - gc.freeze() # freeze all currently surviving objects and exclude them from GC - gc.disable() # disable GC entirely except: - elif step % 5000 == 0: # every 5000 steps... - gc.collect() # manually collect, just to be safe for very long runs + # Synchronize last_step across all ranks to avoid hangs in the distributed setting + if ddp: + last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device) + dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) + last_step = bool(last_step_tensor.item()) -# print a few more stats -print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") -print0(f"Total training time: {total_training_time/60:.2f}m") -print0(f"Minimum validation bpb: {min_val_bpb:.4f}") + # once in a while: evaluate the val bpb (all ranks participate) + if last_step or (args.eval_every > 0 and step % args.eval_every == 0): + model.eval() + val_loader = build_val_loader() + eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) + val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) + print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") + if val_bpb < min_val_bpb: + min_val_bpb = val_bpb + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "val/bpb": val_bpb, + }) + model.train() -# Log to report -from nanochat.report import get_report -get_report().log(section="SFT", data=[ - user_config, # CLI args - { # stats about the training setup - "Number of iterations": step, - "DDP world size": ddp_world_size, - }, - { # stats about training outcomes - "Minimum validation bpb": min_val_bpb, - } -]) + # once in a while: estimate the ChatCORE metric (all ranks participate) + # use the original uncompiled model because the inputs keep changing shape + chatcore_results = {} + if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)): + model.eval() + engine = Engine(orig_model, tokenizer) + # all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee'] + all_tasks = [ + 'ARC-Easy', 'ARC-Challenge', + 'MMLU', + 'GSM8K', + 'HumanEval', + 'SpellingBee' + ] + categorical_tasks = { + 'ARC-Easy', 'ARC-Challenge', + 'MMLU' + } + baseline_accuracies = { + 'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25, + 'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0, + } + task_results = {} + for task_name in all_tasks: + limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample + max_problems = None if limit < 0 else limit # -1 means no limit + max_problems = 10 if (device_type != "cuda" and task_name == "MMLU") else max_problems -# cleanup -wandb_run.finish() # wandb run finish -compute_cleanup() + acc = run_chat_eval(task_name, orig_model, tokenizer, engine, + batch_size=args.device_batch_size, max_problems=max_problems) + task_results[task_name] = acc + print0(f" {task_name}: {100*acc:.2f}%") + # Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect) + def centered_mean(tasks): + return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks) + chatcore = centered_mean(all_tasks) + chatcore_cat = centered_mean(categorical_tasks) + print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}") + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "chatcore_metric": chatcore, + "chatcore_cat": chatcore_cat, + **{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()}, + }) + model.train() + + # save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard) + if last_step: + output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12 + checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname) + save_checkpoint( + checkpoint_dir, + step, + orig_model.state_dict(), + optimizer.state_dict(), + { + "step": step, + "val_bpb": val_bpb, # loss at last step + "model_config": { + "sequence_len": args.max_seq_len, + "vocab_size": tokenizer.get_vocab_size(), + "n_layer": depth, + "n_head": model.config.n_head, + "n_kv_head": model.config.n_kv_head, + "n_embd": model.config.n_embd, + "window_pattern": model.config.window_pattern, + }, + "user_config": user_config, # inputs to the training script + }, + rank=ddp_rank, + ) + + if last_step: + break + + # ------------------------------------------------------------------------- + # single training step + # evaluate the gradient + synchronize() + t0 = time.time() + for micro_step in range(grad_accum_steps): + loss = model(x, y) + train_loss = loss.detach() # for logging + loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here + if scaler is not None: + scaler.scale(loss).backward() + else: + loss.backward() + x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + progress = max(progress, approx_progress) # only increase progress monotonically + # step the optimizer + lrm = get_lr_multiplier(progress) + muon_momentum = get_muon_momentum(step) + for group in optimizer.param_groups: + group["lr"] = group["initial_lr"] * lrm + if group['kind'] == 'muon': + group["momentum"] = muon_momentum + if scaler is not None: + scaler.unscale_(optimizer) + if is_ddp_initialized(): + for v in scaler._found_inf_per_device(optimizer).values(): + dist.all_reduce(v, op=dist.ReduceOp.MAX) + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + model.zero_grad(set_to_none=True) + synchronize() + t1 = time.time() + dt = t1 - t0 + # ------------------------------------------------------------------------- + + # State + step += 1 + + # logging + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA + pct_done = 100 * progress + tok_per_sec = int(args.total_batch_size / dt) + flops_per_sec = num_flops_per_token * args.total_batch_size / dt + mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) + if step > 10: + total_training_time += dt # only count the time after the first 10 steps + print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m") + if step % 10 == 0: + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "train/loss": debiased_smooth_loss, + "train/lrm": lrm, + "train/dt": dt, + "train/tok_per_sec": tok_per_sec, + "train/mfu": mfu, + "train/epoch": current_epoch, + }) + + # The garbage collector spends ~500ms scanning for cycles quite frequently. + # We manually manage it to avoid these pauses during training. + if step == 1: + gc.collect() # manually collect a lot of garbage from setup + gc.freeze() # freeze all currently surviving objects and exclude them from GC + gc.disable() # disable GC entirely except: + elif step % 5000 == 0: # every 5000 steps... + gc.collect() # manually collect, just to be safe for very long runs + + # print a few more stats + print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") + print0(f"Total training time: {total_training_time/60:.2f}m") + print0(f"Minimum validation bpb: {min_val_bpb:.4f}") + + # Log to report + from nanochat.report import get_report + get_report().log(section="SFT", data=[ + user_config, # CLI args + { # stats about the training setup + "Number of iterations": step, + "DDP world size": ddp_world_size, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb, + } + ]) + + # cleanup + wandb_run.finish() # wandb run finish + compute_cleanup()