diff --git a/scripts/base_train.py b/scripts/base_train.py index b691ed4..f3902d8 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -35,7 +35,11 @@ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable) target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable) # Optimization -device_batch_size = 32 # per-device batch size (set to not OOM) +# Auto batch size discovery +auto_batch_size = True # Enable/disable auto-discovery +batch_size_margin = 0.85 # Safety margin (85% of max) +batch_size_cache = False # Enable result caching +device_batch_size = None # If None, auto-discover; if set, use that value total_batch_size = 524288 # total desired batch size, in #tokens embedding_lr = 0.2 # learning rate for the embedding parameters (Adam) unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index a642126..f16fcdf 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -36,7 +36,11 @@ model_tag = None # model tag to load the model from (base model or midtrained mo step = None # step to load the model from (base model or midtrained model) # compute/precision dtype = "bfloat16" -device_batch_size = 4 # max to avoid OOM +# Auto batch size discovery +auto_batch_size = True # Enable/disable auto-discovery +batch_size_margin = 0.85 # Safety margin (85% of max) +batch_size_cache = False # Enable result caching +device_batch_size = None # If None, auto-discover; if set, use that value # optimization num_epochs = 1 max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations) @@ -70,11 +74,6 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sf model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step) orig_model = model # original, uncompiled model # model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs -# Validation: Log compilation status -if hasattr(model, '_orig_mod'): - print0("[VALIDATION] ✓ Model is compiled (torch.compile detected)") -else: - print0("[VALIDATION] ✗ Model is NOT compiled (running in eager mode)") engine = Engine(model, tokenizer) # will be used for inline model evaluation only # ----------------------------------------------------------------------------- @@ -161,16 +160,10 @@ def get_lr_multiplier(it): lrm = 1.0 - it / num_iterations return lrm -# Validation: Performance tracking variables -import time -step_times = [] -step_tokens = [] - # Go! step = 0 train_iter = iter(train_loader) for step in range(num_iterations): - step_start_time = time.time() last_step = step == num_iterations - 1 # evaluate the validation loss @@ -217,9 +210,6 @@ for step in range(num_iterations): num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen for micro_step in range(grad_accum_steps): train_inputs, train_targets = next(train_iter) - # Validation: Log batch shapes for first 3 steps to verify fixed padding - if step < 3 and micro_step == 0: - print0(f"[VALIDATION] Step {step} | Batch shape: {train_inputs.shape}") with autocast_ctx: loss = model(train_inputs, train_targets) train_loss = loss.detach() # for logging @@ -240,33 +230,15 @@ for step in range(num_iterations): opt.step() model.zero_grad(set_to_none=True) - # Validation: Calculate performance metrics - step_end_time = time.time() - step_time = step_end_time - step_start_time - # logging train_loss_item = train_loss.item() num_tokens_item = num_tokens.item() - - # Validation: Track performance (skip first 5 warmup iterations) - if step >= 5: - step_times.append(step_time) - step_tokens.append(num_tokens_item) - - # Validation: Calculate and log performance metrics every 10 steps (after warmup) - if step >= 5 and step % 10 == 0: - avg_step_time = sum(step_times[-10:]) / len(step_times[-10:]) if len(step_times) >= 10 else sum(step_times) / len(step_times) - avg_tokens = sum(step_tokens[-10:]) / len(step_tokens[-10:]) if len(step_tokens) >= 10 else sum(step_tokens) / len(step_tokens) - tokens_per_sec = avg_tokens / avg_step_time if avg_step_time > 0 else 0 - print0(f"[VALIDATION] Avg time/step: {avg_step_time:.3f}s | Tokens/sec: {tokens_per_sec:.1f}") - - print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,} | time: {step_time:.3f}s") + print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}") wandb_run.log({ "step": step, "lrm": lrm, "train_loss": train_loss_item, "num_tokens": num_tokens_item, - "step_time": step_time, }) step += 1 diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 90ab954..2c23ed4 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -34,7 +34,11 @@ model_tag = None # model tag to load the model from (base model or midtrained mo step = None # step to load the model from (base model or midtrained model) dtype = "bfloat16" max_seq_len = 2048 -device_batch_size = 32 +# Auto batch size discovery +auto_batch_size = True # Enable/disable auto-discovery +batch_size_margin = 0.85 # Safety margin (85% of max) +batch_size_cache = False # Enable result caching +device_batch_size = None # If None, auto-discover; if set, use that value unembedding_lr = 0.004 embedding_lr = 0.2 matrix_lr = 0.02