diff --git a/scripts/base_train.py b/scripts/base_train.py index a3774e6..ccf35e6 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -237,11 +237,9 @@ orig_model = model # original, uncompiled model, for saving raw model state_dict model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe # ----------------------------------------------------------------------------- -# Determine the optimization horizon based on the model size -# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis). -# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params +# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay. -# Get the parameter counts of the model +# Get the parameter counts of our model param_counts = model.num_scaling_params() print0(f"Parameter counts:") for key, value in param_counts.items(): @@ -250,23 +248,80 @@ num_params = param_counts['total'] num_flops_per_token = model.estimate_flops() print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") -# Scaling params: transformer matrices + lm_head (gives cleanest scaling laws, see dev/LOG.md Jan 27, 2026) -get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head'] +# 1) Use scaling laws to determine the optimal training horizon in tokens +# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis). +# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params +def get_scaling_params(m): + # As for which params to use exactly, transformer matrices + lm_head gives cleanest scaling laws (see dev/LOG.md Jan 27, 2026) + params_counts = m.num_scaling_params() + scaling_params = params_counts['transformer_matrices'] + params_counts['lm_head'] + return scaling_params num_scaling_params = get_scaling_params(model) -target_tokens = int(args.target_param_data_ratio * num_scaling_params) +target_tokens = int(args.target_param_data_ratio * num_scaling_params) # optimal tokens for the model we are about to train -# Auto-compute optimal batch size based on Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738 -total_batch_size = args.total_batch_size +# Our reference model is d12, this is where a lot of hyperparameters are tuned and then transfered to higher depths (muP style) +d12_ref = build_model_meta(12) # creates the model on meta device +D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) # compute-optimal d12 training horizon in tokens (measured empirically) +B_REF = 2**19 # optimal batch size at d12 ~= 524,288 tokens (measured empirically) + +# 2) Now that we have the token horizon, we can calculate the optimal batch size +# We follow the Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738 +# The optimal batch size grows as approximately D^0.383, so e.g. if D doubles from d12 to d24, B should grow by 2^0.383 ≈ 1.3x. +total_batch_size = args.total_batch_size # user-provided override is possible if total_batch_size == -1: - d12_ref = build_model_meta(12) # d12 is where the optimal batch size was measured to be 2**19 tokens - d12_num_scaling_params = get_scaling_params(d12_ref) - D_REF = args.target_param_data_ratio * d12_num_scaling_params - B_REF = 2**19 batch_size_ratio = target_tokens / D_REF - total_batch_size = 2 ** round(math.log2(B_REF * batch_size_ratio ** 0.383)) # also clamp to power of 2 + predicted_batch_size = B_REF * batch_size_ratio ** 0.383 + total_batch_size = 2 ** round(math.log2(predicted_batch_size)) # clamp to nearest power of 2 for efficiency print0(f"Auto-computed optimal batch size: {total_batch_size:,} tokens") -# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order) +# 3) Knowing the batch size, we can now calculate a learning rate correction (bigger batch size allows higher learning rates) +batch_lr_scale = 1.0 +batch_ratio = total_batch_size / B_REF # B/B_ref +if batch_ratio != 1.0: + # SGD: linear scaling with batch size is standard (not used in nanochat) + # AdamW: sqrt scaling is standard: η ∝ √(B/B_ref) + # Muon: we will use the same scaling for Muon as for AdamW: η ∝ √(B/B_ref) (not studied carefully, assumption!) + batch_lr_scale = batch_ratio ** 0.5 # η ∝ √(B/B_ref) + print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {B_REF:,})") + +# 4) Knowing the batch size and the token horizon, we can now calculate the appropriate weight decay scaling +# We adopt the T_epoch framework from https://arxiv.org/abs/2405.13698 +# Central idea of the paper is that T_epoch = B/(η·λ·D) should remain constant. +# Above, we used learning rate scaling η ∝ √(B/B_ref). So it's a matter of ~10 lines of math to derive that to keep T_epoch constant, we need: +# λ = λ_ref · √(B/B_ref) · (D_ref/D) +# Note that these papers study AdamW, *not* Muon. We are blindly following AdamW theory for scaling hoping it ~works for Muon too. +weight_decay_scaled = args.weight_decay * math.sqrt(total_batch_size / B_REF) * (D_REF / target_tokens) +if weight_decay_scaled != args.weight_decay: + print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") + +# ----------------------------------------------------------------------------- +# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) +optimizer = model.setup_optimizer( + # AdamW hyperparameters + unembedding_lr=args.unembedding_lr * batch_lr_scale, + embedding_lr=args.embedding_lr * batch_lr_scale, + scalar_lr=args.scalar_lr * batch_lr_scale, + adam_betas=(args.adam_beta1, args.adam_beta2), + # Muon hyperparameters + matrix_lr=args.matrix_lr * batch_lr_scale, + weight_decay=weight_decay_scaled, +) + +if resuming: + optimizer.load_state_dict(optimizer_data) + del optimizer_data + +# ----------------------------------------------------------------------------- +# Initialize the DataLoaders for train/val +dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] +train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) +build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) +x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data + +# ----------------------------------------------------------------------------- +# Calculate the number of iterations we will train for and set up the various schedulers + +# num_iterations: either it is given, or from target flops, or from target data:param ratio (in that order) assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0 if args.num_iterations > 0: # Override num_iterations to a specific value if given @@ -282,65 +337,12 @@ elif args.target_param_data_ratio > 0: print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") else: raise ValueError("No training horizon specified") -total_tokens = total_batch_size * num_iterations +total_tokens = total_batch_size * num_iterations # the actual number of tokens we will train for print0(f"Total number of training tokens: {total_tokens:,}") -print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 +print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") -# ----------------------------------------------------------------------------- -# Optimizer / data / training length related hyperparameters -# figure out the needed gradient accumulation to reach the desired total batch size -tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank -world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert total_batch_size % world_tokens_per_fwdbwd == 0 -grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd -print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") -print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") -print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") - -# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19) -batch_lr_scale = 1.0 -reference_batch_size = 2**19 -batch_ratio = total_batch_size / reference_batch_size -if batch_ratio != 1.0: - # SGD: linear scaling with batch size is standard (not used in nanochat) - # AdamW: sqrt scaling is standard - # Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer - batch_lr_scale = batch_ratio ** 0.5 - print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {reference_batch_size:,})") - -# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio) -weight_decay_scaled = args.weight_decay * (12 / args.depth)**2 -if args.depth != 12: - print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") - -# ----------------------------------------------------------------------------- -# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) -adam_betas = (args.adam_beta1, args.adam_beta2) -optimizer = model.setup_optimizer( - unembedding_lr=args.unembedding_lr * batch_lr_scale, - embedding_lr=args.embedding_lr * batch_lr_scale, - matrix_lr=args.matrix_lr * batch_lr_scale, - weight_decay=weight_decay_scaled, - adam_betas=adam_betas, - scalar_lr=args.scalar_lr * batch_lr_scale, -) - -if resuming: - optimizer.load_state_dict(optimizer_data) - del optimizer_data - -# ----------------------------------------------------------------------------- -# Initialize the DataLoaders for train/val -dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] -train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) -build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) -x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data - -# ----------------------------------------------------------------------------- -# Set up hyperparameter schedulers - -# Learning rate scheduler +# Learning rate schedule (linear warmup, constant, linear warmdown) def get_lr_multiplier(it): warmup_iters = round(args.warmup_ratio * num_iterations) warmdown_iters = round(args.warmdown_ratio * num_iterations) @@ -352,19 +354,20 @@ def get_lr_multiplier(it): progress = (num_iterations - it) / warmdown_iters return progress * 1.0 + (1 - progress) * args.final_lr_frac -# Momentum scheduler for Muon optimizer +# Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps) def get_muon_momentum(it): frac = min(it / 300, 1) momentum = (1 - frac) * 0.85 + frac * 0.95 return momentum -# Weight decay scheduler for Muon optimizer (linear to zero over the course of training) +# Weight decay scheduler for Muon optimizer (linearly decays to zero over the course of training) def get_weight_decay(it): return weight_decay_scaled * (1 - it / num_iterations) # ----------------------------------------------------------------------------- -# Loop state (variables updated by the training loop) +# Training loop +# Loop state (variables updated by the training loop) if not resuming: step = 0 val_bpb = None # will be set if eval_every > 0 @@ -379,8 +382,16 @@ else: smooth_train_loss = loop_state["smooth_train_loss"] total_training_time = loop_state["total_training_time"] -# ----------------------------------------------------------------------------- -# Training loop +# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step +tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank +world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks +assert total_batch_size % world_tokens_per_fwdbwd == 0 +grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd +print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") +print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") +print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") + +# Go! while True: last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end flops_so_far = num_flops_per_token * total_batch_size * step