From 15e15f3c1457f568508226f3db7e6f2a04724610 Mon Sep 17 00:00:00 2001 From: Chris McCormick Date: Mon, 2 Feb 2026 08:27:29 -0800 Subject: [PATCH] Run command and refs --- log/base_train-gc-fixes.py | 507 +++++++++++++++++++++++++++++++++++++ log/dataloader-gc-fixes.py | 239 +++++++++++++++++ log/speedrun-feb01.sh | 30 +++ 3 files changed, 776 insertions(+) create mode 100644 log/base_train-gc-fixes.py create mode 100644 log/dataloader-gc-fixes.py create mode 100644 log/speedrun-feb01.sh diff --git a/log/base_train-gc-fixes.py b/log/base_train-gc-fixes.py new file mode 100644 index 00000000..9cab2215 --- /dev/null +++ b/log/base_train-gc-fixes.py @@ -0,0 +1,507 @@ +""" +Train model. From root directory of the project, run as: + +python -m scripts.base_train + +or distributed as: + +torchrun --nproc_per_node=8 -m scripts.base_train + +If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: +python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 +""" + +import os +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" +import argparse +import time +import gc +from contextlib import nullcontext + +# GC timing callback to detect if garbage collection is causing training stalls +def _gc_callback(phase, info): + if phase == "start": + _gc_callback.start_time = time.perf_counter() + elif phase == "stop": + duration_ms = (time.perf_counter() - _gc_callback.start_time) * 1000 + if duration_ms > 10: # Only log if GC took >10ms + rank = getattr(_gc_callback, 'rank', '?') + print(f"[GC rank{rank}] gen{info['generation']}: {duration_ms:.1f}ms collected {info.get('collected', '?')} objects") +_gc_callback.start_time = 0 +_gc_callback.rank = '?' # Will be set after compute_init +gc.callbacks.append(_gc_callback) + +import wandb +import torch + +from nanochat.gpt import GPT, GPTConfig +from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops +from nanochat.tokenizer import get_tokenizer, get_token_bytes +from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint +from nanochat.loss_eval import evaluate_bpb +from nanochat.engine import Engine +from nanochat.flash_attention import HAS_FA3 +from scripts.base_eval import evaluate_core +print_banner() + +# ----------------------------------------------------------------------------- +# CLI arguments +parser = argparse.ArgumentParser(description="Pretrain base model") +# Logging +parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") +# Runtime +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +# Model architecture +parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") +parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") +parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") +parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") +parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") +# Training horizon (only one used, in order of precedence) +parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") +parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") +parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") +# Optimization +parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") +parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") +parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") +parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") +parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") +parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") +parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") +parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") +parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") +parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") +parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") +parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") +parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") +# Evaluation +parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") +parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") +parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") +parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") +parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") +parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") +# Output +parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") +args = parser.parse_args() +user_config = vars(args).copy() # for logging +# ----------------------------------------------------------------------------- + +# Compute init +device_type = autodetect_device_type() if args.device_type == "" else args.device_type +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) +master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. +_gc_callback.rank = ddp_rank # Store rank for GC log printouts +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None +get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 +if device_type == "cuda": + gpu_device_name = torch.cuda.get_device_name(0) + gpu_peak_flops = get_peak_flops(gpu_device_name) + print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") +else: + gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS + +# wandb logging init +use_dummy_wandb = args.run == "dummy" or not master_process +if use_dummy_wandb: + wandb_run = DummyWandb() +else: + try: + wandb_run = wandb.init(project="nanochat", name=args.run, config=user_config) + except wandb.errors.UsageError as e: + print0(f"Warning: wandb initialization failed ({e}), logging disabled. Run 'wandb login' to enable.") + wandb_run = DummyWandb() + +# Flash Attention status +if HAS_FA3: + print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") +else: + print0("!" * 80) + print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") + print0("WARNING: Training will be less efficient without FA3") + if args.window_pattern != "L": + print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.") + print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.") + print0("!" * 80) + +# Tokenizer will be useful for evaluation, also we need the vocab size +tokenizer = get_tokenizer() +token_bytes = get_token_bytes(device=device) +vocab_size = tokenizer.get_vocab_size() +print0(f"Vocab size: {vocab_size:,}") + +# Model kwargs are derived from the desired depth of the model +# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division +# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly) +# (For very small depths, this gives a slight "unfair" advantage to models with odd depths) +num_layers = args.depth +base_dim = args.depth * args.aspect_ratio +model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim +num_heads = model_dim // args.head_dim +num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled) +head_dim = model_dim // num_heads +print0(f"num_layers: {num_layers}") +print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})") +print0(f"num_heads: {num_heads}") +print0(f"head_dim: {head_dim}") +print0(f"num_kv_heads: {num_kv_heads}") + +# Optimizer / data / training length related hyperparameters +# figure out the needed gradient accumulation to reach the desired total batch size +tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank +world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks +assert args.total_batch_size % world_tokens_per_fwdbwd == 0 +grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd +print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") +print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") +print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") + +# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19) +batch_lr_scale = 1.0 +reference_batch_size = 2**19 +batch_ratio = args.total_batch_size / reference_batch_size +if batch_ratio != 1.0: + # SGD: linear scaling with batch size is standard (not used in nanochat) + # AdamW: sqrt scaling is standard + # Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer + batch_lr_scale = batch_ratio ** 0.5 + print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})") + +# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio) +weight_decay_scaled = args.weight_decay * (12 / args.depth)**2 +if args.depth != 12: + print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") + +# ----------------------------------------------------------------------------- +# Initialize the Model + +# Create a new model with random weights +model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern) +with torch.device("meta"): + # All tensors are created as meta tensors (they have shape/dtype but no data) + model_config = GPTConfig(**model_config_kwargs) + model = GPT(model_config) +model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data +model.init_weights() # All tensors get initialized + +# If we are resuming, overwrite the model parameters with those of the checkpoint +base_dir = get_base_dir() +output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12 +checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) +resuming = args.resume_from_step != -1 +if resuming: + print0(f"Resuming optimization from step {args.resume_from_step}") + model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank) + model.load_state_dict(model_data, strict=True, assign=True) + del model_data # free up this memory after the copy + +orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) +model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe + +# Detailed parameter counts +param_counts = orig_model.num_scaling_params() +print0(f"Parameter counts:") +for key, value in param_counts.items(): + print0(f"{key:24s}: {value:,}") +num_params = param_counts['total'] +num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026 +num_flops_per_token = model.estimate_flops() +print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") + +# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order) +assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0 +if args.num_iterations > 0: + num_iterations = args.num_iterations + print0(f"Using user-provided number of iterations: {num_iterations:,}") +elif args.target_flops > 0: + # calculate the number of iterations from the target flops + num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size)) + print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") +elif args.target_param_data_ratio > 0: + # calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.) + target_tokens = int(args.target_param_data_ratio * num_scaling_params) + num_iterations = target_tokens // args.total_batch_size + print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") +else: + raise ValueError("No training horizon specified") +total_tokens = args.total_batch_size * num_iterations +print0(f"Total number of training tokens: {total_tokens:,}") +print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 +print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") + +# ----------------------------------------------------------------------------- +# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) +adam_betas = (args.adam_beta1, args.adam_beta2) +optimizer = model.setup_optimizer( + unembedding_lr=args.unembedding_lr * batch_lr_scale, + embedding_lr=args.embedding_lr * batch_lr_scale, + matrix_lr=args.matrix_lr * batch_lr_scale, + weight_decay=weight_decay_scaled, + adam_betas=adam_betas, + scalar_lr=args.scalar_lr * batch_lr_scale, +) + +if resuming: + optimizer.load_state_dict(optimizer_data) + del optimizer_data + +# ----------------------------------------------------------------------------- +# Initialize the DataLoaders for train/val +dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] +train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) +build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) +x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data + +# ----------------------------------------------------------------------------- +# Set up hyperparameter schedulers + +# Learning rate scheduler +def get_lr_multiplier(it): + warmup_iters = round(args.warmup_ratio * num_iterations) + warmdown_iters = round(args.warmdown_ratio * num_iterations) + if it < warmup_iters: + return (it + 1) / warmup_iters + elif it <= num_iterations - warmdown_iters: + return 1.0 + else: + progress = (num_iterations - it) / warmdown_iters + return progress * 1.0 + (1 - progress) * args.final_lr_frac + +# Momentum scheduler for Muon optimizer +def get_muon_momentum(it): + frac = min(it / 300, 1) + momentum = (1 - frac) * 0.85 + frac * 0.95 + return momentum + +# Weight decay scheduler for Muon optimizer (linear to zero over the course of training) +def get_weight_decay(it): + return weight_decay_scaled * (1 - it / num_iterations) + +# ----------------------------------------------------------------------------- +# Loop state (variables updated by the training loop) + +if not resuming: + step = 0 + val_bpb = None # will be set if eval_every > 0 + min_val_bpb = float("inf") + smooth_train_loss = 0 # EMA of training loss + total_training_time = 0 # total wall-clock time of training +else: + step = meta_data["step"] + loop_state = meta_data["loop_state"] + val_bpb = meta_data["val_bpb"] + min_val_bpb = loop_state["min_val_bpb"] + smooth_train_loss = loop_state["smooth_train_loss"] + total_training_time = loop_state["total_training_time"] + +# ----------------------------------------------------------------------------- +# Training loop +while True: + last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end + flops_so_far = num_flops_per_token * args.total_batch_size * step + + # once in a while: evaluate the val bpb (all ranks participate) + if args.eval_every > 0 and (last_step or step % args.eval_every == 0): + model.eval() + val_loader = build_val_loader() + eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) + with autocast_ctx: + val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) + print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}") + if val_bpb < min_val_bpb: + min_val_bpb = val_bpb + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "val/bpb": val_bpb, + }) + model.train() + + # once in a while: estimate the CORE metric (all ranks participate) + # use the original uncompiled model because the inputs keep changing shape + results = {} + if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)): + model.eval() + with autocast_ctx: + results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) + print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "core_metric": results["core_metric"], + "centered_results": results["centered_results"], + }) + model.train() + + # once in a while: sample from the model (only on master process) + # use the original uncompiled model because the inputs keep changing shape + if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)): + model.eval() + prompts = [ + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", + ] + engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation + for prompt in prompts: + tokens = tokenizer(prompt, prepend="<|bos|>") + with autocast_ctx: + sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) + print0(tokenizer.decode(sample[0])) + model.train() + + # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step + if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0): + save_checkpoint( + checkpoint_dir, + step, + orig_model.state_dict(), # model parameters + optimizer.state_dict(), # optimizer state + { # metadata saved as json + "step": step, + "val_bpb": val_bpb, # loss at last step + "model_config": model_config_kwargs, + "user_config": user_config, # inputs to the training script + "device_batch_size": args.device_batch_size, + "max_seq_len": args.max_seq_len, + "dataloader_state_dict": dataloader_state_dict, + "loop_state": { # all loop state (other than step) so that we can resume training + "min_val_bpb": min_val_bpb, + "smooth_train_loss": smooth_train_loss, + "total_training_time": total_training_time, + }, + }, + rank=ddp_rank, + ) + + # termination conditions (TODO: possibly also add loss explosions etc.) + if last_step: + break + + # ------------------------------------------------------------------------- + # single training step + # evaluate the gradient + synchronize() + t0 = time.time() + for micro_step in range(grad_accum_steps): + with autocast_ctx: + loss = model(x, y) + train_loss = loss.detach() # for logging + loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here + loss.backward() + x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + # step the optimizer + lrm = get_lr_multiplier(step) + muon_momentum = get_muon_momentum(step) + muon_weight_decay = get_weight_decay(step) + for group in optimizer.param_groups: + group["lr"] = group["initial_lr"] * lrm + if group['kind'] == 'muon': + group["momentum"] = muon_momentum + group["weight_decay"] = muon_weight_decay + optimizer.step() + model.zero_grad(set_to_none=True) + train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point + synchronize() + t1 = time.time() + dt = t1 - t0 + # ------------------------------------------------------------------------- + + # logging (CPU action only) + ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA + pct_done = 100 * step / num_iterations + tok_per_sec = int(args.total_batch_size / dt) + flops_per_sec = num_flops_per_token * args.total_batch_size / dt + mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) + if step > 10: + total_training_time += dt # only count the time after the first 10 steps + # Calculate ETA based on average time per step (excluding first 10 steps) + steps_done = step - 10 + if steps_done > 0: + avg_time_per_step = total_training_time / steps_done + remaining_steps = num_iterations - step + eta_seconds = remaining_steps * avg_time_per_step + eta_str = f" | eta: {eta_seconds/60:.1f}m" + else: + eta_str = "" + epoch = dataloader_state_dict["epoch"] + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") + if step % 100 == 0: + log_data = { + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "train/loss": debiased_smooth_loss, + "train/lrm": lrm, + "train/dt": dt, + "train/tok_per_sec": tok_per_sec, + "train/mfu": mfu, + "train/epoch": epoch, + } + wandb_run.log(log_data) + + # Set 'first_step_of_run' flag + first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step) + + # state update + step += 1 + + # # TEMP - Bail at 1000 steps for benchmarking. + # if step == 1001: + # print0(f"Elapsed + ETA: {total_training_time + eta_seconds:.0f}s") + # break + + # Help out the garbage collector by flushing garbage and then freezing long-lived objects + # This eliminates random ~500ms pauses during training steps as the GC scans ~millions of objects for cycles + if first_step_of_run: + gc.collect() + gc.freeze() + gc.disable() # nuclear option: disable GC for the run + elif step % 2000 == 0: + gc.collect() # manual GC to keep memory usage in check for very long runs + +# print a few more stats +print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") +print0(f"Total training time: {total_training_time/60:.2f}m") +if val_bpb is not None: + print0(f"Minimum validation bpb: {min_val_bpb:.6f}") + +# Log to report +from nanochat.report import get_report +get_report().log(section="Base model training", data=[ + user_config, # CLI args + { # stats about the training setup + "Number of parameters": num_params, + "Number of FLOPs per token": f"{num_flops_per_token:e}", + "Calculated number of iterations": num_iterations, + "Number of training tokens": total_tokens, + "Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params, + "DDP world size": ddp_world_size, + "warmup_ratio": args.warmup_ratio, + "warmdown_ratio": args.warmdown_ratio, + "final_lr_frac": args.final_lr_frac, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb if val_bpb is not None else None, + "Final validation bpb": val_bpb, + "CORE metric estimate": results.get("core_metric", None), + "MFU %": f"{mfu:.2f}%", + "Total training flops": f"{flops_so_far:e}", + "Total training time": f"{total_training_time/60:.2f}m", + "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", + } +]) + +# cleanup +wandb_run.finish() # wandb run finish +compute_cleanup() diff --git a/log/dataloader-gc-fixes.py b/log/dataloader-gc-fixes.py new file mode 100644 index 00000000..544d9901 --- /dev/null +++ b/log/dataloader-gc-fixes.py @@ -0,0 +1,239 @@ +""" +Distributed dataloaders for pretraining. + +Two implementations are provided: + +1. Original (tokenizing_distributed_data_loader): + - Streams tokens into a flat buffer, reshapes to (B, T) + - Rows may start mid-document (no guaranteed BOS at position 0) + - 100% token utilization, simple and efficient + +2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit): + - Every row starts with BOS token + - Documents packed using best-fit algorithm to minimize cropping + - When no document fits remaining space, crops a document to fill exactly + - 100% utilization (no padding), ~35% tokens cropped at T=2048 + +The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that +there are fewer "confusing" tokens in the train/val batches as every token can +now attend back to the BOS token and sees the full context of the document. +(2) is the new default if you have enough data. +Fallback to (1) if you have very limited data AND long documents. +""" + +import torch +import pyarrow.parquet as pq + +from nanochat.common import get_dist_info +from nanochat.dataset import list_parquet_files + +def _document_batches(split, resume_state_dict, tokenizer_batch_size): + """ + Infinite iterator over document batches (list of text strings) from parquet files. + + Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch)) + where text_batch is a list of document strings, indices track position for resumption, + and epoch counts how many times we've cycled through the dataset (starts at 1). + """ + ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + + parquet_paths = list_parquet_files() + assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?" + parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] + + resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 + resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None + resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1 + first_pass = True + pq_idx = resume_pq_idx + epoch = resume_epoch + + while True: # iterate infinitely (multi-epoch) + pq_idx = resume_pq_idx if first_pass else 0 + while pq_idx < len(parquet_paths): + filepath = parquet_paths[pq_idx] + pf = pq.ParquetFile(filepath) + # Start from resume point if resuming on same file, otherwise from DDP rank + if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx): + base_idx = resume_rg_idx // ddp_world_size + base_idx += 1 # advance by 1 so we don't repeat data after resuming + rg_idx = base_idx * ddp_world_size + ddp_rank + if rg_idx >= pf.num_row_groups: + pq_idx += 1 + continue + resume_rg_idx = None # only do this once + else: + rg_idx = ddp_rank + while rg_idx < pf.num_row_groups: + rg = pf.read_row_group(rg_idx) + batch = rg.column('text').to_pylist() + for i in range(0, len(batch), tokenizer_batch_size): + yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch) + rg_idx += ddp_world_size + pq_idx += 1 + first_pass = False + epoch += 1 + + +def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None): + """ + Stream pretraining text from parquet files, tokenize, yield training batches. + + This is the original dataloader that streams tokens into a flat buffer and reshapes. + Rows may start mid-document (no guaranteed BOS at position 0). + + Supports approximate resume via state_dict. + """ + assert split in ["train", "val"], "split must be 'train' or 'val'" + + batches = _document_batches(split, resume_state_dict, tokenizer_batch_size) + needed_tokens = B * T + 1 # +1 for target at last position + bos_token = tokenizer.get_bos_token_id() + token_buffer = [] + pq_idx, rg_idx, epoch = 0, 0, 1 + + while True: + + # Accumulate enough tokens + while len(token_buffer) < needed_tokens: + doc_batch, (pq_idx, rg_idx, epoch) = next(batches) + token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) + for tokens in token_lists: + token_buffer.extend(tokens) + tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token) + token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over + + # Package tokens into inputs and targets, yield + use_cuda = device == "cuda" + scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda) + inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda) + targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda) + yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch} + + +def tokenizing_distributed_data_loader(*args, **kwargs): + """Helper that omits state_dict from yields.""" + for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs): + yield inputs, targets + + +def tokenizing_distributed_data_loader_with_state_bos_bestfit( + tokenizer, B, T, split, + tokenizer_threads=4, tokenizer_batch_size=128, + device="cuda", resume_state_dict=None, + buffer_size=1000 +): + """ + BOS-aligned dataloader with Best-Fit Cropping. + + Reduces token waste compared to simple greedy cropping by searching a buffer + for documents that fit well, while maintaining 100% utilization (no padding). + + Algorithm for each row: + 1. From buffered docs, pick the LARGEST doc that fits entirely + 2. Repeat until no doc fits + 3. When nothing fits, crop a doc to fill remaining space exactly + + Key properties: + - Every row starts with BOS + - 100% utilization (no padding, every token is trained on) + - Approximately 35% of all tokens are discarded due to cropping + """ + assert split in ["train", "val"], "split must be 'train' or 'val'" + + row_capacity = T + 1 + batches = _document_batches(split, resume_state_dict, tokenizer_batch_size) + bos_token = tokenizer.get_bos_token_id() + pq_idx, rg_idx, epoch = 0, 0, 1 + + # Token pool: single tensor holding all buffered tokens + # Documents tracked as (start, length) tuples + pool = torch.empty(buffer_size * 512, dtype=torch.long) + pool_end = 0 + docs = [] # [(start, length), ...] + + def compact_pool(): + """Shift active documents to front of pool, reclaiming space.""" + nonlocal pool_end + if not docs: + pool_end = 0 + return + write_pos = 0 + for i, (start, length) in enumerate(docs): + if start != write_pos: + pool[write_pos:write_pos + length] = pool[start:start + length].clone() + docs[i] = (write_pos, length) + write_pos += length + pool_end = write_pos + + def refill_buffer(): + """Retrieve more docs and add them to the pool""" + nonlocal pq_idx, rg_idx, epoch, pool, pool_end + doc_batch, (pq_idx, rg_idx, epoch) = next(batches) + token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) + # Number of new tokens to store + total_new = sum(len(t) for t in token_lists) + # If there's not enough space at the end, + if pool_end + total_new > pool.size(0): + compact_pool() # Try compacting first. + # If still not enough, + if pool_end + total_new > pool.size(0): + # Allocate a new, larger pool. + new_size = max(pool.size(0) * 2, pool_end + total_new) + new_pool = torch.empty(new_size, dtype=torch.long) + new_pool[:pool_end] = pool[:pool_end] + pool = new_pool + # Write tokens to pool + for tokens in token_lists: + n = len(tokens) + pool[pool_end:pool_end + n] = torch.tensor(tokens, dtype=torch.long) + docs.append((pool_end, n)) + pool_end += n + + # Pre-allocate buffers once + use_cuda = device == "cuda" + row_buffer = torch.empty((B, row_capacity), dtype=torch.long) + inputs = torch.empty((B, T), dtype=torch.long, device=device) + targets = torch.empty((B, T), dtype=torch.long, device=device) + + while True: + for row_idx in range(B): + col = 0 + while col < row_capacity: + # Ensure buffer has documents + while len(docs) < buffer_size: + refill_buffer() + + remaining = row_capacity - col + + # Find largest doc that fits entirely + best_idx = -1 + best_len = 0 + for i, (start, length) in enumerate(docs): + if length <= remaining and length > best_len: + best_idx = i + best_len = length + + if best_idx >= 0: + start, length = docs.pop(best_idx) + row_buffer[row_idx, col:col + length] = pool[start:start + length] + col += length + else: + # No doc fits - crop shortest to fill remaining + shortest_idx = min(range(len(docs)), key=lambda i: docs[i][1]) + start, length = docs.pop(shortest_idx) + row_buffer[row_idx, col:col + remaining] = pool[start:start + remaining] + col += remaining + + # Copy to GPU + inputs.copy_(row_buffer[:, :-1], non_blocking=use_cuda) + targets.copy_(row_buffer[:, 1:], non_blocking=use_cuda) + + state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch} + yield inputs, targets, state_dict + +def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs): + """Helper that omits state_dict from yields.""" + for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs): + yield inputs, targets + \ No newline at end of file diff --git a/log/speedrun-feb01.sh b/log/speedrun-feb01.sh new file mode 100644 index 00000000..028d5f09 --- /dev/null +++ b/log/speedrun-feb01.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +# Default intermediate artifacts directory is in ~/.cache/nanochat +export OMP_NUM_THREADS=1 +export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" +mkdir -p $NANOCHAT_BASE_DIR + +# ----------------------------------------------------------------------------- +# Python venv setup with uv + +# install uv (if not already installed) +command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh +# create a .venv local virtual environment (if it doesn't exist) +[ -d ".venv" ] || uv venv +# install the repo dependencies +uv sync --extra gpu +# activate venv so that `python` uses the project's venv instead of system python +source .venv/bin/activate + +( cat ./nanochat/gpt.py; cat ./nanochat/optim.py; cat ./nanochat/dataloader.py; cat ./scripts/base_train.py; echo -e "\n\n===== TRAINING OUTPUT =====\n\n"; OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --depth=24 \ + --run=d24-feb01 \ + --model-tag=d24_feb01 \ + --device-batch-size=16 \ + --sample-every=-1 \ + --save-every=-1 \ + --core-metric-max-per-task=-1 \ + --core-metric-every=3000 \ + --target-param-data-ratio=12 ) \ + 2>&1 | tee ./logs/speedrun_d24_feb01-rope_chunk_mlp_lr_1x2x.log