""" Finetune a base model to be a chat model. Run on one GPU e.g. for debugging: python -m scripts.chat_sft Or torchrun for training: torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft """ import os os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import wandb import torch import torch.distributed as dist from contextlib import nullcontext from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import save_checkpoint from nanochat.engine import Engine from scripts.chat_eval import run_chat_eval from tasks.common import TaskMixture from tasks.arc import ARC from tasks.gsm8k import GSM8K from tasks.smoltalk import SmolTalk from tasks.customjson import CustomJSON from tasks.nemotron import Nemotron # ----------------------------------------------------------------------------- # SFT Hyperparameters run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) # input model options source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model) model_tag = None # model tag to load the model from (base model or midtrained model) step = None # step to load the model from (base model or midtrained model) # compute/precision device_type = "" # cuda|cpu|mps (empty => autodetect) dtype = "bfloat16" device_batch_size = 4 # max to avoid OOM # optimization num_epochs = 1 num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it) target_examples_per_step = 32 unembedding_lr = 0.004 embedding_lr = 0.2 matrix_lr = 0.02 weight_decay = 0.0 init_lr_frac = 0.02 # evaluation and logging there of eval_every = 100 eval_steps = 100 eval_metrics_every = 200 eval_metrics_max_problems = 1024 dataset_choice = "smoltalk" # dataset choice: "smoltalk" or "nemotron" # now allow CLI to override the settings via the configurator lol config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging # ----------------------------------------------------------------------------- # Compute init device_type = autodetect_device_type() if device_type == "" else device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16 autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() # wandb logging init use_dummy_wandb = run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True) # Load the model and tokenizer model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step) orig_model = model # original, uncompiled model # model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs engine = Engine(model, tokenizer) # will be used for inline model evaluation only tokenizer_name = meta.get("tokenizer_name", "tokenizer") print0(f"Using tokenizer: {tokenizer_name}") # ----------------------------------------------------------------------------- # Task data mixture we'll train on # Select dataset based on dataset_choice parameter print0(f"SFT using dataset: {dataset_choice}") identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl") if dataset_choice == "smoltalk": # Original: SmolTalk + ARC + GSM8K train_ds = TaskMixture([ ARC(subset="ARC-Easy", split="train"), # 2.3K rows ARC(subset="ARC-Challenge", split="train"), # 1.1K rows GSM8K(subset="main", split="train"), # 8K rows SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations ]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows val_ds = SmolTalk(split="test") # general conversations, 24K rows elif dataset_choice == "nemotron": # Ablation: Nemotron (sampled to match SmolTalk 10K) + ARC + GSM8K # SmolTalk has 10K samples, we sample Nemotron proportionally to match # Original Nemotron distribution: stem(25.4%), math(17.1%), chat(44.9%), code(12.5%) train_ds = TaskMixture([ ARC(subset="ARC-Easy", split="train"), # 2.3K rows ARC(subset="ARC-Challenge", split="train"), # 1.1K rows GSM8K(subset="main", split="train"), # 8K rows Nemotron(categories=["stem"], split="train", stop=3000), # 3K samples Nemotron(categories=["math"], split="train", stop=3000), # 3K samples Nemotron(categories=["chat"], split="train", stop=1000), # 1K samples Nemotron(categories=["code"], split="train", stop=3000), # 3K samples ]) # total: 2.3K + 1.1K + 8K + (3.0K + 3.0K + 1.0K + 3.0K) + 1K = 22.4K rows (similar to SmolTalk) # For validation, use a small subset of Nemotron mixed categories val_ds = TaskMixture([ Nemotron(categories=["stem"], split="train", start=3000, stop=3300), # 300 samples Nemotron(categories=["math"], split="train", start=3000, stop=3300), # 300 samples Nemotron(categories=["chat"], split="train", start=1000, stop=1100), # 100 samples Nemotron(categories=["code"], split="train", start=3000, stop=3300), # 300 samples ]) # 300 + 300 + 100 + 300 = 1K samples for validation else: raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'") # ----------------------------------------------------------------------------- # DataLoader def sft_data_generator(dataset, batch_size): pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss # prepares a list of tokenized conversations into a batch and yields def collate_and_yield(batch): nrows = len(batch) ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1 inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long) targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index for i, (ids, mask) in enumerate(batch): n = len(ids) ids_tensor = torch.tensor(ids, dtype=torch.long) inputs[i, :n-1] = ids_tensor[:-1] # recall -1 is the ignore index, so mask out targets where mask is 0 row_targets = ids_tensor[1:] # mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok mask_tensor = torch.tensor(mask[1:], dtype=torch.long) row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0 targets[i, :n-1] = row_targets inputs = inputs.to(device) # move to device targets = targets.to(device) return inputs, targets # iterates over the dataset in epochs, tokenizes batch = [] while True: for i in range(ddp_rank, len(dataset), ddp_world_size): doc = dataset[i] ids, mask = tokenizer.render_conversation(doc) batch.append((ids, mask)) if len(batch) == batch_size: yield collate_and_yield(batch) batch = [] examples_per_step = device_batch_size * ddp_world_size print0(f"Target examples per step: {target_examples_per_step}") print0(f"Device batch size: {device_batch_size}") print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}") assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step" grad_accum_steps = target_examples_per_step // examples_per_step print0(f"=> Setting grad accum steps: {grad_accum_steps}") if num_iterations == -1: # derive num_iterations from num_epochs and the size of the dataset assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1" num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs train_loader = sft_data_generator(train_ds, batch_size=device_batch_size) build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size) # ----------------------------------------------------------------------------- # Initialize the Optimizer optimizers = model.setup_optimizers( unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay, ) # Set the initial learning rate as a fraction of the base learning rate for opt in optimizers: for group in opt.param_groups: group["lr"] = group["lr"] * init_lr_frac group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later # ----------------------------------------------------------------------------- # Training loop # Learning rate scheduler def get_lr_multiplier(it): lrm = 1.0 - it / num_iterations return lrm # Go! step = 0 train_iter = iter(train_loader) for step in range(num_iterations): last_step = step == num_iterations - 1 # evaluate the validation loss if last_step or step % eval_every == 0: model.eval() val_iter = iter(build_val_loader()) losses = [] for _ in range(eval_steps): val_inputs, val_targets = next(val_iter) with torch.no_grad(), autocast_ctx: loss = model(val_inputs, val_targets) losses.append(loss) val_loss = torch.stack(losses).mean() # average over eval_steps if ddp: dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks val_loss = val_loss.item() print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}") wandb_run.log({ "step": step, "val_loss": val_loss, }) model.train() # evlauate accuracy of the multiple choice tasks (which are quick to run) if last_step or (step > 0 and step % eval_metrics_every == 0): model.eval() metrics = {} with torch.no_grad(), autocast_ctx: # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems) metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems) metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) print0(f"Step {step:05d} | {metrics_str}") wandb_run.log({ "step": step, **metrics, }) model.train() if last_step: break # evaluate the gradient num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen for micro_step in range(grad_accum_steps): train_inputs, train_targets = next(train_iter) with autocast_ctx: loss = model(train_inputs, train_targets) train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() # accumulate the gradient num_tokens += (train_targets >= 0).sum() if ddp: dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks # learning rate scheduler lrm = get_lr_multiplier(step) for opt in optimizers: for group in opt.param_groups: group["lr"] = group["initial_lr"] * lrm # step the optimizers for opt in optimizers: opt.step() model.zero_grad(set_to_none=True) # logging train_loss_item = train_loss.item() num_tokens_item = num_tokens.item() print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}") wandb_run.log({ "step": step, "lrm": lrm, "train_loss": train_loss_item, "num_tokens": num_tokens_item, }) step += 1 # Save the model at the end of the run if master_process: base_dir = get_base_dir() depth = model.config.n_layer # Use model_tag from config if provided, otherwise default to d{depth} output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", run) model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer save_checkpoint( checkpoint_dir, step, model.state_dict(), None, # note: we don't bother to save the optimizer state { "step": step, "val_loss": val_loss, **metrics, "model_config": model_config_kwargs, "tokenizer_name": tokenizer_name, } ) print(f"✅ Saved model checkpoint to {checkpoint_dir}") # Log to report from nanochat.report import get_report get_report(exp_name=run).log(section="Chat SFT", data=[ user_config, # CLI args { "Training rows": len(train_ds), "Number of iterations": num_iterations, "Training loss": train_loss_item, "Validation loss": val_loss, }, ]) # Cleanup wandb_run.finish() compute_cleanup()