mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-03 16:00:28 +00:00
Capture PyTorch execution traces and CUDA memory snapshots. Traces display detailed CPU and CUDA activity, including individual CUDA kernel calls. CUDA memory snapshots visualize all memory allocations, helping diagnose CUDA out-of-memory errors, investigate memory leaks, or understand GPU memory usage for educational purposes. Enable profiling with the --enable_profiling=True flag in speedrun.sh. See PROFILING.md for documentation and example visualizations.
390 lines
18 KiB
Python
390 lines
18 KiB
Python
"""
|
|
Train model. Run as:
|
|
|
|
python base_train.py
|
|
|
|
or distributed as:
|
|
|
|
torchrun --nproc_per_node=8 base_train.py
|
|
"""
|
|
|
|
import os
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
import time
|
|
import wandb
|
|
import torch
|
|
|
|
from nanochat.gpt import GPT, GPTConfig
|
|
from nanochat.dataloader import tokenizing_distributed_data_loader
|
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
|
|
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
|
from nanochat.checkpoint_manager import save_checkpoint
|
|
from nanochat.loss_eval import evaluate_bpb
|
|
from nanochat.engine import Engine
|
|
from nanochat.profiling import ProfilingManager
|
|
from scripts.base_eval import evaluate_model
|
|
print_banner()
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# User settings
|
|
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
|
# Model architecture
|
|
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
|
max_seq_len = 2048 # max context length
|
|
# Training horizon. Only one of these 3 will be used, in this order of precedence.
|
|
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
|
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
|
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
|
# Optimization
|
|
device_batch_size = 32 # per-device batch size (set to not OOM)
|
|
total_batch_size = 524288 # total desired batch size, in #tokens
|
|
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
|
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
|
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
|
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
|
|
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
|
# Evaluation
|
|
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
|
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
|
core_metric_every = 2000 # every how many steps to evaluate the core metric
|
|
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
|
sample_every = 2000 # every how many steps to sample from the model
|
|
# Profiling configuration (output files will be placed in ~/.cache/nanochat/profile_traces/<timestamp>/ by default)
|
|
# Master switch: enables both PyTorch profiler (traces) and CUDA memory profiler (snapshots)
|
|
enable_profiling = False
|
|
# Output
|
|
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
|
# now allow CLI to override the settings via the configurator lol
|
|
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
|
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
|
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
|
# -----------------------------------------------------------------------------
|
|
|
|
# Compute init
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
|
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
|
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
|
|
|
# Get base directory early for profiling setup
|
|
base_dir = get_base_dir()
|
|
|
|
# Initialize profiling manager
|
|
profiler = ProfilingManager(
|
|
base_dir=base_dir,
|
|
ddp_local_rank=ddp_local_rank,
|
|
master_process=master_process,
|
|
enable_profiling=enable_profiling,
|
|
print_fn=print0,
|
|
)
|
|
|
|
# wandb logging init
|
|
use_dummy_wandb = run == "dummy" or not master_process
|
|
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
|
|
|
|
# Tokenizer will be useful for evaluation, also we need the vocab size
|
|
tokenizer = get_tokenizer()
|
|
token_bytes = get_token_bytes(device=device)
|
|
vocab_size = tokenizer.get_vocab_size()
|
|
print0(f"Vocab size: {vocab_size:,}")
|
|
|
|
# Model kwargs are derived from the desired depth of the model
|
|
num_layers = depth
|
|
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
|
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
|
num_kv_heads = num_heads # 1:1 MQA ratio
|
|
print0(f"num_layers: {num_layers}")
|
|
print0(f"model_dim: {model_dim}")
|
|
print0(f"num_heads: {num_heads}")
|
|
print0(f"num_kv_heads: {num_kv_heads}")
|
|
|
|
# Optimizer / data / training length related hyperparameters
|
|
# figure out the needed gradient accumulation to reach the desired total batch size
|
|
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
|
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
|
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
|
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
|
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
|
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
|
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
|
# -----------------------------------------------------------------------------
|
|
# Initialize the Model
|
|
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
|
# Start profiling model initialization
|
|
if enable_profiling:
|
|
profiler.start_cuda_memory_recording("model_init")
|
|
profiler.start_torch_profiler("model_init", warmup=0, active=1)
|
|
with torch.device("meta"):
|
|
model_config = GPTConfig(**model_config_kwargs)
|
|
model = GPT(model_config)
|
|
model.to_empty(device="cuda")
|
|
model.init_weights()
|
|
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
|
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
|
# Complete profiling model initialization
|
|
if enable_profiling:
|
|
profiler.step_torch_profiler()
|
|
profiler.dump_cuda_memory_snapshot("model_init")
|
|
num_params = sum(p.numel() for p in model.parameters())
|
|
print0(f"Number of parameters: {num_params:,}")
|
|
num_flops_per_token = model.estimate_flops()
|
|
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
|
|
|
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
|
|
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
|
|
if num_iterations > 0:
|
|
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
|
elif target_flops > 0:
|
|
# calculate the number of iterations from the target flops
|
|
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
|
|
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
|
elif target_param_data_ratio > 0:
|
|
# calculate the number of iterations from the target param data ratio
|
|
target_tokens = target_param_data_ratio * num_params
|
|
num_iterations = target_tokens // total_batch_size
|
|
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
|
else:
|
|
raise ValueError("No training horizon specified")
|
|
total_tokens = total_batch_size * num_iterations
|
|
print0(f"Total number of training tokens: {total_tokens:,}")
|
|
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
|
|
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
|
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
|
adamw_optimizer, muon_optimizer = optimizers
|
|
|
|
# Initialize the DataLoaders for train/val
|
|
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
|
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
|
|
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
|
|
x, y = next(train_loader) # kick off load of the very first batch of data
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Set up hyperparameter schedulers
|
|
|
|
# Learning rate scheduler
|
|
# TODO: experiment with a short warmup for the AdamW params (expecting slight improvement)
|
|
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
|
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
|
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
|
def get_lr_multiplier(it):
|
|
warmup_iters = round(warmup_ratio * num_iterations)
|
|
warmdown_iters = round(warmdown_ratio * num_iterations)
|
|
if it < warmup_iters:
|
|
return (it + 1) / warmup_iters
|
|
elif it <= num_iterations - warmdown_iters:
|
|
return 1.0
|
|
else:
|
|
progress = (num_iterations - it) / warmdown_iters
|
|
return progress * 1.0 + (1 - progress) * final_lr_frac
|
|
|
|
# Momentum scheduler for Muon optimizer
|
|
def get_muon_momentum(it):
|
|
frac = min(it / 300, 1)
|
|
momentum = (1 - frac) * 0.85 + frac * 0.95
|
|
return momentum
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Training loop
|
|
min_val_bpb = float("inf")
|
|
smooth_train_loss = 0 # EMA of training loss
|
|
ema_beta = 0.9 # EMA decay factor
|
|
total_training_time = 0 # total wall-clock time of training
|
|
# note that we run +1 steps only so that we can eval and save at the end
|
|
for step in range(num_iterations + 1):
|
|
last_step = step == num_iterations
|
|
flops_so_far = num_flops_per_token * total_batch_size * step
|
|
|
|
# once in a while: evaluate the val bpb (all ranks participate)
|
|
if last_step or step % eval_every == 0:
|
|
model.eval()
|
|
val_loader = build_val_loader()
|
|
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
|
with autocast_ctx:
|
|
# Pass profiler for first evaluation if profiling is enabled
|
|
prof_arg = profiler if (enable_profiling and step == 0) else None
|
|
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes, profiler=prof_arg)
|
|
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
|
if val_bpb < min_val_bpb:
|
|
min_val_bpb = val_bpb
|
|
wandb_run.log({
|
|
"step": step,
|
|
"total_training_flops": flops_so_far,
|
|
"total_training_time": total_training_time,
|
|
"val/bpb": val_bpb,
|
|
})
|
|
model.train()
|
|
|
|
# once in a while: estimate the CORE metric (all ranks participate)
|
|
# use the original uncompiled model because the inputs keep changing shape
|
|
if last_step or (step > 0 and step % core_metric_every == 0):
|
|
model.eval()
|
|
with autocast_ctx:
|
|
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
|
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
|
wandb_run.log({
|
|
"step": step,
|
|
"total_training_flops": flops_so_far,
|
|
"core_metric": results["core_metric"],
|
|
"centered_results": results["centered_results"],
|
|
})
|
|
model.train()
|
|
|
|
# once in a while: sample from the model (only on master process)
|
|
# use the original uncompiled model because the inputs keep changing shape
|
|
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
|
|
model.eval()
|
|
prompts = [
|
|
"The capital of France is",
|
|
"The chemical symbol of gold is",
|
|
"If yesterday was Friday, then tomorrow will be",
|
|
"The opposite of hot is",
|
|
"The planets of the solar system are:",
|
|
"My favorite color is",
|
|
"If 5*x + 3 = 13, then x is",
|
|
]
|
|
engine = Engine(model, tokenizer)
|
|
for prompt in prompts:
|
|
tokens = tokenizer(prompt, prepend="<|bos|>")
|
|
with autocast_ctx:
|
|
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
|
print0(tokenizer.decode(sample[0]))
|
|
model.train()
|
|
|
|
# save checkpoint at the end of the run (only on master process)
|
|
if master_process and last_step:
|
|
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
|
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
|
save_checkpoint(
|
|
checkpoint_dir,
|
|
step,
|
|
orig_model.state_dict(),
|
|
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
|
{
|
|
"step": step,
|
|
"val_bpb": val_bpb, # loss at last step
|
|
"model_config": model_config_kwargs,
|
|
"user_config": user_config, # inputs to the training script
|
|
"device_batch_size": device_batch_size,
|
|
"max_seq_len": max_seq_len,
|
|
}
|
|
)
|
|
|
|
if last_step:
|
|
break
|
|
|
|
# -------------------------------------------------------------------------
|
|
# single training step
|
|
# evaluate the gradient
|
|
torch.cuda.synchronize()
|
|
t0 = time.time()
|
|
|
|
# Profile micro-steps if enabled (only for first 10 steps)
|
|
profile_ctx = None
|
|
if enable_profiling and step == 0:
|
|
profile_ctx = profiler.profile_section("training_microsteps", warmup=1, active=10)
|
|
profile_ctx.__enter__()
|
|
|
|
for micro_step in range(grad_accum_steps):
|
|
with autocast_ctx:
|
|
loss = model(x, y)
|
|
train_loss = loss.detach() # for logging
|
|
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
|
loss.backward()
|
|
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
|
if profile_ctx is not None:
|
|
profile_ctx.step()
|
|
|
|
# Close profiling context if it was opened
|
|
if profile_ctx is not None:
|
|
profile_ctx.__exit__(None, None, None)
|
|
|
|
# Start optimizer step profiling if enabled
|
|
optimizer_profile_ctx = None
|
|
if enable_profiling and step == 0:
|
|
optimizer_profile_ctx = profiler.profile_section("optimizer_step", warmup=0, active=1)
|
|
optimizer_profile_ctx.__enter__()
|
|
|
|
# gradient clipping (TODO possibly expertiment with)
|
|
if grad_clip > 0.0:
|
|
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
|
# step the optimizers
|
|
lrm = get_lr_multiplier(step)
|
|
for opt in optimizers:
|
|
for group in opt.param_groups:
|
|
group["lr"] = group["initial_lr"] * lrm
|
|
muon_momentum = get_muon_momentum(step)
|
|
for group in muon_optimizer.param_groups:
|
|
group["momentum"] = muon_momentum
|
|
for opt in optimizers:
|
|
opt.step()
|
|
model.zero_grad(set_to_none=True)
|
|
torch.cuda.synchronize()
|
|
|
|
# Step and close optimizer profiling if active
|
|
if optimizer_profile_ctx is not None:
|
|
optimizer_profile_ctx.step()
|
|
optimizer_profile_ctx.__exit__(None, None, None)
|
|
t1 = time.time()
|
|
dt = t1 - t0
|
|
# -------------------------------------------------------------------------
|
|
|
|
# logging
|
|
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
|
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
|
pct_done = 100 * step / num_iterations
|
|
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
|
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
|
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
|
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
|
if step > 10:
|
|
total_training_time += dt # only count the time after the first 10 steps
|
|
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
|
if step % 100 == 0:
|
|
wandb_run.log({
|
|
"step": step,
|
|
"total_training_flops": flops_so_far,
|
|
"total_training_time": total_training_time,
|
|
"train/loss": debiased_smooth_loss,
|
|
"train/lrm": lrm,
|
|
"train/dt": dt,
|
|
"train/tok_per_sec": tok_per_sec,
|
|
"train/mfu": mfu,
|
|
})
|
|
|
|
# print a few more stats
|
|
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
|
print0(f"Total training time: {total_training_time/60:.2f}m")
|
|
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
|
|
|
# Log to report
|
|
from nanochat.report import get_report
|
|
get_report().log(section="Base model training", data=[
|
|
user_config, # CLI args
|
|
{ # stats about the training setup
|
|
"Number of parameters": num_params,
|
|
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
|
"Calculated number of iterations": num_iterations,
|
|
"Number of training tokens": total_tokens,
|
|
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
|
|
"DDP world size": ddp_world_size,
|
|
"warmup_ratio": warmup_ratio,
|
|
"warmdown_ratio": warmdown_ratio,
|
|
"final_lr_frac": final_lr_frac,
|
|
},
|
|
{ # stats about training outcomes
|
|
"Minimum validation bpb": min_val_bpb,
|
|
"Final validation bpb": val_bpb,
|
|
"CORE metric estimate": results["core_metric"],
|
|
"MFU %": f"{mfu:.2f}%",
|
|
"Total training flops": f"{flops_so_far:e}",
|
|
"Total training time": f"{total_training_time/60:.2f}m",
|
|
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
|
|
}
|
|
])
|
|
|
|
# cleanup
|
|
wandb_run.finish() # wandb run finish
|
|
compute_cleanup()
|