mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-07 09:50:28 +00:00
nit: don't mutate args, create new var for total_batch_size
This commit is contained in:
parent
f41dd3cbd7
commit
2c062aaa94
|
|
@ -256,13 +256,15 @@ num_scaling_params = get_scaling_params(model)
|
|||
target_tokens = int(args.target_param_data_ratio * num_scaling_params)
|
||||
|
||||
# Auto-compute optimal batch size based on Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738
|
||||
if args.total_batch_size == -1:
|
||||
total_batch_size = args.total_batch_size
|
||||
if total_batch_size == -1:
|
||||
d12_ref = build_model_meta(12) # d12 is where the optimal batch size was measured to be 2**19 tokens
|
||||
d12_num_scaling_params = get_scaling_params(d12_ref)
|
||||
D_REF = args.target_param_data_ratio * d12_num_scaling_params
|
||||
B_REF = 2**19
|
||||
args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383)) # also clamp to power of 2
|
||||
print0(f"Auto-computed optimal batch size: {args.total_batch_size:,} tokens")
|
||||
batch_size_ratio = target_tokens / D_REF
|
||||
total_batch_size = 2 ** round(math.log2(B_REF * batch_size_ratio ** 0.383)) # also clamp to power of 2
|
||||
print0(f"Auto-computed optimal batch size: {total_batch_size:,} tokens")
|
||||
|
||||
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
|
||||
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
|
||||
|
|
@ -272,17 +274,17 @@ if args.num_iterations > 0:
|
|||
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
||||
elif args.target_flops > 0:
|
||||
# Calculate the number of iterations from the target flops (used in scaling laws analysis, e.g. runs/scaling_laws.sh)
|
||||
num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size))
|
||||
num_iterations = round(args.target_flops / (num_flops_per_token * total_batch_size))
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif args.target_param_data_ratio > 0:
|
||||
# Calculate the number of iterations from the target param data ratio (the most common use case)
|
||||
num_iterations = target_tokens // args.total_batch_size
|
||||
num_iterations = target_tokens // total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = args.total_batch_size * num_iterations
|
||||
total_tokens = total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -290,22 +292,22 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
|||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19)
|
||||
batch_lr_scale = 1.0
|
||||
reference_batch_size = 2**19
|
||||
batch_ratio = args.total_batch_size / reference_batch_size
|
||||
batch_ratio = total_batch_size / reference_batch_size
|
||||
if batch_ratio != 1.0:
|
||||
# SGD: linear scaling with batch size is standard (not used in nanochat)
|
||||
# AdamW: sqrt scaling is standard
|
||||
# Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer
|
||||
batch_lr_scale = batch_ratio ** 0.5
|
||||
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})")
|
||||
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {reference_batch_size:,})")
|
||||
|
||||
# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio)
|
||||
weight_decay_scaled = args.weight_decay * (12 / args.depth)**2
|
||||
|
|
@ -381,7 +383,7 @@ else:
|
|||
# Training loop
|
||||
while True:
|
||||
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||
flops_so_far = num_flops_per_token * args.total_batch_size * step
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
|
||||
|
|
@ -501,8 +503,8 @@ while True:
|
|||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(args.total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
|
|
@ -560,7 +562,7 @@ get_report().log(section="Base model training", data=[
|
|||
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
||||
"Calculated number of iterations": num_iterations,
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params,
|
||||
"Tokens : Scaling params ratio": total_batch_size * num_iterations / num_scaling_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": args.warmup_ratio,
|
||||
"warmdown_ratio": args.warmdown_ratio,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user