""" Loads a checkpoint, and: - Evaluates the loss on a larger chunk of train/val splits - Samples from the model Example run as: torchrun --standalone --nproc_per_node=8 -m scripts.base_loss """ import argparse from contextlib import nullcontext import torch from nanochat.checkpoint_manager import load_model from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type from nanochat.dataloader import tokenizing_distributed_data_loader from nanochat.tokenizer import get_token_bytes from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine # CLI arguments parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model") parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size") parser.add_argument("--split_tokens", type=int, default=20*524288, help="number of tokens to evaluate per split") parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory") parser.add_argument("--model_step", type=int, default=None, help="model step to load") parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") args = parser.parse_args() # Load the base model and the tokenizer device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step) sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() # Evaluate the loss on each split tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step" steps = args.split_tokens // tokens_per_step token_bytes = get_token_bytes(device=device) bpb_results = {} for split_name in ["train", "val"]: loader = tokenizing_distributed_data_loader(args.device_batch_size, sequence_len, split_name, device=device) with autocast_ctx: bpb = evaluate_bpb(model, loader, steps, token_bytes) print0(f"{split_name} bpb: {bpb:.4f}") bpb_results[split_name] = bpb # Master process also samples from the model samples = [] if ddp_rank == 0: prompts = [ "The capital of France is", "The chemical symbol of gold is", "If yesterday was Friday, then tomorrow will be", "The opposite of hot is", "The planets of the solar system are:", "My favorite color is", "If 5*x + 3 = 13, then x is", ] engine = Engine(model, tokenizer) for prompt in prompts: tokens = tokenizer(prompt, prepend="<|bos|>") with autocast_ctx: sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) sample_str = tokenizer.decode(sample[0]) print0(sample_str) samples.append(sample_str) # Log to report from nanochat.report import get_report get_report().log(section="Base model loss", data=[ { "train bpb": bpb_results["train"], "val bpb": bpb_results["val"], }, {f"sample {i}": sample for i, sample in enumerate(samples)}, ]) # Cleanup compute_cleanup()