From 2c062aaa949536c1cff2ffb3df2ae4aeba20dc4b Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 5 Feb 2026 19:59:46 +0000 Subject: [PATCH] nit: don't mutate args, create new var for total_batch_size --- scripts/base_train.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index 97264c8..a3774e6 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -256,13 +256,15 @@ num_scaling_params = get_scaling_params(model) target_tokens = int(args.target_param_data_ratio * num_scaling_params) # Auto-compute optimal batch size based on Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738 -if args.total_batch_size == -1: +total_batch_size = args.total_batch_size +if total_batch_size == -1: d12_ref = build_model_meta(12) # d12 is where the optimal batch size was measured to be 2**19 tokens d12_num_scaling_params = get_scaling_params(d12_ref) D_REF = args.target_param_data_ratio * d12_num_scaling_params B_REF = 2**19 - args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383)) # also clamp to power of 2 - print0(f"Auto-computed optimal batch size: {args.total_batch_size:,} tokens") + batch_size_ratio = target_tokens / D_REF + total_batch_size = 2 ** round(math.log2(B_REF * batch_size_ratio ** 0.383)) # also clamp to power of 2 + print0(f"Auto-computed optimal batch size: {total_batch_size:,} tokens") # Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order) assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0 @@ -272,17 +274,17 @@ if args.num_iterations > 0: print0(f"Using user-provided number of iterations: {num_iterations:,}") elif args.target_flops > 0: # Calculate the number of iterations from the target flops (used in scaling laws analysis, e.g. runs/scaling_laws.sh) - num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size)) + num_iterations = round(args.target_flops / (num_flops_per_token * total_batch_size)) print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") elif args.target_param_data_ratio > 0: # Calculate the number of iterations from the target param data ratio (the most common use case) - num_iterations = target_tokens // args.total_batch_size + num_iterations = target_tokens // total_batch_size print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") else: raise ValueError("No training horizon specified") -total_tokens = args.total_batch_size * num_iterations +total_tokens = total_batch_size * num_iterations print0(f"Total number of training tokens: {total_tokens:,}") -print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 +print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") # ----------------------------------------------------------------------------- @@ -290,22 +292,22 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") # figure out the needed gradient accumulation to reach the desired total batch size tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert args.total_batch_size % world_tokens_per_fwdbwd == 0 -grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd +assert total_batch_size % world_tokens_per_fwdbwd == 0 +grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") -print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") +print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") # Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19) batch_lr_scale = 1.0 reference_batch_size = 2**19 -batch_ratio = args.total_batch_size / reference_batch_size +batch_ratio = total_batch_size / reference_batch_size if batch_ratio != 1.0: # SGD: linear scaling with batch size is standard (not used in nanochat) # AdamW: sqrt scaling is standard # Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer batch_lr_scale = batch_ratio ** 0.5 - print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})") + print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {reference_batch_size:,})") # Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio) weight_decay_scaled = args.weight_decay * (12 / args.depth)**2 @@ -381,7 +383,7 @@ else: # Training loop while True: last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end - flops_so_far = num_flops_per_token * args.total_batch_size * step + flops_so_far = num_flops_per_token * total_batch_size * step # once in a while: evaluate the val bpb (all ranks participate) if args.eval_every > 0 and (last_step or step % args.eval_every == 0): @@ -501,8 +503,8 @@ while True: smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA pct_done = 100 * step / num_iterations - tok_per_sec = int(args.total_batch_size / dt) - flops_per_sec = num_flops_per_token * args.total_batch_size / dt + tok_per_sec = int(total_batch_size / dt) + flops_per_sec = num_flops_per_token * total_batch_size / dt mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) if step > 10: total_training_time += dt # only count the time after the first 10 steps @@ -560,7 +562,7 @@ get_report().log(section="Base model training", data=[ "Number of FLOPs per token": f"{num_flops_per_token:e}", "Calculated number of iterations": num_iterations, "Number of training tokens": total_tokens, - "Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params, + "Tokens : Scaling params ratio": total_batch_size * num_iterations / num_scaling_params, "DDP world size": ddp_world_size, "warmup_ratio": args.warmup_ratio, "warmdown_ratio": args.warmdown_ratio,