feat(train): add batch sample functions for memory testing in auto-discovery

Add create_batch_sample_fn closures to base_train.py, mid_train.py, and chat_sft.py that generate realistic test batches matching training data formats for accurate memory
This commit is contained in:
Artemis Git Integration 2025-11-05 16:48:55 +00:00
parent 38801c983d
commit a8aad26041
3 changed files with 39 additions and 15 deletions

View File

@ -35,11 +35,7 @@ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
# Optimization
# Auto batch size discovery
auto_batch_size = True # Enable/disable auto-discovery
batch_size_margin = 0.85 # Safety margin (85% of max)
batch_size_cache = False # Enable result caching
device_batch_size = None # If None, auto-discover; if set, use that value
device_batch_size = 32 # per-device batch size (set to not OOM)
total_batch_size = 524288 # total desired batch size, in #tokens
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
@ -102,6 +98,17 @@ with torch.device("meta"):
model = GPT(model_config)
model.to_empty(device="cuda")
model.init_weights()
# Create batch sample function for auto-discovery
def create_batch_sample_fn(max_seq_len, vocab_size, device):
def sample_fn(batch_size, seq_len):
inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
return inputs, targets
return sample_fn
batch_sample_fn = create_batch_sample_fn(max_seq_len, vocab_size, device)
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
num_params = sum(p.numel() for p in model.parameters())

View File

@ -36,11 +36,7 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
step = None # step to load the model from (base model or midtrained model)
# compute/precision
dtype = "bfloat16"
# Auto batch size discovery
auto_batch_size = True # Enable/disable auto-discovery
batch_size_margin = 0.85 # Safety margin (85% of max)
batch_size_cache = False # Enable result caching
device_batch_size = None # If None, auto-discover; if set, use that value
device_batch_size = 4 # max to avoid OOM
# optimization
num_epochs = 1
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
@ -72,6 +68,19 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sf
# Load the model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
# Create batch sample function for auto-discovery
max_seq_len = model.config.sequence_len
def create_batch_sample_fn(max_seq_len, vocab_size, device):
def sample_fn(batch_size, seq_len):
# Use max_seq_len (worst case for variable-length sequences)
inputs = torch.randint(0, vocab_size, (batch_size, max_seq_len), device=device)
targets = torch.full((batch_size, max_seq_len), -1, dtype=torch.long, device=device)
return inputs, targets
return sample_fn
batch_sample_fn = create_batch_sample_fn(max_seq_len, model.config.vocab_size, device)
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(model, tokenizer) # will be used for inline model evaluation only

View File

@ -34,11 +34,7 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
max_seq_len = 2048
# Auto batch size discovery
auto_batch_size = True # Enable/disable auto-discovery
batch_size_margin = 0.85 # Safety margin (85% of max)
batch_size_cache = False # Enable result caching
device_batch_size = None # If None, auto-discover; if set, use that value
device_batch_size = 32
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
@ -70,6 +66,18 @@ if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
# Create batch sample function for auto-discovery
vocab_size = tokenizer.get_vocab_size()
def create_batch_sample_fn(max_seq_len, vocab_size, device):
def sample_fn(batch_size, seq_len):
inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
return inputs, targets
return sample_fn
batch_sample_fn = create_batch_sample_fn(max_seq_len, vocab_size, device)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank