This commit is contained in:
Sitananda Prasad 2026-02-08 20:43:17 +02:00 committed by GitHub
commit 79aa242302
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 6 deletions

View File

@ -342,7 +342,8 @@ print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# Learning rate schedule (linear warmup, constant, linear warmdown)
# Learning rate schedule (linear warmup, constant, 1-sqrt warmdown)
# 1-sqrt cooldown shape from https://arxiv.org/abs/2405.18392
def get_lr_multiplier(it):
warmup_iters = round(args.warmup_ratio * num_iterations)
warmdown_iters = round(args.warmdown_ratio * num_iterations)
@ -351,8 +352,9 @@ def get_lr_multiplier(it):
elif it <= num_iterations - warmdown_iters:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * args.final_lr_frac
decay_frac = 1 - (num_iterations - it) / warmdown_iters
lr_mult = 1 - decay_frac ** 0.5
return lr_mult * (1 - args.final_lr_frac) + args.final_lr_frac
# Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps)
def get_muon_momentum(it):

View File

@ -236,10 +236,13 @@ train_loader = sft_data_generator_bos_bestfit("train")
build_val_loader = lambda: sft_data_generator_bos_bestfit("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
# Learning rate scheduler (1-sqrt warmdown, https://arxiv.org/abs/2405.18392)
def get_lr_multiplier(progress):
# first 80% of training: no decay, then linearly ramp down to 0.
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
# first 80% of training: no decay, then 1-sqrt ramp down to 0.
if progress < 0.8:
return 1.0
decay_frac = (progress - 0.8) / 0.2
return 1 - decay_frac ** 0.5
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):