mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Merge pull request #16 from Dianababaei/feat/auto-batch-size-discovery-config
Refactor training scripts: update base model training, chat SFT implementation, and intermediate stage parameters
This commit is contained in:
commit
38801c983d
|
|
@ -35,7 +35,11 @@ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable
|
|||
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
# Optimization
|
||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
||||
# Auto batch size discovery
|
||||
auto_batch_size = True # Enable/disable auto-discovery
|
||||
batch_size_margin = 0.85 # Safety margin (85% of max)
|
||||
batch_size_cache = False # Enable result caching
|
||||
device_batch_size = None # If None, auto-discover; if set, use that value
|
||||
total_batch_size = 524288 # total desired batch size, in #tokens
|
||||
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,11 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
|
|||
step = None # step to load the model from (base model or midtrained model)
|
||||
# compute/precision
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
# Auto batch size discovery
|
||||
auto_batch_size = True # Enable/disable auto-discovery
|
||||
batch_size_margin = 0.85 # Safety margin (85% of max)
|
||||
batch_size_cache = False # Enable result caching
|
||||
device_batch_size = None # If None, auto-discover; if set, use that value
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
|
||||
|
|
@ -70,11 +74,6 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sf
|
|||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||
orig_model = model # original, uncompiled model
|
||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
# Validation: Log compilation status
|
||||
if hasattr(model, '_orig_mod'):
|
||||
print0("[VALIDATION] ✓ Model is compiled (torch.compile detected)")
|
||||
else:
|
||||
print0("[VALIDATION] ✗ Model is NOT compiled (running in eager mode)")
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -161,16 +160,10 @@ def get_lr_multiplier(it):
|
|||
lrm = 1.0 - it / num_iterations
|
||||
return lrm
|
||||
|
||||
# Validation: Performance tracking variables
|
||||
import time
|
||||
step_times = []
|
||||
step_tokens = []
|
||||
|
||||
# Go!
|
||||
step = 0
|
||||
train_iter = iter(train_loader)
|
||||
for step in range(num_iterations):
|
||||
step_start_time = time.time()
|
||||
last_step = step == num_iterations - 1
|
||||
|
||||
# evaluate the validation loss
|
||||
|
|
@ -217,9 +210,6 @@ for step in range(num_iterations):
|
|||
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
||||
for micro_step in range(grad_accum_steps):
|
||||
train_inputs, train_targets = next(train_iter)
|
||||
# Validation: Log batch shapes for first 3 steps to verify fixed padding
|
||||
if step < 3 and micro_step == 0:
|
||||
print0(f"[VALIDATION] Step {step} | Batch shape: {train_inputs.shape}")
|
||||
with autocast_ctx:
|
||||
loss = model(train_inputs, train_targets)
|
||||
train_loss = loss.detach() # for logging
|
||||
|
|
@ -240,33 +230,15 @@ for step in range(num_iterations):
|
|||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
# Validation: Calculate performance metrics
|
||||
step_end_time = time.time()
|
||||
step_time = step_end_time - step_start_time
|
||||
|
||||
# logging
|
||||
train_loss_item = train_loss.item()
|
||||
num_tokens_item = num_tokens.item()
|
||||
|
||||
# Validation: Track performance (skip first 5 warmup iterations)
|
||||
if step >= 5:
|
||||
step_times.append(step_time)
|
||||
step_tokens.append(num_tokens_item)
|
||||
|
||||
# Validation: Calculate and log performance metrics every 10 steps (after warmup)
|
||||
if step >= 5 and step % 10 == 0:
|
||||
avg_step_time = sum(step_times[-10:]) / len(step_times[-10:]) if len(step_times) >= 10 else sum(step_times) / len(step_times)
|
||||
avg_tokens = sum(step_tokens[-10:]) / len(step_tokens[-10:]) if len(step_tokens) >= 10 else sum(step_tokens) / len(step_tokens)
|
||||
tokens_per_sec = avg_tokens / avg_step_time if avg_step_time > 0 else 0
|
||||
print0(f"[VALIDATION] Avg time/step: {avg_step_time:.3f}s | Tokens/sec: {tokens_per_sec:.1f}")
|
||||
|
||||
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,} | time: {step_time:.3f}s")
|
||||
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"lrm": lrm,
|
||||
"train_loss": train_loss_item,
|
||||
"num_tokens": num_tokens_item,
|
||||
"step_time": step_time,
|
||||
})
|
||||
step += 1
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,11 @@ model_tag = None # model tag to load the model from (base model or midtrained mo
|
|||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
# Auto batch size discovery
|
||||
auto_batch_size = True # Enable/disable auto-discovery
|
||||
batch_size_margin = 0.85 # Safety margin (85% of max)
|
||||
batch_size_cache = False # Enable result caching
|
||||
device_batch_size = None # If None, auto-discover; if set, use that value
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user