mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-31 00:55:18 +00:00
Merge pull request #17 from Dianababaei/feat/train-batch-sample-functions-memory-testing
Add batch sampling function factory for auto-discovery across training scripts
This commit is contained in:
commit
fa14cba28e
|
|
@ -35,11 +35,7 @@ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable
|
|||
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
# Optimization
|
||||
# Auto batch size discovery
|
||||
auto_batch_size = True # Enable/disable auto-discovery
|
||||
batch_size_margin = 0.85 # Safety margin (85% of max)
|
||||
batch_size_cache = False # Enable result caching
|
||||
device_batch_size = None # If None, auto-discover; if set, use that value
|
||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
||||
total_batch_size = 524288 # total desired batch size, in #tokens
|
||||
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||
|
|
@ -102,6 +98,17 @@ with torch.device("meta"):
|
|||
model = GPT(model_config)
|
||||
model.to_empty(device="cuda")
|
||||
model.init_weights()
|
||||
|
||||
# Create batch sample function for auto-discovery
|
||||
def create_batch_sample_fn(max_seq_len, vocab_size, device):
|
||||
def sample_fn(batch_size, seq_len):
|
||||
inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||
targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||
return inputs, targets
|
||||
return sample_fn
|
||||
|
||||
batch_sample_fn = create_batch_sample_fn(max_seq_len, vocab_size, device)
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
|
|
|
|||
|
|
@ -36,11 +36,7 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
|
|||
step = None # step to load the model from (base model or midtrained model)
|
||||
# compute/precision
|
||||
dtype = "bfloat16"
|
||||
# Auto batch size discovery
|
||||
auto_batch_size = True # Enable/disable auto-discovery
|
||||
batch_size_margin = 0.85 # Safety margin (85% of max)
|
||||
batch_size_cache = False # Enable result caching
|
||||
device_batch_size = None # If None, auto-discover; if set, use that value
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
|
||||
|
|
@ -72,6 +68,19 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sf
|
|||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||
|
||||
# Create batch sample function for auto-discovery
|
||||
max_seq_len = model.config.sequence_len
|
||||
def create_batch_sample_fn(max_seq_len, vocab_size, device):
|
||||
def sample_fn(batch_size, seq_len):
|
||||
# Use max_seq_len (worst case for variable-length sequences)
|
||||
inputs = torch.randint(0, vocab_size, (batch_size, max_seq_len), device=device)
|
||||
targets = torch.full((batch_size, max_seq_len), -1, dtype=torch.long, device=device)
|
||||
return inputs, targets
|
||||
return sample_fn
|
||||
|
||||
batch_sample_fn = create_batch_sample_fn(max_seq_len, model.config.vocab_size, device)
|
||||
|
||||
orig_model = model # original, uncompiled model
|
||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
|
|
|
|||
|
|
@ -34,11 +34,7 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
|
|||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
max_seq_len = 2048
|
||||
# Auto batch size discovery
|
||||
auto_batch_size = True # Enable/disable auto-discovery
|
||||
batch_size_margin = 0.85 # Safety margin (85% of max)
|
||||
batch_size_cache = False # Enable result caching
|
||||
device_batch_size = None # If None, auto-discover; if set, use that value
|
||||
device_batch_size = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
|
|
@ -70,6 +66,18 @@ if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
|||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
|
||||
# Create batch sample function for auto-discovery
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
def create_batch_sample_fn(max_seq_len, vocab_size, device):
|
||||
def sample_fn(batch_size, seq_len):
|
||||
inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||
targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||
return inputs, targets
|
||||
return sample_fn
|
||||
|
||||
batch_sample_fn = create_batch_sample_fn(max_seq_len, vocab_size, device)
|
||||
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user