mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-19 19:33:15 +00:00
better comments/flow on all the hyperparameter transfer stuff, and change the WD scaling from my empirical 1/d^2 to a bit more principled version based on Tepoch. All of that theory is based on AdamW and could be suboptimal for Muon
This commit is contained in:
parent
685271dc8d
commit
aeff095e97
|
|
@ -237,11 +237,9 @@ orig_model = model # original, uncompiled model, for saving raw model state_dict
|
|||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Determine the optimization horizon based on the model size
|
||||
# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis).
|
||||
# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params
|
||||
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
|
||||
|
||||
# Get the parameter counts of the model
|
||||
# Get the parameter counts of our model
|
||||
param_counts = model.num_scaling_params()
|
||||
print0(f"Parameter counts:")
|
||||
for key, value in param_counts.items():
|
||||
|
|
@ -250,23 +248,80 @@ num_params = param_counts['total']
|
|||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
# Scaling params: transformer matrices + lm_head (gives cleanest scaling laws, see dev/LOG.md Jan 27, 2026)
|
||||
get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head']
|
||||
# 1) Use scaling laws to determine the optimal training horizon in tokens
|
||||
# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis).
|
||||
# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params
|
||||
def get_scaling_params(m):
|
||||
# As for which params to use exactly, transformer matrices + lm_head gives cleanest scaling laws (see dev/LOG.md Jan 27, 2026)
|
||||
params_counts = m.num_scaling_params()
|
||||
scaling_params = params_counts['transformer_matrices'] + params_counts['lm_head']
|
||||
return scaling_params
|
||||
num_scaling_params = get_scaling_params(model)
|
||||
target_tokens = int(args.target_param_data_ratio * num_scaling_params)
|
||||
target_tokens = int(args.target_param_data_ratio * num_scaling_params) # optimal tokens for the model we are about to train
|
||||
|
||||
# Auto-compute optimal batch size based on Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738
|
||||
total_batch_size = args.total_batch_size
|
||||
# Our reference model is d12, this is where a lot of hyperparameters are tuned and then transfered to higher depths (muP style)
|
||||
d12_ref = build_model_meta(12) # creates the model on meta device
|
||||
D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) # compute-optimal d12 training horizon in tokens (measured empirically)
|
||||
B_REF = 2**19 # optimal batch size at d12 ~= 524,288 tokens (measured empirically)
|
||||
|
||||
# 2) Now that we have the token horizon, we can calculate the optimal batch size
|
||||
# We follow the Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738
|
||||
# The optimal batch size grows as approximately D^0.383, so e.g. if D doubles from d12 to d24, B should grow by 2^0.383 ≈ 1.3x.
|
||||
total_batch_size = args.total_batch_size # user-provided override is possible
|
||||
if total_batch_size == -1:
|
||||
d12_ref = build_model_meta(12) # d12 is where the optimal batch size was measured to be 2**19 tokens
|
||||
d12_num_scaling_params = get_scaling_params(d12_ref)
|
||||
D_REF = args.target_param_data_ratio * d12_num_scaling_params
|
||||
B_REF = 2**19
|
||||
batch_size_ratio = target_tokens / D_REF
|
||||
total_batch_size = 2 ** round(math.log2(B_REF * batch_size_ratio ** 0.383)) # also clamp to power of 2
|
||||
predicted_batch_size = B_REF * batch_size_ratio ** 0.383
|
||||
total_batch_size = 2 ** round(math.log2(predicted_batch_size)) # clamp to nearest power of 2 for efficiency
|
||||
print0(f"Auto-computed optimal batch size: {total_batch_size:,} tokens")
|
||||
|
||||
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
|
||||
# 3) Knowing the batch size, we can now calculate a learning rate correction (bigger batch size allows higher learning rates)
|
||||
batch_lr_scale = 1.0
|
||||
batch_ratio = total_batch_size / B_REF # B/B_ref
|
||||
if batch_ratio != 1.0:
|
||||
# SGD: linear scaling with batch size is standard (not used in nanochat)
|
||||
# AdamW: sqrt scaling is standard: η ∝ √(B/B_ref)
|
||||
# Muon: we will use the same scaling for Muon as for AdamW: η ∝ √(B/B_ref) (not studied carefully, assumption!)
|
||||
batch_lr_scale = batch_ratio ** 0.5 # η ∝ √(B/B_ref)
|
||||
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {B_REF:,})")
|
||||
|
||||
# 4) Knowing the batch size and the token horizon, we can now calculate the appropriate weight decay scaling
|
||||
# We adopt the T_epoch framework from https://arxiv.org/abs/2405.13698
|
||||
# Central idea of the paper is that T_epoch = B/(η·λ·D) should remain constant.
|
||||
# Above, we used learning rate scaling η ∝ √(B/B_ref). So it's a matter of ~10 lines of math to derive that to keep T_epoch constant, we need:
|
||||
# λ = λ_ref · √(B/B_ref) · (D_ref/D)
|
||||
# Note that these papers study AdamW, *not* Muon. We are blindly following AdamW theory for scaling hoping it ~works for Muon too.
|
||||
weight_decay_scaled = args.weight_decay * math.sqrt(total_batch_size / B_REF) * (D_REF / target_tokens)
|
||||
if weight_decay_scaled != args.weight_decay:
|
||||
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
optimizer = model.setup_optimizer(
|
||||
# AdamW hyperparameters
|
||||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
adam_betas=(args.adam_beta1, args.adam_beta2),
|
||||
# Muon hyperparameters
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=weight_decay_scaled,
|
||||
)
|
||||
|
||||
if resuming:
|
||||
optimizer.load_state_dict(optimizer_data)
|
||||
del optimizer_data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Calculate the number of iterations we will train for and set up the various schedulers
|
||||
|
||||
# num_iterations: either it is given, or from target flops, or from target data:param ratio (in that order)
|
||||
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
|
||||
if args.num_iterations > 0:
|
||||
# Override num_iterations to a specific value if given
|
||||
|
|
@ -282,65 +337,12 @@ elif args.target_param_data_ratio > 0:
|
|||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = total_batch_size * num_iterations
|
||||
total_tokens = total_batch_size * num_iterations # the actual number of tokens we will train for
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19)
|
||||
batch_lr_scale = 1.0
|
||||
reference_batch_size = 2**19
|
||||
batch_ratio = total_batch_size / reference_batch_size
|
||||
if batch_ratio != 1.0:
|
||||
# SGD: linear scaling with batch size is standard (not used in nanochat)
|
||||
# AdamW: sqrt scaling is standard
|
||||
# Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer
|
||||
batch_lr_scale = batch_ratio ** 0.5
|
||||
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {reference_batch_size:,})")
|
||||
|
||||
# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio)
|
||||
weight_decay_scaled = args.weight_decay * (12 / args.depth)**2
|
||||
if args.depth != 12:
|
||||
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
adam_betas = (args.adam_beta1, args.adam_beta2)
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=weight_decay_scaled,
|
||||
adam_betas=adam_betas,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
)
|
||||
|
||||
if resuming:
|
||||
optimizer.load_state_dict(optimizer_data)
|
||||
del optimizer_data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Set up hyperparameter schedulers
|
||||
|
||||
# Learning rate scheduler
|
||||
# Learning rate schedule (linear warmup, constant, linear warmdown)
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(args.warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
||||
|
|
@ -352,19 +354,20 @@ def get_lr_multiplier(it):
|
|||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
# Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps)
|
||||
def get_muon_momentum(it):
|
||||
frac = min(it / 300, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
return momentum
|
||||
|
||||
# Weight decay scheduler for Muon optimizer (linear to zero over the course of training)
|
||||
# Weight decay scheduler for Muon optimizer (linearly decays to zero over the course of training)
|
||||
def get_weight_decay(it):
|
||||
return weight_decay_scaled * (1 - it / num_iterations)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Loop state (variables updated by the training loop)
|
||||
# Training loop
|
||||
|
||||
# Loop state (variables updated by the training loop)
|
||||
if not resuming:
|
||||
step = 0
|
||||
val_bpb = None # will be set if eval_every > 0
|
||||
|
|
@ -379,8 +382,16 @@ else:
|
|||
smooth_train_loss = loop_state["smooth_train_loss"]
|
||||
total_training_time = loop_state["total_training_time"]
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# Go!
|
||||
while True:
|
||||
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user