Update warmdown and rename quickrun

This commit is contained in:
dangxingyu 2026-02-03 20:25:16 -05:00
parent e28d4ead22
commit 77de3297ea
2 changed files with 30 additions and 18 deletions

View File

@ -30,8 +30,9 @@ fi
# Optimizer
MATRIX_OPTIMIZER="${MATRIX_OPTIMIZER:-hyperball}"
SCALAR_LR="${SCALAR_LR:-0.5}"
MATRIX_LR="$SCALAR_LR" # share with scalar LR
WARMDOWN_RATIO="${WARMDOWN_RATIO:-0.3}"
MATRIX_LR="${MATRIX_LR:-0.02}"
WARMDOWN_RATIO="${WARMDOWN_RATIO:-1.0}"
MATRIX_WARMDOWN_RATIO="${MATRIX_WARMDOWN_RATIO:-1.0}"
# AdamW
EMBEDDING_LR="${EMBEDDING_LR:-0.3}"
@ -82,9 +83,9 @@ echo "Num GPUs: $NPROC_PER_NODE"
echo "Device batch size: $DEVICE_BATCH_SIZE"
echo "Total batch size: $TOTAL_BATCH_SIZE"
echo "Matrix optimizer: $MATRIX_OPTIMIZER"
echo "Matrix LR: $MATRIX_LR (shared with scalar)"
echo "Matrix LR: $MATRIX_LR"
echo "Adam LRs: embedding=$EMBEDDING_LR, unembedding=$UNEMBEDDING_LR, scalar=$SCALAR_LR"
echo "Warmdown ratio: $WARMDOWN_RATIO"
echo "Warmdown ratio: adam=$WARMDOWN_RATIO, matrix=$MATRIX_WARMDOWN_RATIO"
echo "Wandb run: $WANDB_RUN"
echo "Model tag: $MODEL_TAG"
if [ "${FP8:-0}" -eq 1 ]; then
@ -133,6 +134,7 @@ TRAIN_ARGS=(
--matrix-optimizer=$MATRIX_OPTIMIZER
--matrix-lr=$MATRIX_LR
--warmdown-ratio=$WARMDOWN_RATIO
--matrix-warmdown-ratio=$MATRIX_WARMDOWN_RATIO
--embedding-lr=$EMBEDDING_LR
--unembedding-lr=$UNEMBEDDING_LR
--scalar-lr=$SCALAR_LR

View File

@ -64,7 +64,8 @@ parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for AdamW LR warmdown")
parser.add_argument("--matrix-warmdown-ratio", type=float, default=1.0, help="ratio of iterations for Muon/Hyperball LR warmdown")
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
# Evaluation
@ -80,6 +81,8 @@ args = parser.parse_args()
user_config = vars(args).copy() # for logging
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
@ -321,17 +324,18 @@ x, y, dataloader_state_dict = next(train_loader) # kick off load of the very fir
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
# Learning rate scheduler
def get_lr_multiplier(it):
warmup_iters = round(args.warmup_ratio * num_iterations)
warmdown_iters = round(args.warmdown_ratio * num_iterations)
if it < warmup_iters:
# Learning rate scheduler (warmup + warmdown)
def get_lr_multiplier(it, warmup_ratio, warmdown_ratio, final_lr_frac):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)
if warmup_iters > 0 and it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
if warmdown_iters <= 0:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * args.final_lr_frac
if it <= num_iterations - warmdown_iters:
return 1.0
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac
# Momentum scheduler for matrix optimizer (Muon/Hyperball)
def get_muon_momentum(it):
@ -463,11 +467,15 @@ while True:
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# step the optimizer
lrm = get_lr_multiplier(step)
lrm_adam = get_lr_multiplier(step, args.warmup_ratio, args.warmdown_ratio, args.final_lr_frac)
lrm_matrix = get_lr_multiplier(step, 0.0, args.matrix_warmdown_ratio, args.final_lr_frac)
muon_momentum = get_muon_momentum(step)
muon_weight_decay = get_weight_decay(step)
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] in {'muon', 'hyperball'}:
group["lr"] = group["initial_lr"] * lrm_matrix
else:
group["lr"] = group["initial_lr"] * lrm_adam
if group['kind'] in {'muon', 'hyperball'}:
group["momentum"] = muon_momentum
if group['kind'] == 'muon':
@ -500,14 +508,15 @@ while True:
else:
eta_str = ""
epoch = dataloader_state_dict["epoch"]
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm(adam)={lrm_adam:.2f}, lrm(matrix)={lrm_matrix:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
if step % 100 == 0:
log_data = {
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/lrm_adam": lrm_adam,
"train/lrm_matrix": lrm_matrix,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
@ -548,6 +557,7 @@ get_report().log(section="Base model training", data=[
"DDP world size": ddp_world_size,
"warmup_ratio": args.warmup_ratio,
"warmdown_ratio": args.warmdown_ratio,
"matrix_warmdown_ratio": args.matrix_warmdown_ratio,
"final_lr_frac": args.final_lr_frac,
},
{ # stats about training outcomes