This commit is contained in:
Dipesh Babu 2026-02-20 02:43:16 -05:00 committed by GitHub
commit 50e9c3a78f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -348,15 +348,28 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# Learning rate schedule (linear warmup, constant, linear warmdown)
def get_lr_multiplier(it):
warmup_iters = round(args.warmup_ratio * num_iterations)
warmdown_iters = round(args.warmdown_ratio * num_iterations)
if it < warmup_iters:
# Note: optimizer steps run for it in [0, num_iterations-1]
warmup_iters = int(round(args.warmup_ratio * num_iterations))
warmdown_iters = int(round(args.warmdown_ratio * num_iterations))
# Warmup (avoid division by zero when warmup_iters == 0)
if warmup_iters > 0 and it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * args.final_lr_frac
# Warmdown should cover the last `warmdown_iters` optimizer steps:
# it in [num_iterations - warmdown_iters, num_iterations - 1]
if warmdown_iters > 0:
warmdown_start = num_iterations - warmdown_iters
# Ensure warmdown doesn't start before warmup ends (prevents overlap weirdness)
warmdown_start = max(warmdown_start, warmup_iters)
if it >= warmdown_start:
# progress: 1.0 at warmdown_start, 0.0 at last optimizer step (num_iterations - 1)
span = max(1, (num_iterations - 1) - warmdown_start) # denom >= 1
progress = (num_iterations - 1 - it) / span
return progress * 1.0 + (1.0 - progress) * args.final_lr_frac
return 1.0
# Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps)
def get_muon_momentum(it):