diff --git a/runs/speedrun.sh b/runs/speedrun.sh index 9a4c3977..20e62488 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -70,9 +70,19 @@ echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # d22 Muon+/row-eq + hashed bigram recipe. -# scripts/base_train defaults are the submission defaults: fixed 11,600 -# optimizer steps, eval every 250, and one in-training CORE pass halfway through. -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --run=$WANDB_RUN \ + --fp8 \ + --depth=22 \ + --num-iterations=11600 \ + --target-param-data-ratio=11 \ + --total-batch-size=524288 \ + --scalar-lr=0.3 \ + --bigram-embed-factor=5 \ + --muon-plus \ + --muon-eq=row \ + --core-metric-every=5800 \ + --sample-every=-1 # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 diff --git a/scripts/base_train.py b/scripts/base_train.py index d0e3780c..b415005c 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -8,7 +8,7 @@ or distributed as: torchrun --nproc_per_node=8 -m scripts.base_train If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: -python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 --no-fp8 --no-muon-plus --muon-eq=none --bigram-embed-factor=0 +python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 """ import os @@ -41,39 +41,32 @@ print_banner() parser = argparse.ArgumentParser(description="Pretrain base model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") -parser.add_argument("--train-log-every", type=int, default=50, help="print training metrics every N steps; values >1 avoid per-step CPU/GPU sync") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training -parser.add_argument("--fp8", dest="fp8", action="store_true", default=True, help="enable FP8 training (requires H100+ GPU and torchao)") -parser.add_argument("--no-fp8", dest="fp8", action="store_false", help="disable FP8 training") +parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") -parser.add_argument("--compile-mode", type=str, default="max-autotune-no-cudagraphs", choices=["", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], help="optional torch.compile mode") # Model architecture -parser.add_argument("--depth", type=int, default=22, help="depth of the Transformer model") +parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") -parser.add_argument("--bigram-embed-factor", type=int, default=5, help="if >0, add a hashed bigram embedding residual") -parser.add_argument("--bigram-lambda-init", type=float, default=0.05, help="initial layer residual scale for --bigram-embed-factor") -parser.add_argument("--bigram-embedding-lr-mult", type=float, default=1.0, help="bigram embedding LR multiplier relative to --embedding-lr") -parser.add_argument("--bigram-lambda-lr", type=float, default=0.004, help="AdamW LR for bigram layer lambdas before dmodel schedule scaling") +parser.add_argument("--bigram-embed-factor", type=int, default=0, help="if >0, add a hashed bigram embedding residual") # Training horizon (only one used, in order of precedence) -parser.add_argument("--num-iterations", type=int, default=11600, help="explicit number of optimization steps (-1 = disable)") +parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") -parser.add_argument("--target-param-data-ratio", type=float, default=11, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") +parser.add_argument("--target-param-data-ratio", type=float, default=12, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") # Optimization parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.") -parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") +parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--scalar-lr", type=float, default=0.3, help="learning rate for scalars (resid_lambdas, x0_lambdas)") -parser.add_argument("--muon-plus", dest="muon_plus", action="store_true", default=True, help="apply Muon+ style post-orthogonalization Frobenius renormalization") -parser.add_argument("--no-muon-plus", dest="muon_plus", action="store_false", help="disable Muon+ post-orthogonalization renormalization") -parser.add_argument("--muon-eq", type=str, default="row", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization") +parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") +parser.add_argument("--muon-plus", action="store_true", help="apply Muon+ style post-orthogonalization Frobenius renormalization") +parser.add_argument("--muon-eq", type=str, default="none", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization") parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup") parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown") parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR") @@ -81,24 +74,16 @@ parser.add_argument("--resume-from-step", type=int, default=-1, help="resume tra # Evaluation parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on") -parser.add_argument("--skip-initial-eval", dest="skip_initial_eval", action="store_true", default=True, help="skip the step 0 validation pass; final validation still runs") -parser.add_argument("--initial-eval", dest="skip_initial_eval", action="store_false", help="run validation at step 0") -parser.add_argument("--core-metric-every", type=int, default=5800, help="evaluate CORE metric every N steps (-1 = disable)") +parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") -parser.add_argument("--sample-every", type=int, default=-1, help="sample from model every N steps (-1 = disable)") +parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") # Output parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") args = parser.parse_args() user_config = vars(args).copy() # for logging -if args.train_log_every <= 0: - parser.error("--train-log-every must be positive") if args.bigram_embed_factor < 0: parser.error("--bigram-embed-factor must be non-negative") -if args.bigram_lambda_lr < 0: - parser.error("--bigram-lambda-lr must be non-negative") -if args.bigram_embedding_lr_mult <= 0: - parser.error("--bigram-embedding-lr-mult must be positive") # ----------------------------------------------------------------------------- # Compute init and wandb logging @@ -158,7 +143,6 @@ def build_model_meta(depth): n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, window_pattern=args.window_pattern, bigram_embed_factor=args.bigram_embed_factor, - bigram_lambda_init=args.bigram_lambda_init, ) with torch.device("meta"): model_meta = GPT(config) @@ -265,10 +249,7 @@ def disable_fp8(model): # Compile the model orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -compile_kwargs = {"dynamic": False} -if args.compile_mode: - compile_kwargs["mode"] = args.compile_mode -model = torch.compile(model, **compile_kwargs) # the inputs to model will never change shape so dynamic=False is safe +model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe # ----------------------------------------------------------------------------- # Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay. @@ -336,8 +317,7 @@ optimizer = model.setup_optimizer( # AdamW hyperparameters unembedding_lr=args.unembedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale, - bigram_embedding_lr_mult=args.bigram_embedding_lr_mult, - bigram_lambda_lr=args.bigram_lambda_lr * batch_lr_scale, + bigram_lambda_lr=0.004 * batch_lr_scale, scalar_lr=args.scalar_lr * batch_lr_scale, # Muon hyperparameters matrix_lr=args.matrix_lr * batch_lr_scale, @@ -442,11 +422,6 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") -train_log_every = args.train_log_every -batched_train_timing = train_log_every > 1 -train_timing_interval_start = None -train_timing_interval_first_step = step -train_log_count = 0 # Go! while True: @@ -454,7 +429,7 @@ while True: flops_so_far = num_flops_per_token * total_batch_size * step # once in a while: evaluate the val bpb (all ranks participate) - if args.eval_every > 0 and (last_step or (step % args.eval_every == 0 and (step > 0 or not args.skip_initial_eval))): + if args.eval_every > 0 and (last_step or step % args.eval_every == 0): model.eval() val_loader = build_val_loader() eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) @@ -541,14 +516,8 @@ while True: # ------------------------------------------------------------------------- # single training step # evaluate the gradient - if batched_train_timing: - if train_timing_interval_start is None: - synchronize() - train_timing_interval_start = time.time() - train_timing_interval_first_step = step - else: - synchronize() - t0 = time.time() + synchronize() + t0 = time.time() for micro_step in range(grad_accum_steps): loss = model(x, y) train_loss = loss.detach() # for logging @@ -580,66 +549,46 @@ while True: else: optimizer.step() model.zero_grad(set_to_none=True) - should_log_train = step == 0 or (step + 1) % train_log_every == 0 or (step + 1) == num_iterations - if batched_train_timing: - if should_log_train: - synchronize() - t1 = time.time() - interval_steps = step - train_timing_interval_first_step + 1 - interval_dt = t1 - train_timing_interval_start - dt = interval_dt / interval_steps - counted_start = max(train_timing_interval_first_step, 11) - counted_steps = max(0, step - counted_start + 1) - if counted_steps > 0: - total_training_time += interval_dt * counted_steps / interval_steps - train_loss_f = train_loss.item() - train_timing_interval_start = None - else: - dt = None - train_loss_f = None - else: - train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point - synchronize() - t1 = time.time() - dt = t1 - t0 - if step > 10: - total_training_time += dt # only count the time after the first 10 steps + train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point + synchronize() + t1 = time.time() + dt = t1 - t0 # ------------------------------------------------------------------------- # logging (CPU action only) - if should_log_train: - ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging - smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss - train_log_count += 1 - debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**train_log_count) # debias the EMA - pct_done = 100 * step / num_iterations - tok_per_sec = int(total_batch_size / dt) - flops_per_sec = num_flops_per_token * total_batch_size / dt - mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) - # Calculate ETA based on average time per step (excluding first 10 steps) - steps_done = step - 10 - if steps_done > 0: - avg_time_per_step = total_training_time / steps_done - remaining_steps = num_iterations - step - eta_seconds = remaining_steps * avg_time_per_step - eta_str = f" | eta: {eta_seconds/60:.1f}m" - else: - eta_str = "" - epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}" - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") - if step % 100 == 0 or (step + 1) % 100 == 0: - log_data = { - "step": step, - "total_training_flops": flops_so_far, - "total_training_time": total_training_time, - "train/loss": debiased_smooth_loss, - "train/lrm": lrm, - "train/dt": dt, - "train/tok_per_sec": tok_per_sec, - "train/mfu": mfu, - "train/epoch": epoch, - } - wandb_run.log(log_data) + ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA + pct_done = 100 * step / num_iterations + tok_per_sec = int(total_batch_size / dt) + flops_per_sec = num_flops_per_token * total_batch_size / dt + mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) + if step > 10: + total_training_time += dt # only count the time after the first 10 steps + # Calculate ETA based on average time per step (excluding first 10 steps) + steps_done = step - 10 + if steps_done > 0: + avg_time_per_step = total_training_time / steps_done + remaining_steps = num_iterations - step + eta_seconds = remaining_steps * avg_time_per_step + eta_str = f" | eta: {eta_seconds/60:.1f}m" + else: + eta_str = "" + epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}" + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") + if step % 100 == 0: + log_data = { + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "train/loss": debiased_smooth_loss, + "train/lrm": lrm, + "train/dt": dt, + "train/tok_per_sec": tok_per_sec, + "train/mfu": mfu, + "train/epoch": epoch, + } + wandb_run.log(log_data) # state update first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step)