diff --git a/log/base_train-feb01.py b/log/base_train-feb01.py deleted file mode 100644 index 9cab221..0000000 --- a/log/base_train-feb01.py +++ /dev/null @@ -1,507 +0,0 @@ -""" -Train model. From root directory of the project, run as: - -python -m scripts.base_train - -or distributed as: - -torchrun --nproc_per_node=8 -m scripts.base_train - -If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: -python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 -""" - -import os -os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" -import argparse -import time -import gc -from contextlib import nullcontext - -# GC timing callback to detect if garbage collection is causing training stalls -def _gc_callback(phase, info): - if phase == "start": - _gc_callback.start_time = time.perf_counter() - elif phase == "stop": - duration_ms = (time.perf_counter() - _gc_callback.start_time) * 1000 - if duration_ms > 10: # Only log if GC took >10ms - rank = getattr(_gc_callback, 'rank', '?') - print(f"[GC rank{rank}] gen{info['generation']}: {duration_ms:.1f}ms collected {info.get('collected', '?')} objects") -_gc_callback.start_time = 0 -_gc_callback.rank = '?' # Will be set after compute_init -gc.callbacks.append(_gc_callback) - -import wandb -import torch - -from nanochat.gpt import GPT, GPTConfig -from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops -from nanochat.tokenizer import get_tokenizer, get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint -from nanochat.loss_eval import evaluate_bpb -from nanochat.engine import Engine -from nanochat.flash_attention import HAS_FA3 -from scripts.base_eval import evaluate_core -print_banner() - -# ----------------------------------------------------------------------------- -# CLI arguments -parser = argparse.ArgumentParser(description="Pretrain base model") -# Logging -parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") -# Runtime -parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") -# Model architecture -parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") -parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") -parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") -parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") -parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") -# Training horizon (only one used, in order of precedence) -parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") -parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") -parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") -# Optimization -parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") -parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") -parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") -parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") -parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") -parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") -parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") -parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") -parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") -parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") -# Evaluation -parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") -parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") -parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") -parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") -parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") -# Output -parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") -args = parser.parse_args() -user_config = vars(args).copy() # for logging -# ----------------------------------------------------------------------------- - -# Compute init -device_type = autodetect_device_type() if args.device_type == "" else args.device_type -ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) -master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. -_gc_callback.rank = ddp_rank # Store rank for GC log printouts -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() -synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None -get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 -if device_type == "cuda": - gpu_device_name = torch.cuda.get_device_name(0) - gpu_peak_flops = get_peak_flops(gpu_device_name) - print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") -else: - gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS - -# wandb logging init -use_dummy_wandb = args.run == "dummy" or not master_process -if use_dummy_wandb: - wandb_run = DummyWandb() -else: - try: - wandb_run = wandb.init(project="nanochat", name=args.run, config=user_config) - except wandb.errors.UsageError as e: - print0(f"Warning: wandb initialization failed ({e}), logging disabled. Run 'wandb login' to enable.") - wandb_run = DummyWandb() - -# Flash Attention status -if HAS_FA3: - print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") -else: - print0("!" * 80) - print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") - print0("WARNING: Training will be less efficient without FA3") - if args.window_pattern != "L": - print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.") - print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.") - print0("!" * 80) - -# Tokenizer will be useful for evaluation, also we need the vocab size -tokenizer = get_tokenizer() -token_bytes = get_token_bytes(device=device) -vocab_size = tokenizer.get_vocab_size() -print0(f"Vocab size: {vocab_size:,}") - -# Model kwargs are derived from the desired depth of the model -# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division -# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly) -# (For very small depths, this gives a slight "unfair" advantage to models with odd depths) -num_layers = args.depth -base_dim = args.depth * args.aspect_ratio -model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim -num_heads = model_dim // args.head_dim -num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled) -head_dim = model_dim // num_heads -print0(f"num_layers: {num_layers}") -print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})") -print0(f"num_heads: {num_heads}") -print0(f"head_dim: {head_dim}") -print0(f"num_kv_heads: {num_kv_heads}") - -# Optimizer / data / training length related hyperparameters -# figure out the needed gradient accumulation to reach the desired total batch size -tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank -world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert args.total_batch_size % world_tokens_per_fwdbwd == 0 -grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd -print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") -print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") -print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") - -# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19) -batch_lr_scale = 1.0 -reference_batch_size = 2**19 -batch_ratio = args.total_batch_size / reference_batch_size -if batch_ratio != 1.0: - # SGD: linear scaling with batch size is standard (not used in nanochat) - # AdamW: sqrt scaling is standard - # Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer - batch_lr_scale = batch_ratio ** 0.5 - print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})") - -# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio) -weight_decay_scaled = args.weight_decay * (12 / args.depth)**2 -if args.depth != 12: - print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") - -# ----------------------------------------------------------------------------- -# Initialize the Model - -# Create a new model with random weights -model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern) -with torch.device("meta"): - # All tensors are created as meta tensors (they have shape/dtype but no data) - model_config = GPTConfig(**model_config_kwargs) - model = GPT(model_config) -model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data -model.init_weights() # All tensors get initialized - -# If we are resuming, overwrite the model parameters with those of the checkpoint -base_dir = get_base_dir() -output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12 -checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) -resuming = args.resume_from_step != -1 -if resuming: - print0(f"Resuming optimization from step {args.resume_from_step}") - model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank) - model.load_state_dict(model_data, strict=True, assign=True) - del model_data # free up this memory after the copy - -orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe - -# Detailed parameter counts -param_counts = orig_model.num_scaling_params() -print0(f"Parameter counts:") -for key, value in param_counts.items(): - print0(f"{key:24s}: {value:,}") -num_params = param_counts['total'] -num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026 -num_flops_per_token = model.estimate_flops() -print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") - -# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order) -assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0 -if args.num_iterations > 0: - num_iterations = args.num_iterations - print0(f"Using user-provided number of iterations: {num_iterations:,}") -elif args.target_flops > 0: - # calculate the number of iterations from the target flops - num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size)) - print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") -elif args.target_param_data_ratio > 0: - # calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.) - target_tokens = int(args.target_param_data_ratio * num_scaling_params) - num_iterations = target_tokens // args.total_batch_size - print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") -else: - raise ValueError("No training horizon specified") -total_tokens = args.total_batch_size * num_iterations -print0(f"Total number of training tokens: {total_tokens:,}") -print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 -print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") - -# ----------------------------------------------------------------------------- -# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) -adam_betas = (args.adam_beta1, args.adam_beta2) -optimizer = model.setup_optimizer( - unembedding_lr=args.unembedding_lr * batch_lr_scale, - embedding_lr=args.embedding_lr * batch_lr_scale, - matrix_lr=args.matrix_lr * batch_lr_scale, - weight_decay=weight_decay_scaled, - adam_betas=adam_betas, - scalar_lr=args.scalar_lr * batch_lr_scale, -) - -if resuming: - optimizer.load_state_dict(optimizer_data) - del optimizer_data - -# ----------------------------------------------------------------------------- -# Initialize the DataLoaders for train/val -dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] -train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) -build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) -x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data - -# ----------------------------------------------------------------------------- -# Set up hyperparameter schedulers - -# Learning rate scheduler -def get_lr_multiplier(it): - warmup_iters = round(args.warmup_ratio * num_iterations) - warmdown_iters = round(args.warmdown_ratio * num_iterations) - if it < warmup_iters: - return (it + 1) / warmup_iters - elif it <= num_iterations - warmdown_iters: - return 1.0 - else: - progress = (num_iterations - it) / warmdown_iters - return progress * 1.0 + (1 - progress) * args.final_lr_frac - -# Momentum scheduler for Muon optimizer -def get_muon_momentum(it): - frac = min(it / 300, 1) - momentum = (1 - frac) * 0.85 + frac * 0.95 - return momentum - -# Weight decay scheduler for Muon optimizer (linear to zero over the course of training) -def get_weight_decay(it): - return weight_decay_scaled * (1 - it / num_iterations) - -# ----------------------------------------------------------------------------- -# Loop state (variables updated by the training loop) - -if not resuming: - step = 0 - val_bpb = None # will be set if eval_every > 0 - min_val_bpb = float("inf") - smooth_train_loss = 0 # EMA of training loss - total_training_time = 0 # total wall-clock time of training -else: - step = meta_data["step"] - loop_state = meta_data["loop_state"] - val_bpb = meta_data["val_bpb"] - min_val_bpb = loop_state["min_val_bpb"] - smooth_train_loss = loop_state["smooth_train_loss"] - total_training_time = loop_state["total_training_time"] - -# ----------------------------------------------------------------------------- -# Training loop -while True: - last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end - flops_so_far = num_flops_per_token * args.total_batch_size * step - - # once in a while: evaluate the val bpb (all ranks participate) - if args.eval_every > 0 and (last_step or step % args.eval_every == 0): - model.eval() - val_loader = build_val_loader() - eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) - with autocast_ctx: - val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) - print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}") - if val_bpb < min_val_bpb: - min_val_bpb = val_bpb - wandb_run.log({ - "step": step, - "total_training_flops": flops_so_far, - "total_training_time": total_training_time, - "val/bpb": val_bpb, - }) - model.train() - - # once in a while: estimate the CORE metric (all ranks participate) - # use the original uncompiled model because the inputs keep changing shape - results = {} - if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)): - model.eval() - with autocast_ctx: - results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) - print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") - wandb_run.log({ - "step": step, - "total_training_flops": flops_so_far, - "core_metric": results["core_metric"], - "centered_results": results["centered_results"], - }) - model.train() - - # once in a while: sample from the model (only on master process) - # use the original uncompiled model because the inputs keep changing shape - if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)): - model.eval() - prompts = [ - "The capital of France is", - "The chemical symbol of gold is", - "If yesterday was Friday, then tomorrow will be", - "The opposite of hot is", - "The planets of the solar system are:", - "My favorite color is", - "If 5*x + 3 = 13, then x is", - ] - engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation - for prompt in prompts: - tokens = tokenizer(prompt, prepend="<|bos|>") - with autocast_ctx: - sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) - print0(tokenizer.decode(sample[0])) - model.train() - - # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step - if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0): - save_checkpoint( - checkpoint_dir, - step, - orig_model.state_dict(), # model parameters - optimizer.state_dict(), # optimizer state - { # metadata saved as json - "step": step, - "val_bpb": val_bpb, # loss at last step - "model_config": model_config_kwargs, - "user_config": user_config, # inputs to the training script - "device_batch_size": args.device_batch_size, - "max_seq_len": args.max_seq_len, - "dataloader_state_dict": dataloader_state_dict, - "loop_state": { # all loop state (other than step) so that we can resume training - "min_val_bpb": min_val_bpb, - "smooth_train_loss": smooth_train_loss, - "total_training_time": total_training_time, - }, - }, - rank=ddp_rank, - ) - - # termination conditions (TODO: possibly also add loss explosions etc.) - if last_step: - break - - # ------------------------------------------------------------------------- - # single training step - # evaluate the gradient - synchronize() - t0 = time.time() - for micro_step in range(grad_accum_steps): - with autocast_ctx: - loss = model(x, y) - train_loss = loss.detach() # for logging - loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here - loss.backward() - x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward - # step the optimizer - lrm = get_lr_multiplier(step) - muon_momentum = get_muon_momentum(step) - muon_weight_decay = get_weight_decay(step) - for group in optimizer.param_groups: - group["lr"] = group["initial_lr"] * lrm - if group['kind'] == 'muon': - group["momentum"] = muon_momentum - group["weight_decay"] = muon_weight_decay - optimizer.step() - model.zero_grad(set_to_none=True) - train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point - synchronize() - t1 = time.time() - dt = t1 - t0 - # ------------------------------------------------------------------------- - - # logging (CPU action only) - ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging - smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss - debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA - pct_done = 100 * step / num_iterations - tok_per_sec = int(args.total_batch_size / dt) - flops_per_sec = num_flops_per_token * args.total_batch_size / dt - mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) - if step > 10: - total_training_time += dt # only count the time after the first 10 steps - # Calculate ETA based on average time per step (excluding first 10 steps) - steps_done = step - 10 - if steps_done > 0: - avg_time_per_step = total_training_time / steps_done - remaining_steps = num_iterations - step - eta_seconds = remaining_steps * avg_time_per_step - eta_str = f" | eta: {eta_seconds/60:.1f}m" - else: - eta_str = "" - epoch = dataloader_state_dict["epoch"] - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") - if step % 100 == 0: - log_data = { - "step": step, - "total_training_flops": flops_so_far, - "total_training_time": total_training_time, - "train/loss": debiased_smooth_loss, - "train/lrm": lrm, - "train/dt": dt, - "train/tok_per_sec": tok_per_sec, - "train/mfu": mfu, - "train/epoch": epoch, - } - wandb_run.log(log_data) - - # Set 'first_step_of_run' flag - first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step) - - # state update - step += 1 - - # # TEMP - Bail at 1000 steps for benchmarking. - # if step == 1001: - # print0(f"Elapsed + ETA: {total_training_time + eta_seconds:.0f}s") - # break - - # Help out the garbage collector by flushing garbage and then freezing long-lived objects - # This eliminates random ~500ms pauses during training steps as the GC scans ~millions of objects for cycles - if first_step_of_run: - gc.collect() - gc.freeze() - gc.disable() # nuclear option: disable GC for the run - elif step % 2000 == 0: - gc.collect() # manual GC to keep memory usage in check for very long runs - -# print a few more stats -print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") -print0(f"Total training time: {total_training_time/60:.2f}m") -if val_bpb is not None: - print0(f"Minimum validation bpb: {min_val_bpb:.6f}") - -# Log to report -from nanochat.report import get_report -get_report().log(section="Base model training", data=[ - user_config, # CLI args - { # stats about the training setup - "Number of parameters": num_params, - "Number of FLOPs per token": f"{num_flops_per_token:e}", - "Calculated number of iterations": num_iterations, - "Number of training tokens": total_tokens, - "Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params, - "DDP world size": ddp_world_size, - "warmup_ratio": args.warmup_ratio, - "warmdown_ratio": args.warmdown_ratio, - "final_lr_frac": args.final_lr_frac, - }, - { # stats about training outcomes - "Minimum validation bpb": min_val_bpb if val_bpb is not None else None, - "Final validation bpb": val_bpb, - "CORE metric estimate": results.get("core_metric", None), - "MFU %": f"{mfu:.2f}%", - "Total training flops": f"{flops_so_far:e}", - "Total training time": f"{total_training_time/60:.2f}m", - "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", - } -]) - -# cleanup -wandb_run.finish() # wandb run finish -compute_cleanup()