From 9a2e40eff0d47a78b3cd803d0e08bfc9f7e73d11 Mon Sep 17 00:00:00 2001 From: Dipesh Babu Date: Fri, 20 Feb 2026 09:52:10 -0500 Subject: [PATCH] chore: clarify LR warmup/warmdown schedule in base_train --- scripts/base_train.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index 343d2d9..8ef7aae 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -347,27 +347,33 @@ print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") # Learning rate schedule (linear warmup, constant, linear warmdown) -def get_lr_multiplier(it): - # Note: optimizer steps run for it in [0, num_iterations-1] +def get_lr_multiplier(it: int) -> float: + # Note: optimizer steps run for it in [0, num_iterations - 1] warmup_iters = round(args.warmup_ratio * num_iterations) warmdown_iters = round(args.warmdown_ratio * num_iterations) - # Warmup (avoid division by zero when warmup_iters == 0) - if warmup_iters > 0 and it < warmup_iters: + # Clamp to sane ranges + warmup_iters = max(0, min(warmup_iters, num_iterations)) + warmdown_iters = max(0, min(warmdown_iters, num_iterations)) + + # Warmup: linear ramp from (1/warmup_iters) .. 1.0 over warmup_iters steps + if it < warmup_iters: + # safe: if warmup_iters == 0 this branch is unreachable return (it + 1) / warmup_iters - # Warmdown should cover the last `warmdown_iters` optimizer steps: + # Warmdown: apply over the last warmdown_iters optimizer steps: # it in [num_iterations - warmdown_iters, num_iterations - 1] if warmdown_iters > 0: warmdown_start = num_iterations - warmdown_iters - # Ensure warmdown doesn't start before warmup ends (prevents overlap weirdness) + + # If warmup overlaps warmdown, start warmdown only after warmup is done warmdown_start = max(warmdown_start, warmup_iters) if it >= warmdown_start: - # progress: 1.0 at warmdown_start, 0.0 at last optimizer step (num_iterations - 1) - span = max(1, (num_iterations - 1) - warmdown_start) # denom >= 1 - progress = (num_iterations - 1 - it) / span - return progress * 1.0 + (1.0 - progress) * args.final_lr_frac + # progress goes 1.0 -> 1.0/warmdown_iters (NOT to 0 by default) + # This matches the original behavior where final step doesn't hit 0 + progress = (num_iterations - it) / warmdown_iters + return args.final_lr_frac + (1.0 - args.final_lr_frac) * progress return 1.0