mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-10 01:40:17 +00:00
Keep speedrun changes functional
This commit is contained in:
parent
d6a169b329
commit
4680e799fb
|
|
@ -70,9 +70,19 @@ echo "Waiting for dataset download to complete..."
|
|||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# d22 Muon+/row-eq + hashed bigram recipe.
|
||||
# scripts/base_train defaults are the submission defaults: fixed 11,600
|
||||
# optimizer steps, eval every 250, and one in-training CORE pass halfway through.
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
|
||||
--run=$WANDB_RUN \
|
||||
--fp8 \
|
||||
--depth=22 \
|
||||
--num-iterations=11600 \
|
||||
--target-param-data-ratio=11 \
|
||||
--total-batch-size=524288 \
|
||||
--scalar-lr=0.3 \
|
||||
--bigram-embed-factor=5 \
|
||||
--muon-plus \
|
||||
--muon-eq=row \
|
||||
--core-metric-every=5800 \
|
||||
--sample-every=-1
|
||||
# evaluate the model: CORE metric, BPB on train/val, and draw samples
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ or distributed as:
|
|||
torchrun --nproc_per_node=8 -m scripts.base_train
|
||||
|
||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 --no-fp8 --no-muon-plus --muon-eq=none --bigram-embed-factor=0
|
||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -41,39 +41,32 @@ print_banner()
|
|||
parser = argparse.ArgumentParser(description="Pretrain base model")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
parser.add_argument("--train-log-every", type=int, default=50, help="print training metrics every N steps; values >1 avoid per-step CPU/GPU sync")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# FP8 training
|
||||
parser.add_argument("--fp8", dest="fp8", action="store_true", default=True, help="enable FP8 training (requires H100+ GPU and torchao)")
|
||||
parser.add_argument("--no-fp8", dest="fp8", action="store_false", help="disable FP8 training")
|
||||
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
|
||||
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
|
||||
parser.add_argument("--compile-mode", type=str, default="max-autotune-no-cudagraphs", choices=["", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], help="optional torch.compile mode")
|
||||
# Model architecture
|
||||
parser.add_argument("--depth", type=int, default=22, help="depth of the Transformer model")
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
parser.add_argument("--bigram-embed-factor", type=int, default=5, help="if >0, add a hashed bigram embedding residual")
|
||||
parser.add_argument("--bigram-lambda-init", type=float, default=0.05, help="initial layer residual scale for --bigram-embed-factor")
|
||||
parser.add_argument("--bigram-embedding-lr-mult", type=float, default=1.0, help="bigram embedding LR multiplier relative to --embedding-lr")
|
||||
parser.add_argument("--bigram-lambda-lr", type=float, default=0.004, help="AdamW LR for bigram layer lambdas before dmodel schedule scaling")
|
||||
parser.add_argument("--bigram-embed-factor", type=int, default=0, help="if >0, add a hashed bigram embedding residual")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=11600, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=float, default=11, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=float, default=12, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)")
|
||||
parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.3, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--muon-plus", dest="muon_plus", action="store_true", default=True, help="apply Muon+ style post-orthogonalization Frobenius renormalization")
|
||||
parser.add_argument("--no-muon-plus", dest="muon_plus", action="store_false", help="disable Muon+ post-orthogonalization renormalization")
|
||||
parser.add_argument("--muon-eq", type=str, default="row", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--muon-plus", action="store_true", help="apply Muon+ style post-orthogonalization Frobenius renormalization")
|
||||
parser.add_argument("--muon-eq", type=str, default="none", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization")
|
||||
parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR")
|
||||
|
|
@ -81,24 +74,16 @@ parser.add_argument("--resume-from-step", type=int, default=-1, help="resume tra
|
|||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--skip-initial-eval", dest="skip_initial_eval", action="store_true", default=True, help="skip the step 0 validation pass; final validation still runs")
|
||||
parser.add_argument("--initial-eval", dest="skip_initial_eval", action="store_false", help="run validation at step 0")
|
||||
parser.add_argument("--core-metric-every", type=int, default=5800, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample-every", type=int, default=-1, help="sample from model every N steps (-1 = disable)")
|
||||
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
# Output
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
if args.train_log_every <= 0:
|
||||
parser.error("--train-log-every must be positive")
|
||||
if args.bigram_embed_factor < 0:
|
||||
parser.error("--bigram-embed-factor must be non-negative")
|
||||
if args.bigram_lambda_lr < 0:
|
||||
parser.error("--bigram-lambda-lr must be non-negative")
|
||||
if args.bigram_embedding_lr_mult <= 0:
|
||||
parser.error("--bigram-embedding-lr-mult must be positive")
|
||||
# -----------------------------------------------------------------------------
|
||||
# Compute init and wandb logging
|
||||
|
||||
|
|
@ -158,7 +143,6 @@ def build_model_meta(depth):
|
|||
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
||||
window_pattern=args.window_pattern,
|
||||
bigram_embed_factor=args.bigram_embed_factor,
|
||||
bigram_lambda_init=args.bigram_lambda_init,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_meta = GPT(config)
|
||||
|
|
@ -265,10 +249,7 @@ def disable_fp8(model):
|
|||
# Compile the model
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
compile_kwargs = {"dynamic": False}
|
||||
if args.compile_mode:
|
||||
compile_kwargs["mode"] = args.compile_mode
|
||||
model = torch.compile(model, **compile_kwargs) # the inputs to model will never change shape so dynamic=False is safe
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
|
||||
|
|
@ -336,8 +317,7 @@ optimizer = model.setup_optimizer(
|
|||
# AdamW hyperparameters
|
||||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
bigram_embedding_lr_mult=args.bigram_embedding_lr_mult,
|
||||
bigram_lambda_lr=args.bigram_lambda_lr * batch_lr_scale,
|
||||
bigram_lambda_lr=0.004 * batch_lr_scale,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
# Muon hyperparameters
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
|
|
@ -442,11 +422,6 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
|||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
train_log_every = args.train_log_every
|
||||
batched_train_timing = train_log_every > 1
|
||||
train_timing_interval_start = None
|
||||
train_timing_interval_first_step = step
|
||||
train_log_count = 0
|
||||
|
||||
# Go!
|
||||
while True:
|
||||
|
|
@ -454,7 +429,7 @@ while True:
|
|||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if args.eval_every > 0 and (last_step or (step % args.eval_every == 0 and (step > 0 or not args.skip_initial_eval))):
|
||||
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
|
|
@ -541,14 +516,8 @@ while True:
|
|||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
if batched_train_timing:
|
||||
if train_timing_interval_start is None:
|
||||
synchronize()
|
||||
train_timing_interval_start = time.time()
|
||||
train_timing_interval_first_step = step
|
||||
else:
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
|
|
@ -580,66 +549,46 @@ while True:
|
|||
else:
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
should_log_train = step == 0 or (step + 1) % train_log_every == 0 or (step + 1) == num_iterations
|
||||
if batched_train_timing:
|
||||
if should_log_train:
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
interval_steps = step - train_timing_interval_first_step + 1
|
||||
interval_dt = t1 - train_timing_interval_start
|
||||
dt = interval_dt / interval_steps
|
||||
counted_start = max(train_timing_interval_first_step, 11)
|
||||
counted_steps = max(0, step - counted_start + 1)
|
||||
if counted_steps > 0:
|
||||
total_training_time += interval_dt * counted_steps / interval_steps
|
||||
train_loss_f = train_loss.item()
|
||||
train_timing_interval_start = None
|
||||
else:
|
||||
dt = None
|
||||
train_loss_f = None
|
||||
else:
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# logging (CPU action only)
|
||||
if should_log_train:
|
||||
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
|
||||
train_log_count += 1
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**train_log_count) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
||||
# Calculate ETA based on average time per step (excluding first 10 steps)
|
||||
steps_done = step - 10
|
||||
if steps_done > 0:
|
||||
avg_time_per_step = total_training_time / steps_done
|
||||
remaining_steps = num_iterations - step
|
||||
eta_seconds = remaining_steps * avg_time_per_step
|
||||
eta_str = f" | eta: {eta_seconds/60:.1f}m"
|
||||
else:
|
||||
eta_str = ""
|
||||
epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}"
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
if step % 100 == 0 or (step + 1) % 100 == 0:
|
||||
log_data = {
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": epoch,
|
||||
}
|
||||
wandb_run.log(log_data)
|
||||
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
# Calculate ETA based on average time per step (excluding first 10 steps)
|
||||
steps_done = step - 10
|
||||
if steps_done > 0:
|
||||
avg_time_per_step = total_training_time / steps_done
|
||||
remaining_steps = num_iterations - step
|
||||
eta_seconds = remaining_steps * avg_time_per_step
|
||||
eta_str = f" | eta: {eta_seconds/60:.1f}m"
|
||||
else:
|
||||
eta_str = ""
|
||||
epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}"
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
if step % 100 == 0:
|
||||
log_data = {
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": epoch,
|
||||
}
|
||||
wandb_run.log(log_data)
|
||||
|
||||
# state update
|
||||
first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user