diff --git a/scripts/base_train.py b/scripts/base_train.py index 1e025de..6e34911 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -35,8 +35,11 @@ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable) target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable) # Optimization -device_batch_size = 32 # per-device batch size (set to not OOM) +device_batch_size = None # per-device batch size (set to not OOM), None = auto-discover total_batch_size = 524288 # total desired batch size, in #tokens +auto_batch_size = True # whether to auto-discover optimal batch size +batch_size_margin = 0.85 # safety margin for auto-discovered batch size +batch_size_cache = True # whether to cache auto-discovered batch size embedding_lr = 0.2 # learning rate for the embedding parameters (Adam) unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam) weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam) @@ -81,15 +84,6 @@ print0(f"model_dim: {model_dim}") print0(f"num_heads: {num_heads}") print0(f"num_kv_heads: {num_kv_heads}") -# Optimizer / data / training length related hyperparameters -# figure out the needed gradient accumulation to reach the desired total batch size -tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank -world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert total_batch_size % world_tokens_per_fwdbwd == 0 -grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd -print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") -print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") -print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") # ----------------------------------------------------------------------------- # Initialize the Model model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) @@ -99,15 +93,41 @@ with torch.device("meta"): model.to_empty(device="cuda") model.init_weights() -# Create batch sample function for auto-discovery -def create_batch_sample_fn(max_seq_len, vocab_size, device): - def sample_fn(batch_size, seq_len): - inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) - targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) - return inputs, targets - return sample_fn +# Create a batch sample function for auto-discovery +def batch_sample_fn(batch_size): + """Generate a sample batch for auto-discovery of optimal batch size.""" + return ( + torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int32, device="cuda"), + torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device="cuda") + ) -batch_sample_fn = create_batch_sample_fn(max_seq_len, vocab_size, device) +# Auto-discover optimal batch size if not manually set +if auto_batch_size and device_batch_size is None: + from nanochat.auto_batch_size import find_optimal_device_batch_size + device_batch_size = find_optimal_device_batch_size( + model=model, + max_seq_len=max_seq_len, + total_batch_size=total_batch_size, + ddp_world_size=ddp_world_size, + data_sample_fn=batch_sample_fn, + override=device_batch_size, + safety_margin=batch_size_margin, + enable_cache=batch_size_cache, + ddp_rank=ddp_rank, + ) +elif device_batch_size is None: + device_batch_size = 8 + print0(f"Auto-discovery disabled, using default device_batch_size={device_batch_size}") + +# Optimizer / data / training length related hyperparameters +# figure out the needed gradient accumulation to reach the desired total batch size +tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank +world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks +assert total_batch_size % world_tokens_per_fwdbwd == 0 +grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd +print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") +print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") +print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") orig_model = model # original, uncompiled model, for saving raw model state_dict model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index f0a6294..f46fe2f 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -36,7 +36,10 @@ model_tag = None # model tag to load the model from (base model or midtrained mo step = None # step to load the model from (base model or midtrained model) # compute/precision dtype = "bfloat16" -device_batch_size = 4 # max to avoid OOM +device_batch_size = None # per-device batch size (set to not OOM), None = auto-discover +auto_batch_size = True # whether to auto-discover optimal batch size +batch_size_margin = 0.85 # safety margin for auto-discovered batch size +batch_size_cache = True # whether to cache auto-discovered batch size # optimization num_epochs = 1 max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations) @@ -69,17 +72,36 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sf # Load the model and tokenizer model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step) -# Create batch sample function for auto-discovery -max_seq_len = model.config.sequence_len -def create_batch_sample_fn(max_seq_len, vocab_size, device): - def sample_fn(batch_size, seq_len): - # Use max_seq_len (worst case for variable-length sequences) - inputs = torch.randint(0, vocab_size, (batch_size, max_seq_len), device=device) - targets = torch.full((batch_size, max_seq_len), -1, dtype=torch.long, device=device) - return inputs, targets - return sample_fn +# Create a batch sample function for auto-discovery +def batch_sample_fn(batch_size): + """Generate a sample batch for auto-discovery of optimal batch size.""" + vocab_size = tokenizer.get_vocab_size() + # For chat SFT, we use variable length sequences, so we use max_seq_len for discovery + # Note: We need to define max_seq_len before this point, which is already done in config + pad_token_id = tokenizer.encode_special("<|assistant_end|>") + max_len = 2048 # Use a fixed max length for discovery + return ( + torch.randint(0, vocab_size, (batch_size, max_len), dtype=torch.int32, device="cuda"), + torch.randint(0, vocab_size, (batch_size, max_len), dtype=torch.int64, device="cuda") + ) -batch_sample_fn = create_batch_sample_fn(max_seq_len, model.config.vocab_size, device) +# Auto-discover optimal batch size if not manually set +if auto_batch_size and device_batch_size is None: + from nanochat.auto_batch_size import find_optimal_device_batch_size + device_batch_size = find_optimal_device_batch_size( + model=model, + max_seq_len=2048, # Use fixed max length for discovery + total_batch_size=target_examples_per_step, + ddp_world_size=ddp_world_size, + data_sample_fn=batch_sample_fn, + override=device_batch_size, + safety_margin=batch_size_margin, + enable_cache=batch_size_cache, + ddp_rank=ddp_rank, + ) +elif device_batch_size is None: + device_batch_size = 4 + print0(f"Auto-discovery disabled, using default device_batch_size={device_batch_size}") orig_model = model # original, uncompiled model # model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 86a4819..435ce7f 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -34,7 +34,10 @@ model_tag = None # model tag to load the model from (base model or midtrained mo step = None # step to load the model from (base model or midtrained model) dtype = "bfloat16" max_seq_len = 2048 -device_batch_size = 32 +device_batch_size = None # per-device batch size (set to not OOM), None = auto-discover +auto_batch_size = True # whether to auto-discover optimal batch size +batch_size_margin = 0.85 # safety margin for auto-discovered batch size +batch_size_cache = True # whether to cache auto-discovered batch size unembedding_lr = 0.004 embedding_lr = 0.2 matrix_lr = 0.02 @@ -61,23 +64,39 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step) + +# Create a batch sample function for auto-discovery +def batch_sample_fn(batch_size): + """Generate a sample batch for auto-discovery of optimal batch size.""" + vocab_size = tokenizer.get_vocab_size() + return ( + torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int32, device="cuda"), + torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device="cuda") + ) + +# Auto-discover optimal batch size if not manually set +if auto_batch_size and device_batch_size is None: + from nanochat.auto_batch_size import find_optimal_device_batch_size + device_batch_size = find_optimal_device_batch_size( + model=model, + max_seq_len=max_seq_len, + total_batch_size=total_batch_size, + ddp_world_size=ddp_world_size, + data_sample_fn=batch_sample_fn, + override=device_batch_size, + safety_margin=batch_size_margin, + enable_cache=batch_size_cache, + ddp_rank=ddp_rank, + ) +elif device_batch_size is None: + device_batch_size = 8 + print0(f"Auto-discovery disabled, using default device_batch_size={device_batch_size}") + pretrain_batch_size = meta.get("device_batch_size", None) if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size: print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?") orig_model = model model = torch.compile(model, dynamic=False) - -# Create batch sample function for auto-discovery -vocab_size = tokenizer.get_vocab_size() -def create_batch_sample_fn(max_seq_len, vocab_size, device): - def sample_fn(batch_size, seq_len): - inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) - targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) - return inputs, targets - return sample_fn - -batch_sample_fn = create_batch_sample_fn(max_seq_len, vocab_size, device) - depth = model.config.n_layer num_flops_per_token = model.estimate_flops() tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank