Scale hyperball lr by depth

This commit is contained in:
dangxingyu 2026-02-03 21:29:51 -05:00
parent 924489f582
commit 595a0f460a

View File

@ -300,10 +300,19 @@ model = torch.compile(model, dynamic=False) # the inputs to model will never cha
# -----------------------------------------------------------------------------
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
adam_betas = (args.adam_beta1, args.adam_beta2)
matrix_lr_scaled = args.matrix_lr * batch_lr_scale
# LR depth scaling for Hyperball
if args.matrix_optimizer == "hyperball":
hyperball_depth_scale = 12 / args.depth
matrix_lr_scaled = matrix_lr_scaled * hyperball_depth_scale
if args.depth != 12:
print0(f"Scaling hyperball LR from {args.matrix_lr * batch_lr_scale:.6f} to {matrix_lr_scaled:.6f} for depth {args.depth}")
optimizer = model.setup_optimizer(
unembedding_lr=args.unembedding_lr * batch_lr_scale,
embedding_lr=args.embedding_lr * batch_lr_scale,
matrix_lr=args.matrix_lr * batch_lr_scale,
matrix_lr=matrix_lr_scaled,
weight_decay=weight_decay_scaled,
adam_betas=adam_betas,
scalar_lr=args.scalar_lr * batch_lr_scale,