diff --git a/scripts/base_train.py b/scripts/base_train.py index 7ed6330..cc858d4 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -238,15 +238,28 @@ x, y, dataloader_state_dict = next(train_loader) # kick off load of the very fir # Learning rate scheduler def get_lr_multiplier(it): - warmup_iters = round(args.warmup_ratio * num_iterations) - warmdown_iters = round(args.warmdown_ratio * num_iterations) - if it < warmup_iters: + # Note: optimizer steps run for it in [0, num_iterations-1] + warmup_iters = int(round(args.warmup_ratio * num_iterations)) + warmdown_iters = int(round(args.warmdown_ratio * num_iterations)) + + # Warmup (avoid division by zero when warmup_iters == 0) + if warmup_iters > 0 and it < warmup_iters: return (it + 1) / warmup_iters - elif it <= num_iterations - warmdown_iters: - return 1.0 - else: - progress = (num_iterations - it) / warmdown_iters - return progress * 1.0 + (1 - progress) * args.final_lr_frac + + # Warmdown should cover the last `warmdown_iters` optimizer steps: + # it in [num_iterations - warmdown_iters, num_iterations - 1] + if warmdown_iters > 0: + warmdown_start = num_iterations - warmdown_iters + # Ensure warmdown doesn't start before warmup ends (prevents overlap weirdness) + warmdown_start = max(warmdown_start, warmup_iters) + + if it >= warmdown_start: + # progress: 1.0 at warmdown_start, 0.0 at last optimizer step (num_iterations - 1) + span = max(1, (num_iterations - 1) - warmdown_start) # denom >= 1 + progress = (num_iterations - 1 - it) / span + return progress * 1.0 + (1.0 - progress) * args.final_lr_frac + + return 1.0 # Momentum scheduler for Muon optimizer def get_muon_momentum(it):