From 23acb17f173fa22ae42435a90f4195b1220852f0 Mon Sep 17 00:00:00 2001 From: spjosyula Date: Sun, 8 Feb 2026 12:21:18 +0530 Subject: [PATCH] use 1-sqrt warmdown shape for LR schedule --- scripts/base_train.py | 8 +++++--- scripts/chat_sft.py | 9 ++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index ccf35e6..bb58757 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -342,7 +342,8 @@ print0(f"Total number of training tokens: {total_tokens:,}") print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") -# Learning rate schedule (linear warmup, constant, linear warmdown) +# Learning rate schedule (linear warmup, constant, 1-sqrt warmdown) +# 1-sqrt cooldown shape from https://arxiv.org/abs/2405.18392 def get_lr_multiplier(it): warmup_iters = round(args.warmup_ratio * num_iterations) warmdown_iters = round(args.warmdown_ratio * num_iterations) @@ -351,8 +352,9 @@ def get_lr_multiplier(it): elif it <= num_iterations - warmdown_iters: return 1.0 else: - progress = (num_iterations - it) / warmdown_iters - return progress * 1.0 + (1 - progress) * args.final_lr_frac + decay_frac = 1 - (num_iterations - it) / warmdown_iters + lr_mult = 1 - decay_frac ** 0.5 + return lr_mult * (1 - args.final_lr_frac) + args.final_lr_frac # Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps) def get_muon_momentum(it): diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 4c81f06..404a6b7 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -236,10 +236,13 @@ train_loader = sft_data_generator_bos_bestfit("train") build_val_loader = lambda: sft_data_generator_bos_bestfit("val") progress = 0 # will go from 0 to 1 over the course of the epoch -# Learning rate scheduler +# Learning rate scheduler (1-sqrt warmdown, https://arxiv.org/abs/2405.18392) def get_lr_multiplier(progress): - # first 80% of training: no decay, then linearly ramp down to 0. - return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 + # first 80% of training: no decay, then 1-sqrt ramp down to 0. + if progress < 0.8: + return 1.0 + decay_frac = (progress - 0.8) / 0.2 + return 1 - decay_frac ** 0.5 # Momentum scheduler for Muon optimizer def get_muon_momentum(it):