mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-14 17:03:13 +00:00
Update warmdown and rename quickrun
This commit is contained in:
parent
e28d4ead22
commit
77de3297ea
|
|
@ -30,8 +30,9 @@ fi
|
|||
# Optimizer
|
||||
MATRIX_OPTIMIZER="${MATRIX_OPTIMIZER:-hyperball}"
|
||||
SCALAR_LR="${SCALAR_LR:-0.5}"
|
||||
MATRIX_LR="$SCALAR_LR" # share with scalar LR
|
||||
WARMDOWN_RATIO="${WARMDOWN_RATIO:-0.3}"
|
||||
MATRIX_LR="${MATRIX_LR:-0.02}"
|
||||
WARMDOWN_RATIO="${WARMDOWN_RATIO:-1.0}"
|
||||
MATRIX_WARMDOWN_RATIO="${MATRIX_WARMDOWN_RATIO:-1.0}"
|
||||
|
||||
# AdamW
|
||||
EMBEDDING_LR="${EMBEDDING_LR:-0.3}"
|
||||
|
|
@ -82,9 +83,9 @@ echo "Num GPUs: $NPROC_PER_NODE"
|
|||
echo "Device batch size: $DEVICE_BATCH_SIZE"
|
||||
echo "Total batch size: $TOTAL_BATCH_SIZE"
|
||||
echo "Matrix optimizer: $MATRIX_OPTIMIZER"
|
||||
echo "Matrix LR: $MATRIX_LR (shared with scalar)"
|
||||
echo "Matrix LR: $MATRIX_LR"
|
||||
echo "Adam LRs: embedding=$EMBEDDING_LR, unembedding=$UNEMBEDDING_LR, scalar=$SCALAR_LR"
|
||||
echo "Warmdown ratio: $WARMDOWN_RATIO"
|
||||
echo "Warmdown ratio: adam=$WARMDOWN_RATIO, matrix=$MATRIX_WARMDOWN_RATIO"
|
||||
echo "Wandb run: $WANDB_RUN"
|
||||
echo "Model tag: $MODEL_TAG"
|
||||
if [ "${FP8:-0}" -eq 1 ]; then
|
||||
|
|
@ -133,6 +134,7 @@ TRAIN_ARGS=(
|
|||
--matrix-optimizer=$MATRIX_OPTIMIZER
|
||||
--matrix-lr=$MATRIX_LR
|
||||
--warmdown-ratio=$WARMDOWN_RATIO
|
||||
--matrix-warmdown-ratio=$MATRIX_WARMDOWN_RATIO
|
||||
--embedding-lr=$EMBEDDING_LR
|
||||
--unembedding-lr=$UNEMBEDDING_LR
|
||||
--scalar-lr=$SCALAR_LR
|
||||
|
|
@ -64,7 +64,8 @@ parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate
|
|||
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for AdamW LR warmdown")
|
||||
parser.add_argument("--matrix-warmdown-ratio", type=float, default=1.0, help="ratio of iterations for Muon/Hyperball LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
# Evaluation
|
||||
|
|
@ -80,6 +81,8 @@ args = parser.parse_args()
|
|||
user_config = vars(args).copy() # for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
|
|
@ -321,17 +324,18 @@ x, y, dataloader_state_dict = next(train_loader) # kick off load of the very fir
|
|||
# -----------------------------------------------------------------------------
|
||||
# Set up hyperparameter schedulers
|
||||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(args.warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
||||
if it < warmup_iters:
|
||||
# Learning rate scheduler (warmup + warmdown)
|
||||
def get_lr_multiplier(it, warmup_ratio, warmdown_ratio, final_lr_frac):
|
||||
warmup_iters = round(warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(warmdown_ratio * num_iterations)
|
||||
if warmup_iters > 0 and it < warmup_iters:
|
||||
return (it + 1) / warmup_iters
|
||||
elif it <= num_iterations - warmdown_iters:
|
||||
if warmdown_iters <= 0:
|
||||
return 1.0
|
||||
else:
|
||||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
||||
if it <= num_iterations - warmdown_iters:
|
||||
return 1.0
|
||||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * final_lr_frac
|
||||
|
||||
# Momentum scheduler for matrix optimizer (Muon/Hyperball)
|
||||
def get_muon_momentum(it):
|
||||
|
|
@ -463,11 +467,15 @@ while True:
|
|||
loss.backward()
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# step the optimizer
|
||||
lrm = get_lr_multiplier(step)
|
||||
lrm_adam = get_lr_multiplier(step, args.warmup_ratio, args.warmdown_ratio, args.final_lr_frac)
|
||||
lrm_matrix = get_lr_multiplier(step, 0.0, args.matrix_warmdown_ratio, args.final_lr_frac)
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
muon_weight_decay = get_weight_decay(step)
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group['kind'] in {'muon', 'hyperball'}:
|
||||
group["lr"] = group["initial_lr"] * lrm_matrix
|
||||
else:
|
||||
group["lr"] = group["initial_lr"] * lrm_adam
|
||||
if group['kind'] in {'muon', 'hyperball'}:
|
||||
group["momentum"] = muon_momentum
|
||||
if group['kind'] == 'muon':
|
||||
|
|
@ -500,14 +508,15 @@ while True:
|
|||
else:
|
||||
eta_str = ""
|
||||
epoch = dataloader_state_dict["epoch"]
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm(adam)={lrm_adam:.2f}, lrm(matrix)={lrm_matrix:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
if step % 100 == 0:
|
||||
log_data = {
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
"train/lrm": lrm,
|
||||
"train/lrm_adam": lrm_adam,
|
||||
"train/lrm_matrix": lrm_matrix,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
|
|
@ -548,6 +557,7 @@ get_report().log(section="Base model training", data=[
|
|||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": args.warmup_ratio,
|
||||
"warmdown_ratio": args.warmdown_ratio,
|
||||
"matrix_warmdown_ratio": args.matrix_warmdown_ratio,
|
||||
"final_lr_frac": args.final_lr_frac,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user