diff --git a/scripts/base_train.py b/scripts/base_train.py index 3d13bf4..f8920ea 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -300,10 +300,19 @@ model = torch.compile(model, dynamic=False) # the inputs to model will never cha # ----------------------------------------------------------------------------- # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) adam_betas = (args.adam_beta1, args.adam_beta2) +matrix_lr_scaled = args.matrix_lr * batch_lr_scale + +# LR depth scaling for Hyperball +if args.matrix_optimizer == "hyperball": + hyperball_depth_scale = 12 / args.depth + matrix_lr_scaled = matrix_lr_scaled * hyperball_depth_scale + if args.depth != 12: + print0(f"Scaling hyperball LR from {args.matrix_lr * batch_lr_scale:.6f} to {matrix_lr_scaled:.6f} for depth {args.depth}") + optimizer = model.setup_optimizer( unembedding_lr=args.unembedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale, - matrix_lr=args.matrix_lr * batch_lr_scale, + matrix_lr=matrix_lr_scaled, weight_decay=weight_decay_scaled, adam_betas=adam_betas, scalar_lr=args.scalar_lr * batch_lr_scale,