""" Train model. From root directory of the project, run as: python -m scripts.base_train or distributed as: torchrun --nproc_per_node=8 -m scripts.base_train If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 """ import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import gc import json import time import math import argparse from dataclasses import asdict from contextlib import nullcontext, contextmanager import wandb import torch from nanochat.gpt import GPT, GPTConfig from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine from nanochat.flash_attention import HAS_FA3 from scripts.base_eval import evaluate_core print_banner() # ----------------------------------------------------------------------------- # CLI arguments parser = argparse.ArgumentParser(description="Pretrain base model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") # Optimization parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.") parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") # Evaluation parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on") parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") # Output parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") args = parser.parse_args() user_config = vars(args).copy() # for logging # ----------------------------------------------------------------------------- # Compute init and wandb logging device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 if device_type == "cuda": gpu_device_name = torch.cuda.get_device_name(0) gpu_peak_flops = get_peak_flops(gpu_device_name) print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") else: gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) # Flash Attention status if HAS_FA3: print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") else: print0("!" * 80) print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") print0("WARNING: Training will be less efficient without FA3") if args.window_pattern != "L": print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.") print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.") print0("!" * 80) # ----------------------------------------------------------------------------- # Tokenizer will be useful for evaluation and also we need the vocab size to init the model tokenizer = get_tokenizer() token_bytes = get_token_bytes(device=device) vocab_size = tokenizer.get_vocab_size() print0(f"Vocab size: {vocab_size:,}") # ----------------------------------------------------------------------------- # Initialize the Model def build_model_meta(depth): """Build a model on meta device for a given depth (shapes/dtypes only, no data).""" # Model dim is nudged up to nearest multiple of head_dim for clean division # (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly) base_dim = depth * args.aspect_ratio model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim num_heads = model_dim // args.head_dim config = GPTConfig( sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, window_pattern=args.window_pattern, ) with torch.device("meta"): model_meta = GPT(config) return model_meta # Build the model, move to device, init the weights model = build_model_meta(args.depth) # 1) Build on meta device (only shapes/dtypes, no data) model_config = model.config model_config_kwargs = asdict(model_config) print0(f"Model config:\n{json.dumps(model_config_kwargs, indent=2)}") model.to_empty(device=device) # 2) All tensors get storage on target device but with uninitialized (garbage) data model.init_weights() # 3) All tensors get initialized # If we are resuming, overwrite the model parameters with those of the checkpoint base_dir = get_base_dir() output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12 checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) resuming = args.resume_from_step != -1 if resuming: print0(f"Resuming optimization from step {args.resume_from_step}") model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank) model.load_state_dict(model_data, strict=True, assign=True) del model_data # free up this memory after the copy # ----------------------------------------------------------------------------- # FP8 training initialization and management (this has to be done before torch.compile) # Convert Linear layers to Float8Linear if --fp8 is set if args.fp8: if device_type != "cuda": print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag") else: # our custom fp8 is simpler than torchao, written for exact API compatibility from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training # from torchao.float8 import Float8LinearConfig, convert_to_float8_training import torch.nn as nn # Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough def fp8_module_filter(mod: nn.Module, fqn: str) -> bool: if not isinstance(mod, nn.Linear): return False if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: return False if min(mod.in_features, mod.out_features) < 128: return False return True fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe) num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter) num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) num_skipped = num_linear - num_fp8 print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)") # Context manager to temporarily disable FP8 so that model evaluation remains in BF16 @contextmanager def disable_fp8(model): """Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation. CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead, we swap out Float8Linear modules entirely and restore them after. """ import torch.nn as nn # Find all Float8Linear modules and their locations fp8_locations = [] # list of (parent_module, attr_name, fp8_module) for name, module in model.named_modules(): if 'Float8' in type(module).__name__: if '.' in name: parent_name, attr_name = name.rsplit('.', 1) parent = model.get_submodule(parent_name) else: parent = model attr_name = name fp8_locations.append((parent, attr_name, module)) if not fp8_locations: yield # No FP8 modules, nothing to do return # Swap Float8Linear -> nn.Linear (shares the same weight tensor, no copy) for parent, attr_name, fp8_module in fp8_locations: linear = nn.Linear( fp8_module.in_features, fp8_module.out_features, bias=fp8_module.bias is not None, device=fp8_module.weight.device, dtype=fp8_module.weight.dtype, ) linear.weight = fp8_module.weight # share, don't copy if fp8_module.bias is not None: linear.bias = fp8_module.bias setattr(parent, attr_name, linear) try: yield finally: # Restore Float8Linear modules for parent, attr_name, fp8_module in fp8_locations: setattr(parent, attr_name, fp8_module) # ----------------------------------------------------------------------------- # Compile the model orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe # ----------------------------------------------------------------------------- # Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay. # Get the parameter counts of our model param_counts = model.num_scaling_params() print0(f"Parameter counts:") for key, value in param_counts.items(): print0(f"{key:24s}: {value:,}") num_params = param_counts['total'] num_flops_per_token = model.estimate_flops() print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") # 1) Use scaling laws to determine the optimal training horizon in tokens # The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis). # We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params def get_scaling_params(m): # As for which params to use exactly, transformer matrices + lm_head gives cleanest scaling laws (see dev/LOG.md Jan 27, 2026) params_counts = m.num_scaling_params() scaling_params = params_counts['transformer_matrices'] + params_counts['lm_head'] return scaling_params num_scaling_params = get_scaling_params(model) target_tokens = int(args.target_param_data_ratio * num_scaling_params) # optimal tokens for the model we are about to train # Our reference model is d12, this is where a lot of hyperparameters are tuned and then transfered to higher depths (muP style) d12_ref = build_model_meta(12) # creates the model on meta device D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) # compute-optimal d12 training horizon in tokens (measured empirically) B_REF = 2**19 # optimal batch size at d12 ~= 524,288 tokens (measured empirically) # 2) Now that we have the token horizon, we can calculate the optimal batch size # We follow the Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738 # The optimal batch size grows as approximately D^0.383, so e.g. if D doubles from d12 to d24, B should grow by 2^0.383 ≈ 1.3x. total_batch_size = args.total_batch_size # user-provided override is possible if total_batch_size == -1: batch_size_ratio = target_tokens / D_REF predicted_batch_size = B_REF * batch_size_ratio ** 0.383 total_batch_size = 2 ** round(math.log2(predicted_batch_size)) # clamp to nearest power of 2 for efficiency print0(f"Auto-computed optimal batch size: {total_batch_size:,} tokens") # 3) Knowing the batch size, we can now calculate a learning rate correction (bigger batch size allows higher learning rates) batch_lr_scale = 1.0 batch_ratio = total_batch_size / B_REF # B/B_ref if batch_ratio != 1.0: # SGD: linear scaling with batch size is standard (not used in nanochat) # AdamW: sqrt scaling is standard: η ∝ √(B/B_ref) # Muon: we will use the same scaling for Muon as for AdamW: η ∝ √(B/B_ref) (not studied carefully, assumption!) batch_lr_scale = batch_ratio ** 0.5 # η ∝ √(B/B_ref) print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {B_REF:,})") # 4) Knowing the batch size and the token horizon, we can now calculate the appropriate weight decay scaling # We adopt the T_epoch framework from https://arxiv.org/abs/2405.13698 # Central idea of the paper is that T_epoch = B/(η·λ·D) should remain constant. # Above, we used learning rate scaling η ∝ √(B/B_ref). So it's a matter of ~10 lines of math to derive that to keep T_epoch constant, we need: # λ = λ_ref · √(B/B_ref) · (D_ref/D) # Note that these papers study AdamW, *not* Muon. We are blindly following AdamW theory for scaling hoping it ~works for Muon too. weight_decay_scaled = args.weight_decay * math.sqrt(total_batch_size / B_REF) * (D_REF / target_tokens) if weight_decay_scaled != args.weight_decay: print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") # ----------------------------------------------------------------------------- # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) optimizer = model.setup_optimizer( # AdamW hyperparameters unembedding_lr=args.unembedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale, scalar_lr=args.scalar_lr * batch_lr_scale, adam_betas=(args.adam_beta1, args.adam_beta2), # Muon hyperparameters matrix_lr=args.matrix_lr * batch_lr_scale, weight_decay=weight_decay_scaled, ) if resuming: optimizer.load_state_dict(optimizer_data) del optimizer_data # ----------------------------------------------------------------------------- # Initialize the DataLoaders for train/val dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data # ----------------------------------------------------------------------------- # Calculate the number of iterations we will train for and set up the various schedulers # num_iterations: either it is given, or from target flops, or from target data:param ratio (in that order) assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0 if args.num_iterations > 0: # Override num_iterations to a specific value if given num_iterations = args.num_iterations print0(f"Using user-provided number of iterations: {num_iterations:,}") elif args.target_flops > 0: # Calculate the number of iterations from the target flops (used in scaling laws analysis, e.g. runs/scaling_laws.sh) num_iterations = round(args.target_flops / (num_flops_per_token * total_batch_size)) print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") elif args.target_param_data_ratio > 0: # Calculate the number of iterations from the target param data ratio (the most common use case) num_iterations = target_tokens // total_batch_size print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") else: raise ValueError("No training horizon specified") total_tokens = total_batch_size * num_iterations # the actual number of tokens we will train for print0(f"Total number of training tokens: {total_tokens:,}") print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") # Learning rate schedule (linear warmup, constant, linear warmdown) def get_lr_multiplier(it): warmup_iters = round(args.warmup_ratio * num_iterations) warmdown_iters = round(args.warmdown_ratio * num_iterations) if it < warmup_iters: return (it + 1) / warmup_iters elif it <= num_iterations - warmdown_iters: return 1.0 else: progress = (num_iterations - it) / warmdown_iters return progress * 1.0 + (1 - progress) * args.final_lr_frac # Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps) def get_muon_momentum(it): frac = min(it / 300, 1) momentum = (1 - frac) * 0.85 + frac * 0.95 return momentum # Weight decay scheduler for Muon optimizer (linearly decays to zero over the course of training) def get_weight_decay(it): return weight_decay_scaled * (1 - it / num_iterations) # ----------------------------------------------------------------------------- # Training loop # Loop state (variables updated by the training loop) if not resuming: step = 0 val_bpb = None # will be set if eval_every > 0 min_val_bpb = float("inf") smooth_train_loss = 0 # EMA of training loss total_training_time = 0 # total wall-clock time of training else: step = meta_data["step"] loop_state = meta_data["loop_state"] val_bpb = meta_data["val_bpb"] min_val_bpb = loop_state["min_val_bpb"] smooth_train_loss = loop_state["smooth_train_loss"] total_training_time = loop_state["total_training_time"] # Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks assert total_batch_size % world_tokens_per_fwdbwd == 0 grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") # Go! while True: last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end flops_so_far = num_flops_per_token * total_batch_size * step # once in a while: evaluate the val bpb (all ranks participate) if args.eval_every > 0 and (last_step or step % args.eval_every == 0): model.eval() val_loader = build_val_loader() eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) with disable_fp8(model), autocast_ctx: val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}") if val_bpb < min_val_bpb: min_val_bpb = val_bpb wandb_run.log({ "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, "val/bpb": val_bpb, }) model.train() # once in a while: estimate the CORE metric (all ranks participate) # use the original uncompiled model because the inputs keep changing shape # disable FP8 for evaluation to use BF16 for more consistent/accurate results results = {} if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)): model.eval() with disable_fp8(orig_model), autocast_ctx: results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") wandb_run.log({ "step": step, "total_training_flops": flops_so_far, "core_metric": results["core_metric"], "centered_results": results["centered_results"], }) model.train() # once in a while: sample from the model (only on master process) # use the original uncompiled model because the inputs keep changing shape if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)): model.eval() prompts = [ "The capital of France is", "The chemical symbol of gold is", "If yesterday was Friday, then tomorrow will be", "The opposite of hot is", "The planets of the solar system are:", "My favorite color is", "If 5*x + 3 = 13, then x is", ] engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation for prompt in prompts: tokens = tokenizer(prompt, prepend="<|bos|>") with disable_fp8(orig_model), autocast_ctx: sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) print0(tokenizer.decode(sample[0])) model.train() # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0): save_checkpoint( checkpoint_dir, step, orig_model.state_dict(), # model parameters optimizer.state_dict(), # optimizer state { # metadata saved as json "step": step, "val_bpb": val_bpb, # loss at last step "model_config": model_config_kwargs, "user_config": user_config, # inputs to the training script "device_batch_size": args.device_batch_size, "max_seq_len": args.max_seq_len, "total_batch_size": total_batch_size, "dataloader_state_dict": dataloader_state_dict, "loop_state": { # all loop state (other than step) so that we can resume training "min_val_bpb": min_val_bpb, "smooth_train_loss": smooth_train_loss, "total_training_time": total_training_time, }, }, rank=ddp_rank, ) # termination conditions (TODO: possibly also add loss explosions etc.) if last_step: break # ------------------------------------------------------------------------- # single training step # evaluate the gradient synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): with autocast_ctx: loss = model(x, y) train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward # step the optimizer lrm = get_lr_multiplier(step) muon_momentum = get_muon_momentum(step) muon_weight_decay = get_weight_decay(step) for group in optimizer.param_groups: group["lr"] = group["initial_lr"] * lrm if group['kind'] == 'muon': group["momentum"] = muon_momentum group["weight_decay"] = muon_weight_decay optimizer.step() model.zero_grad(set_to_none=True) train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point synchronize() t1 = time.time() dt = t1 - t0 # ------------------------------------------------------------------------- # logging (CPU action only) ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA pct_done = 100 * step / num_iterations tok_per_sec = int(total_batch_size / dt) flops_per_sec = num_flops_per_token * total_batch_size / dt mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) if step > 10: total_training_time += dt # only count the time after the first 10 steps # Calculate ETA based on average time per step (excluding first 10 steps) steps_done = step - 10 if steps_done > 0: avg_time_per_step = total_training_time / steps_done remaining_steps = num_iterations - step eta_seconds = remaining_steps * avg_time_per_step eta_str = f" | eta: {eta_seconds/60:.1f}m" else: eta_str = "" epoch = dataloader_state_dict["epoch"] print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") if step % 100 == 0: log_data = { "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, "train/loss": debiased_smooth_loss, "train/lrm": lrm, "train/dt": dt, "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, "train/epoch": epoch, } wandb_run.log(log_data) # state update first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step) step += 1 # The garbage collector is sadly a little bit overactive and for some poorly understood reason, # it spends ~500ms scanning for cycles quite frequently, just to end up cleaning up very few tiny objects each time. # So we manually manage and help it out here if first_step_of_run: gc.collect() # manually collect a lot of garbage from setup gc.freeze() # immediately freeze all currently surviving objects and exclude them from GC gc.disable() # nuclear intervention here: disable GC entirely except Exception: elif step % 5000 == 0: # every 5000 steps... gc.collect() # manually collect, just to be safe for very, very long runs # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") if val_bpb is not None: print0(f"Minimum validation bpb: {min_val_bpb:.6f}") # Log to report from nanochat.report import get_report get_report().log(section="Base model training", data=[ user_config, # CLI args { # stats about the training setup "Number of parameters": num_params, "Number of FLOPs per token": f"{num_flops_per_token:e}", "Calculated number of iterations": num_iterations, "Number of training tokens": total_tokens, "Tokens : Scaling params ratio": total_batch_size * num_iterations / num_scaling_params, "DDP world size": ddp_world_size, "warmup_ratio": args.warmup_ratio, "warmdown_ratio": args.warmdown_ratio, "final_lr_frac": args.final_lr_frac, }, { # stats about training outcomes "Minimum validation bpb": min_val_bpb if val_bpb is not None else None, "Final validation bpb": val_bpb, "CORE metric estimate": results.get("core_metric", None), "MFU %": f"{mfu:.2f}%", "Total training flops": f"{flops_so_far:e}", "Total training time": f"{total_training_time/60:.2f}m", "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", } ]) # cleanup wandb_run.finish() # wandb run finish compute_cleanup()