diff --git a/scripts/base_train.py b/scripts/base_train.py index 25426f5..76e9f88 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -268,7 +268,7 @@ target_tokens = int(args.target_param_data_ratio * num_scaling_params) # optimal # Our reference model is d12, this is where a lot of hyperparameters are tuned and then transfered to higher depths (muP style) d12_ref = build_model_meta(12) # creates the model on meta device -D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) # compute-optimal d12 training horizon in tokens (measured empirically) Shall we use a constant ratio for computing D_REF? +D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) # compute-optimal d12 training horizon in tokens (measured empirically) B_REF = 2**19 # optimal batch size at d12 ~= 524,288 tokens (measured empirically) # 2) Now that we have the token horizon, we can calculate the optimal batch size @@ -308,18 +308,20 @@ matrix_lr_scaled = args.matrix_lr * batch_lr_scale # LR data scaling for Hyperball # We keep the same D_REF here if args.matrix_optimizer == "hyperball": - D_REF_LR = 10.5 * get_scaling_params(d12_ref) - matrix_lr_scaled = matrix_lr_scaled * (D_REF_LR / target_tokens) ** 0.35 # 0.35 is the exponent for the power law fit by ourselves + D_REF_LR = 10.5 * get_scaling_params(d12_ref) + matrix_lr_scaled = matrix_lr_scaled * (D_REF_LR / target_tokens) ** 0.35 # 0.35 is the exponent for the power law fit by ourselves print0(f"Scaling hyperball LR from {args.matrix_lr * batch_lr_scale:.6f} to {matrix_lr_scaled:.6f} for token ratio {target_tokens / D_REF:.2f} (T_train = {target_tokens:,} tokens)") optimizer = model.setup_optimizer( + # AdamW hyperparameters unembedding_lr=args.unembedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale, + scalar_lr=args.scalar_lr * batch_lr_scale, + adam_betas=(args.adam_beta1, args.adam_beta2), + norm_lr=args.norm_lr * batch_lr_scale, + # Muon/Hyperball hyperparameters matrix_lr=matrix_lr_scaled, weight_decay=weight_decay_scaled, - adam_betas=(args.adam_beta1, args.adam_beta2), - scalar_lr=args.scalar_lr * batch_lr_scale, - norm_lr=args.norm_lr * batch_lr_scale, matrix_optimizer=args.matrix_optimizer, )