""" Supervised fine-tuning (SFT) the model. Run as: python -m scripts.chat_sft Or torchrun for training: torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16 """ import argparse import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import time import wandb import torch from contextlib import nullcontext from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.checkpoint_manager import load_model import torch.distributed as dist from tasks.common import TaskMixture from tasks.gsm8k import GSM8K from tasks.mmlu import MMLU from tasks.smoltalk import SmolTalk from tasks.customjson import CustomJSON from tasks.spellingbee import SimpleSpelling, SpellingBee # ----------------------------------------------------------------------------- # CLI arguments parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") # Training horizon parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") # Batch sizes parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") # Optimization parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR") # Evaluation parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") # Output parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- # Compute init device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) pretrain_batch_size = meta.get("device_batch_size", None) if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") orig_model = model model = torch.compile(model, dynamic=False) depth = model.config.n_layer num_flops_per_token = model.estimate_flops() tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks assert args.total_batch_size % world_tokens_per_fwdbwd == 0 grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") token_bytes = get_token_bytes(device=device) # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay) # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: group["lr"] = group["lr"] * args.init_lr_frac group["initial_lr"] = group["lr"] # SFT data mixture and DataLoader base_dir = get_base_dir() identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") train_dataset = TaskMixture([ SmolTalk(split="train"), # 460K rows of general conversations MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use GSM8K(subset="main", split="train"), # 2 epochs of GSM8K CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) ]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows val_dataset = TaskMixture([ SmolTalk(split="test"), # 24K rows in test set MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios ]) # total: 24K + 14K + 1.32K ~= 39K rows # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) # A big problem is that we don't know the final num_iterations in advance. So we create # these two global variables and update them from within the data generator. last_step = False # we will toggle this to True when we reach the end of the training dataset approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch current_epoch = 1 # track epoch for logging def sft_data_generator_bos_bestfit(split, buffer_size=100): """ BOS-aligned dataloader for SFT with bestfit-pad packing. Each row in the batch starts with BOS (beginning of a conversation). Conversations are packed using best-fit algorithm. When no conversation fits, the row is padded. If the row is still empty, we fall back to truncating one conversation to avoid fully padded rows. Padding positions have targets masked with -1 (ignore_index for cross-entropy). """ global last_step, approx_progress, current_epoch assert split in {"train", "val"}, "split must be 'train' or 'val'" dataset = train_dataset if split == "train" else val_dataset dataset_size = len(dataset) assert dataset_size > 0 row_capacity = args.max_seq_len + 1 # +1 for target at last position bos_token = tokenizer.get_bos_token_id() # Conversation buffer: list of token lists conv_buffer = [] cursor = ddp_rank # Each rank processes different conversations (for fetching) consumed = ddp_rank # Track actual consumption separately from buffering epoch = 1 it = 0 # iteration counter def refill_buffer(): nonlocal cursor, epoch while len(conv_buffer) < buffer_size: conversation = dataset[cursor] ids, _ = tokenizer.render_conversation(conversation) conv_buffer.append(ids) cursor += ddp_world_size if cursor >= dataset_size: cursor = cursor % dataset_size epoch += 1 # Note: last_step is now triggered based on consumption, not fetching while True: rows = [] row_lengths = [] # Track actual content length (excluding padding) for each row for _ in range(args.device_batch_size): row = [] padded = False while len(row) < row_capacity: # Ensure buffer has conversations while len(conv_buffer) < buffer_size: refill_buffer() remaining = row_capacity - len(row) # Find largest conversation that fits entirely best_idx = -1 best_len = 0 for i, conv in enumerate(conv_buffer): conv_len = len(conv) if conv_len <= remaining and conv_len > best_len: best_idx = i best_len = conv_len if best_idx >= 0: # Found a conversation that fits - use it entirely conv = conv_buffer.pop(best_idx) row.extend(conv) consumed += ddp_world_size # Track actual consumption else: if len(row) == 0: # No conversation fits even on an empty row. # Truncate the shortest one to avoid fully padded rows. trunc_idx = -1 trunc_len = None for i, conv in enumerate(conv_buffer): conv_len = len(conv) if conv_len > remaining and (trunc_len is None or conv_len < trunc_len): trunc_idx = i trunc_len = conv_len conv = conv_buffer.pop(trunc_idx) row.extend(conv[:row_capacity]) consumed += ddp_world_size else: # No conversation fits - pad the remainder content_len = len(row) row.extend([bos_token] * remaining) # Pad with BOS tokens padded = True break # Row is now full (with padding or truncation) # Track content length: full row if no padding, otherwise the length before padding if padded: row_lengths.append(content_len) else: row_lengths.append(row_capacity) rows.append(row[:row_capacity]) # Stopping condition to respect num_iterations, if given it += 1 if 0 < args.num_iterations <= it and split == "train": last_step = True # Update progress tracking (based on consumed, not cursor, to account for buffering) if split == "train": current_epoch = epoch if args.num_iterations > 0: approx_progress = it / args.num_iterations else: approx_progress = consumed / dataset_size # Trigger last_step when we've consumed enough (instead of when cursor wraps) if consumed >= dataset_size: last_step = True # Build tensors use_cuda = device_type == "cuda" batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) # Mask out padding positions in targets (set to -1 = ignore_index) # For each row, positions >= (content_length - 1) in targets should be masked for i, content_len in enumerate(row_lengths): if content_len < row_capacity: targets[i, content_len-1:] = -1 yield inputs, targets train_loader = sft_data_generator_bos_bestfit("train") build_val_loader = lambda: sft_data_generator_bos_bestfit("val") progress = 0 # will go from 0 to 1 over the course of the epoch # Learning rate scheduler def get_lr_multiplier(progress): # first 80% of training: no decay, then linearly ramp down to 0. return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 # Momentum scheduler for Muon optimizer def get_muon_momentum(it): frac = min(it / 300, 1) momentum = (1 - frac) * 0.85 + frac * 0.95 return momentum # ----------------------------------------------------------------------------- # Training loop x, y = next(train_loader) # prefetch the very first batch of data min_val_bpb = float("inf") smooth_train_loss = 0 # EMA of training loss ema_beta = 0.9 # EMA decay factor total_training_time = 0 # total wall-clock time of training step = 0 while True: flops_so_far = num_flops_per_token * args.total_batch_size * step # Synchronize last_step across all ranks to avoid hangs in the distributed setting if ddp: last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device) dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) last_step = bool(last_step_tensor.item()) # once in a while: evaluate the val bpb (all ranks participate) if last_step or (args.eval_every > 0 and step % args.eval_every == 0): model.eval() val_loader = build_val_loader() eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) with autocast_ctx: val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") if val_bpb < min_val_bpb: min_val_bpb = val_bpb wandb_run.log({ "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, "val/bpb": val_bpb, }) model.train() # save checkpoint at the end of the run (only on master process) if master_process and last_step and not args.dry_run: output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12 checkpoint_dir = os.path.join(base_dir, "sft_checkpoints", output_dirname) save_checkpoint( checkpoint_dir, step, orig_model.state_dict(), optimizer.state_dict(), { "step": step, "val_bpb": val_bpb, # loss at last step "model_config": { "sequence_len": args.max_seq_len, "vocab_size": tokenizer.get_vocab_size(), "n_layer": depth, "n_head": model.config.n_head, "n_kv_head": model.config.n_kv_head, "n_embd": model.config.n_embd, }, "user_config": user_config, # inputs to the training script } ) if last_step: break # ------------------------------------------------------------------------- # single training step # evaluate the gradient synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): with autocast_ctx: loss = model(x, y) train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward progress = max(progress, approx_progress) # only increase progress monotonically # step the optimizer lrm = get_lr_multiplier(progress) muon_momentum = get_muon_momentum(step) for group in optimizer.param_groups: group["lr"] = group["initial_lr"] * lrm if group['kind'] == 'muon': group["momentum"] = muon_momentum optimizer.step() model.zero_grad(set_to_none=True) synchronize() t1 = time.time() dt = t1 - t0 # ------------------------------------------------------------------------- # State step += 1 # logging smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA pct_done = 100 * progress tok_per_sec = int(args.total_batch_size / dt) flops_per_sec = num_flops_per_token * args.total_batch_size / dt promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % if step > 10: total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m") if step % 10 == 0: wandb_run.log({ "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, "train/loss": debiased_smooth_loss, "train/lrm": lrm, "train/dt": dt, "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, "train/epoch": current_epoch, }) # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") # Log to report if not args.dry_run: from nanochat.report import get_report get_report().log(section="SFT", data=[ user_config, # CLI args { # stats about the training setup "Number of iterations": step, "DDP world size": ddp_world_size, }, { # stats about training outcomes "Minimum validation bpb": min_val_bpb, } ]) # cleanup wandb_run.finish() # wandb run finish compute_cleanup()