mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-28 23:32:21 +00:00
Merge pull request #18 from Dianababaei/feat/auto-batch-size-discovery-integration
Add batch sampling functionality for auto-discovery of optimal batch sizes across training scripts
This commit is contained in:
commit
04e66eacfa
|
|
@ -35,8 +35,11 @@ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable
|
|||
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
# Optimization
|
||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
||||
device_batch_size = None # per-device batch size (set to not OOM), None = auto-discover
|
||||
total_batch_size = 524288 # total desired batch size, in #tokens
|
||||
auto_batch_size = True # whether to auto-discover optimal batch size
|
||||
batch_size_margin = 0.85 # safety margin for auto-discovered batch size
|
||||
batch_size_cache = True # whether to cache auto-discovered batch size
|
||||
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
||||
|
|
@ -81,15 +84,6 @@ print0(f"model_dim: {model_dim}")
|
|||
print0(f"num_heads: {num_heads}")
|
||||
print0(f"num_kv_heads: {num_kv_heads}")
|
||||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
|
|
@ -99,15 +93,41 @@ with torch.device("meta"):
|
|||
model.to_empty(device="cuda")
|
||||
model.init_weights()
|
||||
|
||||
# Create batch sample function for auto-discovery
|
||||
def create_batch_sample_fn(max_seq_len, vocab_size, device):
|
||||
def sample_fn(batch_size, seq_len):
|
||||
inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||
targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||
return inputs, targets
|
||||
return sample_fn
|
||||
# Create a batch sample function for auto-discovery
|
||||
def batch_sample_fn(batch_size):
|
||||
"""Generate a sample batch for auto-discovery of optimal batch size."""
|
||||
return (
|
||||
torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int32, device="cuda"),
|
||||
torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device="cuda")
|
||||
)
|
||||
|
||||
batch_sample_fn = create_batch_sample_fn(max_seq_len, vocab_size, device)
|
||||
# Auto-discover optimal batch size if not manually set
|
||||
if auto_batch_size and device_batch_size is None:
|
||||
from nanochat.auto_batch_size import find_optimal_device_batch_size
|
||||
device_batch_size = find_optimal_device_batch_size(
|
||||
model=model,
|
||||
max_seq_len=max_seq_len,
|
||||
total_batch_size=total_batch_size,
|
||||
ddp_world_size=ddp_world_size,
|
||||
data_sample_fn=batch_sample_fn,
|
||||
override=device_batch_size,
|
||||
safety_margin=batch_size_margin,
|
||||
enable_cache=batch_size_cache,
|
||||
ddp_rank=ddp_rank,
|
||||
)
|
||||
elif device_batch_size is None:
|
||||
device_batch_size = 8
|
||||
print0(f"Auto-discovery disabled, using default device_batch_size={device_batch_size}")
|
||||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
|
|
|
|||
|
|
@ -36,7 +36,10 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
|
|||
step = None # step to load the model from (base model or midtrained model)
|
||||
# compute/precision
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
device_batch_size = None # per-device batch size (set to not OOM), None = auto-discover
|
||||
auto_batch_size = True # whether to auto-discover optimal batch size
|
||||
batch_size_margin = 0.85 # safety margin for auto-discovered batch size
|
||||
batch_size_cache = True # whether to cache auto-discovered batch size
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
|
||||
|
|
@ -69,17 +72,36 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sf
|
|||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||
|
||||
# Create batch sample function for auto-discovery
|
||||
max_seq_len = model.config.sequence_len
|
||||
def create_batch_sample_fn(max_seq_len, vocab_size, device):
|
||||
def sample_fn(batch_size, seq_len):
|
||||
# Use max_seq_len (worst case for variable-length sequences)
|
||||
inputs = torch.randint(0, vocab_size, (batch_size, max_seq_len), device=device)
|
||||
targets = torch.full((batch_size, max_seq_len), -1, dtype=torch.long, device=device)
|
||||
return inputs, targets
|
||||
return sample_fn
|
||||
# Create a batch sample function for auto-discovery
|
||||
def batch_sample_fn(batch_size):
|
||||
"""Generate a sample batch for auto-discovery of optimal batch size."""
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
# For chat SFT, we use variable length sequences, so we use max_seq_len for discovery
|
||||
# Note: We need to define max_seq_len before this point, which is already done in config
|
||||
pad_token_id = tokenizer.encode_special("<|assistant_end|>")
|
||||
max_len = 2048 # Use a fixed max length for discovery
|
||||
return (
|
||||
torch.randint(0, vocab_size, (batch_size, max_len), dtype=torch.int32, device="cuda"),
|
||||
torch.randint(0, vocab_size, (batch_size, max_len), dtype=torch.int64, device="cuda")
|
||||
)
|
||||
|
||||
batch_sample_fn = create_batch_sample_fn(max_seq_len, model.config.vocab_size, device)
|
||||
# Auto-discover optimal batch size if not manually set
|
||||
if auto_batch_size and device_batch_size is None:
|
||||
from nanochat.auto_batch_size import find_optimal_device_batch_size
|
||||
device_batch_size = find_optimal_device_batch_size(
|
||||
model=model,
|
||||
max_seq_len=2048, # Use fixed max length for discovery
|
||||
total_batch_size=target_examples_per_step,
|
||||
ddp_world_size=ddp_world_size,
|
||||
data_sample_fn=batch_sample_fn,
|
||||
override=device_batch_size,
|
||||
safety_margin=batch_size_margin,
|
||||
enable_cache=batch_size_cache,
|
||||
ddp_rank=ddp_rank,
|
||||
)
|
||||
elif device_batch_size is None:
|
||||
device_batch_size = 4
|
||||
print0(f"Auto-discovery disabled, using default device_batch_size={device_batch_size}")
|
||||
|
||||
orig_model = model # original, uncompiled model
|
||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
|
|
|
|||
|
|
@ -34,7 +34,10 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
|
|||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
device_batch_size = None # per-device batch size (set to not OOM), None = auto-discover
|
||||
auto_batch_size = True # whether to auto-discover optimal batch size
|
||||
batch_size_margin = 0.85 # safety margin for auto-discovered batch size
|
||||
batch_size_cache = True # whether to cache auto-discovered batch size
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
|
|
@ -61,23 +64,39 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi
|
|||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
|
||||
|
||||
# Create a batch sample function for auto-discovery
|
||||
def batch_sample_fn(batch_size):
|
||||
"""Generate a sample batch for auto-discovery of optimal batch size."""
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
return (
|
||||
torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int32, device="cuda"),
|
||||
torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device="cuda")
|
||||
)
|
||||
|
||||
# Auto-discover optimal batch size if not manually set
|
||||
if auto_batch_size and device_batch_size is None:
|
||||
from nanochat.auto_batch_size import find_optimal_device_batch_size
|
||||
device_batch_size = find_optimal_device_batch_size(
|
||||
model=model,
|
||||
max_seq_len=max_seq_len,
|
||||
total_batch_size=total_batch_size,
|
||||
ddp_world_size=ddp_world_size,
|
||||
data_sample_fn=batch_sample_fn,
|
||||
override=device_batch_size,
|
||||
safety_margin=batch_size_margin,
|
||||
enable_cache=batch_size_cache,
|
||||
ddp_rank=ddp_rank,
|
||||
)
|
||||
elif device_batch_size is None:
|
||||
device_batch_size = 8
|
||||
print0(f"Auto-discovery disabled, using default device_batch_size={device_batch_size}")
|
||||
|
||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
||||
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
|
||||
# Create batch sample function for auto-discovery
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
def create_batch_sample_fn(max_seq_len, vocab_size, device):
|
||||
def sample_fn(batch_size, seq_len):
|
||||
inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||
targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||
return inputs, targets
|
||||
return sample_fn
|
||||
|
||||
batch_sample_fn = create_batch_sample_fn(max_seq_len, vocab_size, device)
|
||||
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user