Merge pull request #18 from Dianababaei/feat/auto-batch-size-discovery-integration

Add batch sampling functionality for auto-discovery of optimal batch sizes across training scripts
This commit is contained in:
Dianababaei 2025-11-05 20:20:47 +03:30 committed by GitHub
commit 04e66eacfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 103 additions and 42 deletions

View File

@ -35,8 +35,11 @@ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
# Optimization
device_batch_size = 32 # per-device batch size (set to not OOM)
device_batch_size = None # per-device batch size (set to not OOM), None = auto-discover
total_batch_size = 524288 # total desired batch size, in #tokens
auto_batch_size = True # whether to auto-discover optimal batch size
batch_size_margin = 0.85 # safety margin for auto-discovered batch size
batch_size_cache = True # whether to cache auto-discovered batch size
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
@ -81,15 +84,6 @@ print0(f"model_dim: {model_dim}")
print0(f"num_heads: {num_heads}")
print0(f"num_kv_heads: {num_kv_heads}")
# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
# -----------------------------------------------------------------------------
# Initialize the Model
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
@ -99,15 +93,41 @@ with torch.device("meta"):
model.to_empty(device="cuda")
model.init_weights()
# Create batch sample function for auto-discovery
def create_batch_sample_fn(max_seq_len, vocab_size, device):
def sample_fn(batch_size, seq_len):
inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
return inputs, targets
return sample_fn
# Create a batch sample function for auto-discovery
def batch_sample_fn(batch_size):
"""Generate a sample batch for auto-discovery of optimal batch size."""
return (
torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int32, device="cuda"),
torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device="cuda")
)
batch_sample_fn = create_batch_sample_fn(max_seq_len, vocab_size, device)
# Auto-discover optimal batch size if not manually set
if auto_batch_size and device_batch_size is None:
from nanochat.auto_batch_size import find_optimal_device_batch_size
device_batch_size = find_optimal_device_batch_size(
model=model,
max_seq_len=max_seq_len,
total_batch_size=total_batch_size,
ddp_world_size=ddp_world_size,
data_sample_fn=batch_sample_fn,
override=device_batch_size,
safety_margin=batch_size_margin,
enable_cache=batch_size_cache,
ddp_rank=ddp_rank,
)
elif device_batch_size is None:
device_batch_size = 8
print0(f"Auto-discovery disabled, using default device_batch_size={device_batch_size}")
# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through

View File

@ -36,7 +36,10 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
step = None # step to load the model from (base model or midtrained model)
# compute/precision
dtype = "bfloat16"
device_batch_size = 4 # max to avoid OOM
device_batch_size = None # per-device batch size (set to not OOM), None = auto-discover
auto_batch_size = True # whether to auto-discover optimal batch size
batch_size_margin = 0.85 # safety margin for auto-discovered batch size
batch_size_cache = True # whether to cache auto-discovered batch size
# optimization
num_epochs = 1
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
@ -69,17 +72,36 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sf
# Load the model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
# Create batch sample function for auto-discovery
max_seq_len = model.config.sequence_len
def create_batch_sample_fn(max_seq_len, vocab_size, device):
def sample_fn(batch_size, seq_len):
# Use max_seq_len (worst case for variable-length sequences)
inputs = torch.randint(0, vocab_size, (batch_size, max_seq_len), device=device)
targets = torch.full((batch_size, max_seq_len), -1, dtype=torch.long, device=device)
return inputs, targets
return sample_fn
# Create a batch sample function for auto-discovery
def batch_sample_fn(batch_size):
"""Generate a sample batch for auto-discovery of optimal batch size."""
vocab_size = tokenizer.get_vocab_size()
# For chat SFT, we use variable length sequences, so we use max_seq_len for discovery
# Note: We need to define max_seq_len before this point, which is already done in config
pad_token_id = tokenizer.encode_special("<|assistant_end|>")
max_len = 2048 # Use a fixed max length for discovery
return (
torch.randint(0, vocab_size, (batch_size, max_len), dtype=torch.int32, device="cuda"),
torch.randint(0, vocab_size, (batch_size, max_len), dtype=torch.int64, device="cuda")
)
batch_sample_fn = create_batch_sample_fn(max_seq_len, model.config.vocab_size, device)
# Auto-discover optimal batch size if not manually set
if auto_batch_size and device_batch_size is None:
from nanochat.auto_batch_size import find_optimal_device_batch_size
device_batch_size = find_optimal_device_batch_size(
model=model,
max_seq_len=2048, # Use fixed max length for discovery
total_batch_size=target_examples_per_step,
ddp_world_size=ddp_world_size,
data_sample_fn=batch_sample_fn,
override=device_batch_size,
safety_margin=batch_size_margin,
enable_cache=batch_size_cache,
ddp_rank=ddp_rank,
)
elif device_batch_size is None:
device_batch_size = 4
print0(f"Auto-discovery disabled, using default device_batch_size={device_batch_size}")
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs

View File

@ -34,7 +34,10 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
max_seq_len = 2048
device_batch_size = 32
device_batch_size = None # per-device batch size (set to not OOM), None = auto-discover
auto_batch_size = True # whether to auto-discover optimal batch size
batch_size_margin = 0.85 # safety margin for auto-discovered batch size
batch_size_cache = True # whether to cache auto-discovered batch size
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
@ -61,23 +64,39 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
# Create a batch sample function for auto-discovery
def batch_sample_fn(batch_size):
"""Generate a sample batch for auto-discovery of optimal batch size."""
vocab_size = tokenizer.get_vocab_size()
return (
torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int32, device="cuda"),
torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device="cuda")
)
# Auto-discover optimal batch size if not manually set
if auto_batch_size and device_batch_size is None:
from nanochat.auto_batch_size import find_optimal_device_batch_size
device_batch_size = find_optimal_device_batch_size(
model=model,
max_seq_len=max_seq_len,
total_batch_size=total_batch_size,
ddp_world_size=ddp_world_size,
data_sample_fn=batch_sample_fn,
override=device_batch_size,
safety_margin=batch_size_margin,
enable_cache=batch_size_cache,
ddp_rank=ddp_rank,
)
elif device_batch_size is None:
device_batch_size = 8
print0(f"Auto-discovery disabled, using default device_batch_size={device_batch_size}")
pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
# Create batch sample function for auto-discovery
vocab_size = tokenizer.get_vocab_size()
def create_batch_sample_fn(max_seq_len, vocab_size, device):
def sample_fn(batch_size, seq_len):
inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
return inputs, targets
return sample_fn
batch_sample_fn = create_batch_sample_fn(max_seq_len, vocab_size, device)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank