Merge branch 'master' into muonh-submit

Resolved conflicts in scripts/base_train.py by keeping muonh-submit features
(hyperball optimizer support, norm_lr parameter, matrix warmup ratio) while
incorporating latest master improvements.

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kaiyue Wen 2026-02-12 20:14:24 -08:00
commit 25ec1e6c43

View File

@ -268,7 +268,7 @@ target_tokens = int(args.target_param_data_ratio * num_scaling_params) # optimal
# Our reference model is d12, this is where a lot of hyperparameters are tuned and then transfered to higher depths (muP style)
d12_ref = build_model_meta(12) # creates the model on meta device
D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) # compute-optimal d12 training horizon in tokens (measured empirically) Shall we use a constant ratio for computing D_REF?
D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) # compute-optimal d12 training horizon in tokens (measured empirically)
B_REF = 2**19 # optimal batch size at d12 ~= 524,288 tokens (measured empirically)
# 2) Now that we have the token horizon, we can calculate the optimal batch size
@ -308,18 +308,20 @@ matrix_lr_scaled = args.matrix_lr * batch_lr_scale
# LR data scaling for Hyperball
# We keep the same D_REF here
if args.matrix_optimizer == "hyperball":
D_REF_LR = 10.5 * get_scaling_params(d12_ref)
matrix_lr_scaled = matrix_lr_scaled * (D_REF_LR / target_tokens) ** 0.35 # 0.35 is the exponent for the power law fit by ourselves
D_REF_LR = 10.5 * get_scaling_params(d12_ref)
matrix_lr_scaled = matrix_lr_scaled * (D_REF_LR / target_tokens) ** 0.35 # 0.35 is the exponent for the power law fit by ourselves
print0(f"Scaling hyperball LR from {args.matrix_lr * batch_lr_scale:.6f} to {matrix_lr_scaled:.6f} for token ratio {target_tokens / D_REF:.2f} (T_train = {target_tokens:,} tokens)")
optimizer = model.setup_optimizer(
# AdamW hyperparameters
unembedding_lr=args.unembedding_lr * batch_lr_scale,
embedding_lr=args.embedding_lr * batch_lr_scale,
scalar_lr=args.scalar_lr * batch_lr_scale,
adam_betas=(args.adam_beta1, args.adam_beta2),
norm_lr=args.norm_lr * batch_lr_scale,
# Muon/Hyperball hyperparameters
matrix_lr=matrix_lr_scaled,
weight_decay=weight_decay_scaled,
adam_betas=(args.adam_beta1, args.adam_beta2),
scalar_lr=args.scalar_lr * batch_lr_scale,
norm_lr=args.norm_lr * batch_lr_scale,
matrix_optimizer=args.matrix_optimizer,
)