Keep speedrun changes functional

This commit is contained in:
Codex 2026-05-07 12:24:51 +00:00
parent d6a169b329
commit 4680e799fb
2 changed files with 67 additions and 108 deletions

View File

@ -70,9 +70,19 @@ echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
# d22 Muon+/row-eq + hashed bigram recipe.
# scripts/base_train defaults are the submission defaults: fixed 11,600
# optimizer steps, eval every 250, and one in-training CORE pass halfway through.
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--run=$WANDB_RUN \
--fp8 \
--depth=22 \
--num-iterations=11600 \
--target-param-data-ratio=11 \
--total-batch-size=524288 \
--scalar-lr=0.3 \
--bigram-embed-factor=5 \
--muon-plus \
--muon-eq=row \
--core-metric-every=5800 \
--sample-every=-1
# evaluate the model: CORE metric, BPB on train/val, and draw samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16

View File

@ -8,7 +8,7 @@ or distributed as:
torchrun --nproc_per_node=8 -m scripts.base_train
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 --no-fp8 --no-muon-plus --muon-eq=none --bigram-embed-factor=0
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
"""
import os
@ -41,39 +41,32 @@ print_banner()
parser = argparse.ArgumentParser(description="Pretrain base model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
parser.add_argument("--train-log-every", type=int, default=50, help="print training metrics every N steps; values >1 avoid per-step CPU/GPU sync")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
# FP8 training
parser.add_argument("--fp8", dest="fp8", action="store_true", default=True, help="enable FP8 training (requires H100+ GPU and torchao)")
parser.add_argument("--no-fp8", dest="fp8", action="store_false", help="disable FP8 training")
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
parser.add_argument("--compile-mode", type=str, default="max-autotune-no-cudagraphs", choices=["", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], help="optional torch.compile mode")
# Model architecture
parser.add_argument("--depth", type=int, default=22, help="depth of the Transformer model")
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
parser.add_argument("--bigram-embed-factor", type=int, default=5, help="if >0, add a hashed bigram embedding residual")
parser.add_argument("--bigram-lambda-init", type=float, default=0.05, help="initial layer residual scale for --bigram-embed-factor")
parser.add_argument("--bigram-embedding-lr-mult", type=float, default=1.0, help="bigram embedding LR multiplier relative to --embedding-lr")
parser.add_argument("--bigram-lambda-lr", type=float, default=0.004, help="AdamW LR for bigram layer lambdas before dmodel schedule scaling")
parser.add_argument("--bigram-embed-factor", type=int, default=0, help="if >0, add a hashed bigram embedding residual")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=11600, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
parser.add_argument("--target-param-data-ratio", type=float, default=11, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
parser.add_argument("--target-param-data-ratio", type=float, default=12, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
# Optimization
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)")
parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)")
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--scalar-lr", type=float, default=0.3, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
parser.add_argument("--muon-plus", dest="muon_plus", action="store_true", default=True, help="apply Muon+ style post-orthogonalization Frobenius renormalization")
parser.add_argument("--no-muon-plus", dest="muon_plus", action="store_false", help="disable Muon+ post-orthogonalization renormalization")
parser.add_argument("--muon-eq", type=str, default="row", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization")
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
parser.add_argument("--muon-plus", action="store_true", help="apply Muon+ style post-orthogonalization Frobenius renormalization")
parser.add_argument("--muon-eq", type=str, default="none", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization")
parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup")
parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown")
parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR")
@ -81,24 +74,16 @@ parser.add_argument("--resume-from-step", type=int, default=-1, help="resume tra
# Evaluation
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--skip-initial-eval", dest="skip_initial_eval", action="store_true", default=True, help="skip the step 0 validation pass; final validation still runs")
parser.add_argument("--initial-eval", dest="skip_initial_eval", action="store_false", help="run validation at step 0")
parser.add_argument("--core-metric-every", type=int, default=5800, help="evaluate CORE metric every N steps (-1 = disable)")
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
parser.add_argument("--sample-every", type=int, default=-1, help="sample from model every N steps (-1 = disable)")
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
# Output
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
args = parser.parse_args()
user_config = vars(args).copy() # for logging
if args.train_log_every <= 0:
parser.error("--train-log-every must be positive")
if args.bigram_embed_factor < 0:
parser.error("--bigram-embed-factor must be non-negative")
if args.bigram_lambda_lr < 0:
parser.error("--bigram-lambda-lr must be non-negative")
if args.bigram_embedding_lr_mult <= 0:
parser.error("--bigram-embedding-lr-mult must be positive")
# -----------------------------------------------------------------------------
# Compute init and wandb logging
@ -158,7 +143,6 @@ def build_model_meta(depth):
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=args.window_pattern,
bigram_embed_factor=args.bigram_embed_factor,
bigram_lambda_init=args.bigram_lambda_init,
)
with torch.device("meta"):
model_meta = GPT(config)
@ -265,10 +249,7 @@ def disable_fp8(model):
# Compile the model
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
compile_kwargs = {"dynamic": False}
if args.compile_mode:
compile_kwargs["mode"] = args.compile_mode
model = torch.compile(model, **compile_kwargs) # the inputs to model will never change shape so dynamic=False is safe
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
# -----------------------------------------------------------------------------
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
@ -336,8 +317,7 @@ optimizer = model.setup_optimizer(
# AdamW hyperparameters
unembedding_lr=args.unembedding_lr * batch_lr_scale,
embedding_lr=args.embedding_lr * batch_lr_scale,
bigram_embedding_lr_mult=args.bigram_embedding_lr_mult,
bigram_lambda_lr=args.bigram_lambda_lr * batch_lr_scale,
bigram_lambda_lr=0.004 * batch_lr_scale,
scalar_lr=args.scalar_lr * batch_lr_scale,
# Muon hyperparameters
matrix_lr=args.matrix_lr * batch_lr_scale,
@ -442,11 +422,6 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
train_log_every = args.train_log_every
batched_train_timing = train_log_every > 1
train_timing_interval_start = None
train_timing_interval_first_step = step
train_log_count = 0
# Go!
while True:
@ -454,7 +429,7 @@ while True:
flops_so_far = num_flops_per_token * total_batch_size * step
# once in a while: evaluate the val bpb (all ranks participate)
if args.eval_every > 0 and (last_step or (step % args.eval_every == 0 and (step > 0 or not args.skip_initial_eval))):
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
model.eval()
val_loader = build_val_loader()
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
@ -541,14 +516,8 @@ while True:
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
if batched_train_timing:
if train_timing_interval_start is None:
synchronize()
train_timing_interval_start = time.time()
train_timing_interval_first_step = step
else:
synchronize()
t0 = time.time()
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
loss = model(x, y)
train_loss = loss.detach() # for logging
@ -580,66 +549,46 @@ while True:
else:
optimizer.step()
model.zero_grad(set_to_none=True)
should_log_train = step == 0 or (step + 1) % train_log_every == 0 or (step + 1) == num_iterations
if batched_train_timing:
if should_log_train:
synchronize()
t1 = time.time()
interval_steps = step - train_timing_interval_first_step + 1
interval_dt = t1 - train_timing_interval_start
dt = interval_dt / interval_steps
counted_start = max(train_timing_interval_first_step, 11)
counted_steps = max(0, step - counted_start + 1)
if counted_steps > 0:
total_training_time += interval_dt * counted_steps / interval_steps
train_loss_f = train_loss.item()
train_timing_interval_start = None
else:
dt = None
train_loss_f = None
else:
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
synchronize()
t1 = time.time()
dt = t1 - t0
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
synchronize()
t1 = time.time()
dt = t1 - t0
# -------------------------------------------------------------------------
# logging (CPU action only)
if should_log_train:
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
train_log_count += 1
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**train_log_count) # debias the EMA
pct_done = 100 * step / num_iterations
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
# Calculate ETA based on average time per step (excluding first 10 steps)
steps_done = step - 10
if steps_done > 0:
avg_time_per_step = total_training_time / steps_done
remaining_steps = num_iterations - step
eta_seconds = remaining_steps * avg_time_per_step
eta_str = f" | eta: {eta_seconds/60:.1f}m"
else:
eta_str = ""
epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}"
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
if step % 100 == 0 or (step + 1) % 100 == 0:
log_data = {
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
"train/epoch": epoch,
}
wandb_run.log(log_data)
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
# Calculate ETA based on average time per step (excluding first 10 steps)
steps_done = step - 10
if steps_done > 0:
avg_time_per_step = total_training_time / steps_done
remaining_steps = num_iterations - step
eta_seconds = remaining_steps * avg_time_per_step
eta_str = f" | eta: {eta_seconds/60:.1f}m"
else:
eta_str = ""
epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}"
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
if step % 100 == 0:
log_data = {
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
"train/epoch": epoch,
}
wandb_run.log(log_data)
# state update
first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step)