diff --git a/runs/quickrun_gamma_muonh_d24.sh b/runs/quickrun_muonh_d24.sh similarity index 94% rename from runs/quickrun_gamma_muonh_d24.sh rename to runs/quickrun_muonh_d24.sh index 63ba8cf..98b142f 100755 --- a/runs/quickrun_gamma_muonh_d24.sh +++ b/runs/quickrun_muonh_d24.sh @@ -30,8 +30,9 @@ fi # Optimizer MATRIX_OPTIMIZER="${MATRIX_OPTIMIZER:-hyperball}" SCALAR_LR="${SCALAR_LR:-0.5}" -MATRIX_LR="$SCALAR_LR" # share with scalar LR -WARMDOWN_RATIO="${WARMDOWN_RATIO:-0.3}" +MATRIX_LR="${MATRIX_LR:-0.02}" +WARMDOWN_RATIO="${WARMDOWN_RATIO:-1.0}" +MATRIX_WARMDOWN_RATIO="${MATRIX_WARMDOWN_RATIO:-1.0}" # AdamW EMBEDDING_LR="${EMBEDDING_LR:-0.3}" @@ -82,9 +83,9 @@ echo "Num GPUs: $NPROC_PER_NODE" echo "Device batch size: $DEVICE_BATCH_SIZE" echo "Total batch size: $TOTAL_BATCH_SIZE" echo "Matrix optimizer: $MATRIX_OPTIMIZER" -echo "Matrix LR: $MATRIX_LR (shared with scalar)" +echo "Matrix LR: $MATRIX_LR" echo "Adam LRs: embedding=$EMBEDDING_LR, unembedding=$UNEMBEDDING_LR, scalar=$SCALAR_LR" -echo "Warmdown ratio: $WARMDOWN_RATIO" +echo "Warmdown ratio: adam=$WARMDOWN_RATIO, matrix=$MATRIX_WARMDOWN_RATIO" echo "Wandb run: $WANDB_RUN" echo "Model tag: $MODEL_TAG" if [ "${FP8:-0}" -eq 1 ]; then @@ -133,6 +134,7 @@ TRAIN_ARGS=( --matrix-optimizer=$MATRIX_OPTIMIZER --matrix-lr=$MATRIX_LR --warmdown-ratio=$WARMDOWN_RATIO + --matrix-warmdown-ratio=$MATRIX_WARMDOWN_RATIO --embedding-lr=$EMBEDDING_LR --unembedding-lr=$UNEMBEDDING_LR --scalar-lr=$SCALAR_LR diff --git a/scripts/base_train.py b/scripts/base_train.py index a4de906..3d13bf4 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -64,7 +64,8 @@ parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") -parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") +parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for AdamW LR warmdown") +parser.add_argument("--matrix-warmdown-ratio", type=float, default=1.0, help="ratio of iterations for Muon/Hyperball LR warmdown") parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") # Evaluation @@ -80,6 +81,8 @@ args = parser.parse_args() user_config = vars(args).copy() # for logging # ----------------------------------------------------------------------------- + + # Compute init device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) @@ -321,17 +324,18 @@ x, y, dataloader_state_dict = next(train_loader) # kick off load of the very fir # ----------------------------------------------------------------------------- # Set up hyperparameter schedulers -# Learning rate scheduler -def get_lr_multiplier(it): - warmup_iters = round(args.warmup_ratio * num_iterations) - warmdown_iters = round(args.warmdown_ratio * num_iterations) - if it < warmup_iters: +# Learning rate scheduler (warmup + warmdown) +def get_lr_multiplier(it, warmup_ratio, warmdown_ratio, final_lr_frac): + warmup_iters = round(warmup_ratio * num_iterations) + warmdown_iters = round(warmdown_ratio * num_iterations) + if warmup_iters > 0 and it < warmup_iters: return (it + 1) / warmup_iters - elif it <= num_iterations - warmdown_iters: + if warmdown_iters <= 0: return 1.0 - else: - progress = (num_iterations - it) / warmdown_iters - return progress * 1.0 + (1 - progress) * args.final_lr_frac + if it <= num_iterations - warmdown_iters: + return 1.0 + progress = (num_iterations - it) / warmdown_iters + return progress * 1.0 + (1 - progress) * final_lr_frac # Momentum scheduler for matrix optimizer (Muon/Hyperball) def get_muon_momentum(it): @@ -463,11 +467,15 @@ while True: loss.backward() x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward # step the optimizer - lrm = get_lr_multiplier(step) + lrm_adam = get_lr_multiplier(step, args.warmup_ratio, args.warmdown_ratio, args.final_lr_frac) + lrm_matrix = get_lr_multiplier(step, 0.0, args.matrix_warmdown_ratio, args.final_lr_frac) muon_momentum = get_muon_momentum(step) muon_weight_decay = get_weight_decay(step) for group in optimizer.param_groups: - group["lr"] = group["initial_lr"] * lrm + if group['kind'] in {'muon', 'hyperball'}: + group["lr"] = group["initial_lr"] * lrm_matrix + else: + group["lr"] = group["initial_lr"] * lrm_adam if group['kind'] in {'muon', 'hyperball'}: group["momentum"] = muon_momentum if group['kind'] == 'muon': @@ -500,14 +508,15 @@ while True: else: eta_str = "" epoch = dataloader_state_dict["epoch"] - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm(adam)={lrm_adam:.2f}, lrm(matrix)={lrm_matrix:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") if step % 100 == 0: log_data = { "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, "train/loss": debiased_smooth_loss, - "train/lrm": lrm, + "train/lrm_adam": lrm_adam, + "train/lrm_matrix": lrm_matrix, "train/dt": dt, "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, @@ -548,6 +557,7 @@ get_report().log(section="Base model training", data=[ "DDP world size": ddp_world_size, "warmup_ratio": args.warmup_ratio, "warmdown_ratio": args.warmdown_ratio, + "matrix_warmdown_ratio": args.matrix_warmdown_ratio, "final_lr_frac": args.final_lr_frac, }, { # stats about training outcomes