Merge pull request #16 from Dianababaei/feat/auto-batch-size-discovery-config

Refactor training scripts: update base model training, chat SFT implementation, and intermediate stage parameters
This commit is contained in:
Dianababaei 2025-11-05 20:18:26 +03:30 committed by GitHub
commit 38801c983d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 36 deletions

View File

@ -35,7 +35,11 @@ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
# Optimization
device_batch_size = 32 # per-device batch size (set to not OOM)
# Auto batch size discovery
auto_batch_size = True # Enable/disable auto-discovery
batch_size_margin = 0.85 # Safety margin (85% of max)
batch_size_cache = False # Enable result caching
device_batch_size = None # If None, auto-discover; if set, use that value
total_batch_size = 524288 # total desired batch size, in #tokens
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)

View File

@ -36,7 +36,11 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
step = None # step to load the model from (base model or midtrained model)
# compute/precision
dtype = "bfloat16"
device_batch_size = 4 # max to avoid OOM
# Auto batch size discovery
auto_batch_size = True # Enable/disable auto-discovery
batch_size_margin = 0.85 # Safety margin (85% of max)
batch_size_cache = False # Enable result caching
device_batch_size = None # If None, auto-discover; if set, use that value
# optimization
num_epochs = 1
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
@ -70,11 +74,6 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sf
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
# Validation: Log compilation status
if hasattr(model, '_orig_mod'):
print0("[VALIDATION] ✓ Model is compiled (torch.compile detected)")
else:
print0("[VALIDATION] ✗ Model is NOT compiled (running in eager mode)")
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
# -----------------------------------------------------------------------------
@ -161,16 +160,10 @@ def get_lr_multiplier(it):
lrm = 1.0 - it / num_iterations
return lrm
# Validation: Performance tracking variables
import time
step_times = []
step_tokens = []
# Go!
step = 0
train_iter = iter(train_loader)
for step in range(num_iterations):
step_start_time = time.time()
last_step = step == num_iterations - 1
# evaluate the validation loss
@ -217,9 +210,6 @@ for step in range(num_iterations):
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
for micro_step in range(grad_accum_steps):
train_inputs, train_targets = next(train_iter)
# Validation: Log batch shapes for first 3 steps to verify fixed padding
if step < 3 and micro_step == 0:
print0(f"[VALIDATION] Step {step} | Batch shape: {train_inputs.shape}")
with autocast_ctx:
loss = model(train_inputs, train_targets)
train_loss = loss.detach() # for logging
@ -240,33 +230,15 @@ for step in range(num_iterations):
opt.step()
model.zero_grad(set_to_none=True)
# Validation: Calculate performance metrics
step_end_time = time.time()
step_time = step_end_time - step_start_time
# logging
train_loss_item = train_loss.item()
num_tokens_item = num_tokens.item()
# Validation: Track performance (skip first 5 warmup iterations)
if step >= 5:
step_times.append(step_time)
step_tokens.append(num_tokens_item)
# Validation: Calculate and log performance metrics every 10 steps (after warmup)
if step >= 5 and step % 10 == 0:
avg_step_time = sum(step_times[-10:]) / len(step_times[-10:]) if len(step_times) >= 10 else sum(step_times) / len(step_times)
avg_tokens = sum(step_tokens[-10:]) / len(step_tokens[-10:]) if len(step_tokens) >= 10 else sum(step_tokens) / len(step_tokens)
tokens_per_sec = avg_tokens / avg_step_time if avg_step_time > 0 else 0
print0(f"[VALIDATION] Avg time/step: {avg_step_time:.3f}s | Tokens/sec: {tokens_per_sec:.1f}")
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,} | time: {step_time:.3f}s")
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
wandb_run.log({
"step": step,
"lrm": lrm,
"train_loss": train_loss_item,
"num_tokens": num_tokens_item,
"step_time": step_time,
})
step += 1

View File

@ -34,7 +34,11 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
max_seq_len = 2048
device_batch_size = 32
# Auto batch size discovery
auto_batch_size = True # Enable/disable auto-discovery
batch_size_margin = 0.85 # Safety margin (85% of max)
batch_size_cache = False # Enable result caching
device_batch_size = None # If None, auto-discover; if set, use that value
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02