mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-20 03:43:20 +00:00
auto-calculate optimal batch size. the original setting of 0.5M was only optimal for d12, but d26 prefers 1M and so on
This commit is contained in:
parent
98eed6df18
commit
f41dd3cbd7
46
dev/LOG.md
46
dev/LOG.md
|
|
@ -4,6 +4,52 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
|||
|
||||
---
|
||||
|
||||
## 2026-02-05: Auto Batch Size Scaling
|
||||
|
||||
### Background
|
||||
|
||||
So far, the `--total-batch-size` was hardcoded to be `2**19 = 524,288` ~= 0.5M tokens. This was the optimal setting for d12, but when I tried to re-tune it for d26 (GPT-2), I noticed that the optimal was closer to `2**20 = 1,048,576` ~= 1M tokens. This is to be expected - larger models prefer a higher optimal total batch size. However, we have to make sure that all settings of `--depth` get their own optimal batch size calculated in some principled way. Here, I referenced the "Power Lines" paper from Cerebras ([arXiv:2505.13738](https://arxiv.org/abs/2505.13738)) for a lot of related experimentation. In particular, they found that **Bopt ∝ D^0.383** (where D is the number of training tokens, not the number of parameters!). So the idea is to tune the optimal batch size on d12, and then extrapolate it with this power law to bigger models. The 0.383 exponent means batch size grows slowly: 10× more tokens only justifies ~2.4× bigger batch. For nanochat's compute-optimal training (D ∝ N via `--target-param-data-ratio`), this means deeper models naturally want larger batches.
|
||||
|
||||
### Implementation
|
||||
|
||||
Added `--total-batch-size=-1` (now the default) to auto-compute optimal batch:
|
||||
|
||||
```python
|
||||
get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head']
|
||||
if args.total_batch_size == -1:
|
||||
D_REF = args.target_param_data_ratio * get_scaling_params(build_model_meta(12))
|
||||
B_REF = 2**19
|
||||
args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383))
|
||||
```
|
||||
|
||||
Reference point: d=12 model with B=2^19 (empirically validated). The reference is computed dynamically so that if the architecture changes (e.g., different `--aspect-ratio`), the math automatically adjusts. However, if the model actually does change too much, one would also want to re-tune the optimal batch size for d=12.
|
||||
|
||||
### Results
|
||||
|
||||
With this formula, we currently get:
|
||||
|
||||
| Depth | Scaling Params | Target Tokens | Auto Batch |
|
||||
|-------|---------------|---------------|------------|
|
||||
| d=8 | 42M | 0.44B | 2^18 = 262K |
|
||||
| d=10-16 | 70M-235M | 0.7B-2.5B | 2^19 = 524K |
|
||||
| d=18-26 | 324M-918M | 3.4B-9.6B | 2^20 = 1.05M |
|
||||
| d=32-50 | 1.7B-6.2B | 17.6B-65.6B | 2^21 = 2.1M |
|
||||
|
||||
In particular, this matches empirical observations that d26 prefers ~2^20 while d12 prefers ~2^19.
|
||||
|
||||
### Code Cleanup
|
||||
|
||||
Also refactored model initialization to use `build_model_meta(depth)` helper and `dataclasses.asdict()` for cleaner config handling.
|
||||
|
||||
### Useful references
|
||||
|
||||
- [Bergsma et al., Power Laws for Batch Size, Model Size, and Training Horizon](https://arxiv.org/abs/2505.13738)
|
||||
- [McCandlish et al., An Empirical Model of Large-Batch Training](https://arxiv.org/abs/1812.06162)
|
||||
- [Brown et al., Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165)
|
||||
- [Merrill et al., The Batch Size–Critical Batch Size Myth](https://arxiv.org/abs/2505.23971)
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-05: SwiGLU Activation (Negative Result)
|
||||
|
||||
Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). SwiGLU uses three projections instead of two, so to match parameters and FLOPs we scale hidden_dim from 4× to 8/3×:
|
||||
|
|
|
|||
|
|
@ -11,11 +11,14 @@ If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Ex
|
|||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import time
|
||||
import math
|
||||
import argparse
|
||||
from dataclasses import asdict
|
||||
from contextlib import nullcontext, contextmanager
|
||||
|
||||
import wandb
|
||||
|
|
@ -53,8 +56,8 @@ parser.add_argument("--num-iterations", type=int, default=-1, help="explicit num
|
|||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.")
|
||||
parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
|
|
@ -78,8 +81,8 @@ parser.add_argument("--model-tag", type=str, default=None, help="override model
|
|||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
# Compute init and wandb logging
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
|
|
@ -109,65 +112,39 @@ else:
|
|||
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
|
||||
print0("!" * 80)
|
||||
|
||||
# Tokenizer will be useful for evaluation, also we need the vocab size
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tokenizer will be useful for evaluation and also we need the vocab size to init the model
|
||||
tokenizer = get_tokenizer()
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
print0(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division
|
||||
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
|
||||
# (For very small depths, this gives a slight "unfair" advantage to models with odd depths)
|
||||
num_layers = args.depth
|
||||
base_dim = args.depth * args.aspect_ratio
|
||||
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
|
||||
num_heads = model_dim // args.head_dim
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
head_dim = model_dim // num_heads
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
print0(f"head_dim: {head_dim}")
|
||||
print0(f"num_kv_heads: {num_kv_heads}")
|
||||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19)
|
||||
batch_lr_scale = 1.0
|
||||
reference_batch_size = 2**19
|
||||
batch_ratio = args.total_batch_size / reference_batch_size
|
||||
if batch_ratio != 1.0:
|
||||
# SGD: linear scaling with batch size is standard (not used in nanochat)
|
||||
# AdamW: sqrt scaling is standard
|
||||
# Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer
|
||||
batch_lr_scale = batch_ratio ** 0.5
|
||||
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})")
|
||||
|
||||
# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio)
|
||||
weight_decay_scaled = args.weight_decay * (12 / args.depth)**2
|
||||
if args.depth != 12:
|
||||
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
|
||||
# Create a new model with random weights
|
||||
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern)
|
||||
with torch.device("meta"):
|
||||
# All tensors are created as meta tensors (they have shape/dtype but no data)
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data
|
||||
model.init_weights() # All tensors get initialized
|
||||
def build_model_meta(depth):
|
||||
"""Build a model on meta device for a given depth (shapes/dtypes only, no data)."""
|
||||
# Model dim is nudged up to nearest multiple of head_dim for clean division
|
||||
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
|
||||
base_dim = depth * args.aspect_ratio
|
||||
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
|
||||
num_heads = model_dim // args.head_dim
|
||||
config = GPTConfig(
|
||||
sequence_len=args.max_seq_len, vocab_size=vocab_size,
|
||||
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
||||
window_pattern=args.window_pattern,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_meta = GPT(config)
|
||||
return model_meta
|
||||
|
||||
# Build the model, move to device, init the weights
|
||||
model = build_model_meta(args.depth) # 1) Build on meta device (only shapes/dtypes, no data)
|
||||
model_config = model.config
|
||||
model_config_kwargs = asdict(model_config)
|
||||
print0(f"Model config:\n{json.dumps(model_config_kwargs, indent=2)}")
|
||||
model.to_empty(device=device) # 2) All tensors get storage on target device but with uninitialized (garbage) data
|
||||
model.init_weights() # 3) All tensors get initialized
|
||||
|
||||
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
||||
base_dir = get_base_dir()
|
||||
|
|
@ -181,41 +158,7 @@ if resuming:
|
|||
del model_data # free up this memory after the copy
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Determine the length of the training run based on model size
|
||||
|
||||
# Detailed parameter counts
|
||||
param_counts = model.num_scaling_params()
|
||||
print0(f"Parameter counts:")
|
||||
for key, value in param_counts.items():
|
||||
print0(f"{key:24s}: {value:,}")
|
||||
num_params = param_counts['total']
|
||||
num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
|
||||
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
|
||||
if args.num_iterations > 0:
|
||||
num_iterations = args.num_iterations
|
||||
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
||||
elif args.target_flops > 0:
|
||||
# calculate the number of iterations from the target flops
|
||||
num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size))
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif args.target_param_data_ratio > 0:
|
||||
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
|
||||
target_tokens = int(args.target_param_data_ratio * num_scaling_params)
|
||||
num_iterations = target_tokens // args.total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = args.total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# FP8 training initialization and management (has to be done before torch.compile)
|
||||
# FP8 training initialization and management (this has to be done before torch.compile)
|
||||
|
||||
# Convert Linear layers to Float8Linear if --fp8 is set
|
||||
if args.fp8:
|
||||
|
|
@ -293,6 +236,82 @@ def disable_fp8(model):
|
|||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Determine the optimization horizon based on the model size
|
||||
# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis).
|
||||
# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params
|
||||
|
||||
# Get the parameter counts of the model
|
||||
param_counts = model.num_scaling_params()
|
||||
print0(f"Parameter counts:")
|
||||
for key, value in param_counts.items():
|
||||
print0(f"{key:24s}: {value:,}")
|
||||
num_params = param_counts['total']
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
# Scaling params: transformer matrices + lm_head (gives cleanest scaling laws, see dev/LOG.md Jan 27, 2026)
|
||||
get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head']
|
||||
num_scaling_params = get_scaling_params(model)
|
||||
target_tokens = int(args.target_param_data_ratio * num_scaling_params)
|
||||
|
||||
# Auto-compute optimal batch size based on Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738
|
||||
if args.total_batch_size == -1:
|
||||
d12_ref = build_model_meta(12) # d12 is where the optimal batch size was measured to be 2**19 tokens
|
||||
d12_num_scaling_params = get_scaling_params(d12_ref)
|
||||
D_REF = args.target_param_data_ratio * d12_num_scaling_params
|
||||
B_REF = 2**19
|
||||
args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383)) # also clamp to power of 2
|
||||
print0(f"Auto-computed optimal batch size: {args.total_batch_size:,} tokens")
|
||||
|
||||
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
|
||||
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
|
||||
if args.num_iterations > 0:
|
||||
# Override num_iterations to a specific value if given
|
||||
num_iterations = args.num_iterations
|
||||
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
||||
elif args.target_flops > 0:
|
||||
# Calculate the number of iterations from the target flops (used in scaling laws analysis, e.g. runs/scaling_laws.sh)
|
||||
num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size))
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif args.target_param_data_ratio > 0:
|
||||
# Calculate the number of iterations from the target param data ratio (the most common use case)
|
||||
num_iterations = target_tokens // args.total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = args.total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19)
|
||||
batch_lr_scale = 1.0
|
||||
reference_batch_size = 2**19
|
||||
batch_ratio = args.total_batch_size / reference_batch_size
|
||||
if batch_ratio != 1.0:
|
||||
# SGD: linear scaling with batch size is standard (not used in nanochat)
|
||||
# AdamW: sqrt scaling is standard
|
||||
# Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer
|
||||
batch_lr_scale = batch_ratio ** 0.5
|
||||
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})")
|
||||
|
||||
# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio)
|
||||
weight_decay_scaled = args.weight_decay * (12 / args.depth)**2
|
||||
if args.depth != 12:
|
||||
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
adam_betas = (args.adam_beta1, args.adam_beta2)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user