mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-19 19:33:15 +00:00
Scale hyperball lr by depth
This commit is contained in:
parent
924489f582
commit
595a0f460a
|
|
@ -300,10 +300,19 @@ model = torch.compile(model, dynamic=False) # the inputs to model will never cha
|
|||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
adam_betas = (args.adam_beta1, args.adam_beta2)
|
||||
matrix_lr_scaled = args.matrix_lr * batch_lr_scale
|
||||
|
||||
# LR depth scaling for Hyperball
|
||||
if args.matrix_optimizer == "hyperball":
|
||||
hyperball_depth_scale = 12 / args.depth
|
||||
matrix_lr_scaled = matrix_lr_scaled * hyperball_depth_scale
|
||||
if args.depth != 12:
|
||||
print0(f"Scaling hyperball LR from {args.matrix_lr * batch_lr_scale:.6f} to {matrix_lr_scaled:.6f} for depth {args.depth}")
|
||||
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
matrix_lr=matrix_lr_scaled,
|
||||
weight_decay=weight_decay_scaled,
|
||||
adam_betas=adam_betas,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user