diff --git a/scripts/base_train.py b/scripts/base_train.py index 1a472f6..343d2d9 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -349,8 +349,8 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") # Learning rate schedule (linear warmup, constant, linear warmdown) def get_lr_multiplier(it): # Note: optimizer steps run for it in [0, num_iterations-1] - warmup_iters = int(round(args.warmup_ratio * num_iterations)) - warmdown_iters = int(round(args.warmdown_ratio * num_iterations)) + warmup_iters = round(args.warmup_ratio * num_iterations) + warmdown_iters = round(args.warmdown_ratio * num_iterations) # Warmup (avoid division by zero when warmup_iters == 0) if warmup_iters > 0 and it < warmup_iters: