introduce lr schedulers and tests

This commit is contained in:
Vilhelm Toivonen 2025-10-13 22:13:37 +03:00
parent 47f7ffa25d
commit 6d51049077
No known key found for this signature in database
GPG Key ID: 587AD4B7CF588708
3 changed files with 96 additions and 17 deletions

50
nanochat/schedules.py Normal file
View File

@ -0,0 +1,50 @@
"""Learning rate schedule utilities."""
def compute_lr_multiplier(
step: int,
total_steps: int,
*,
warmup_ratio: float = 0.0,
warmdown_ratio: float = 0.0,
final_lr_frac: float = 0.0,
) -> float:
"""Compute LR multiplier with linear warmup and warmdown phases.
The multiplier ramps linearly from 0 -> 1 during warmup, stays at 1, then
decays linearly to ``final_lr_frac`` during warmdown. Ratios are expressed
as fractions of ``total_steps``.
"""
if total_steps <= 0:
raise ValueError("total_steps must be positive")
step = min(step, total_steps)
warmup_steps = int(round(warmup_ratio * total_steps))
warmdown_steps = int(round(warmdown_ratio * total_steps))
if warmup_steps > 0 and step < warmup_steps:
return (step + 1) / warmup_steps
if warmdown_steps > 0 and step >= total_steps - warmdown_steps:
progress = (total_steps - step) / max(1, warmdown_steps)
return progress + (1 - progress) * final_lr_frac
return 1.0
def apply_lr_multiplier(
optimizer,
multiplier: float,
*,
base_key: str = "initial_lr",
) -> float:
"""Apply ``multiplier`` to an optimizer in-place using ``base_key`` as base LR."""
for group in optimizer.param_groups:
base_lr = group.get(base_key)
if base_lr is None:
base_lr = group["lr"]
group[base_key] = base_lr
group["lr"] = base_lr * multiplier
return multiplier

View File

@ -21,6 +21,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine from nanochat.engine import Engine
from nanochat.schedules import compute_lr_multiplier, apply_lr_multiplier
from scripts.base_eval import evaluate_model from scripts.base_eval import evaluate_model
print_banner() print_banner()
@ -142,19 +143,12 @@ x, y = next(train_loader) # kick off load of the very first batch of data
# Learning rate scheduler # Learning rate scheduler
# TODO: experiment with a short warmup for the AdamW params (expecting slight improvement) # TODO: experiment with a short warmup for the AdamW params (expecting slight improvement)
warmup_ratio = 0.0 # ratio of iterations for LR warmup adamw_use_lr_warmup = False
adamw_warmup_ratio = 0.0
muon_use_lr_warmup = False
muon_warmup_ratio = 0.0
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR final_lr_frac = 0.0 # final LR is this fraction of the initial LR
def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)
if it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac
# Momentum scheduler for Muon optimizer # Momentum scheduler for Muon optimizer
def get_muon_momentum(it): def get_muon_momentum(it):
@ -265,10 +259,22 @@ for step in range(num_iterations + 1):
if grad_clip > 0.0: if grad_clip > 0.0:
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip) torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
# step the optimizers # step the optimizers
lrm = get_lr_multiplier(step) adamw_lrm = compute_lr_multiplier(
for opt in optimizers: step,
for group in opt.param_groups: num_iterations,
group["lr"] = group["initial_lr"] * lrm warmup_ratio=adamw_warmup_ratio if adamw_use_lr_warmup else 0.0,
warmdown_ratio=warmdown_ratio,
final_lr_frac=final_lr_frac,
)
muon_lrm = compute_lr_multiplier(
step,
num_iterations,
warmup_ratio=muon_warmup_ratio if muon_use_lr_warmup else 0.0,
warmdown_ratio=warmdown_ratio,
final_lr_frac=final_lr_frac,
)
apply_lr_multiplier(adamw_optimizer, adamw_lrm)
apply_lr_multiplier(muon_optimizer, muon_lrm)
muon_momentum = get_muon_momentum(step) muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups: for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum group["momentum"] = muon_momentum
@ -290,14 +296,15 @@ for step in range(num_iterations + 1):
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10: if step > 10:
total_training_time += dt # only count the time after the first 10 steps total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm_adamw: {adamw_lrm:.2f} | lrm_muon: {muon_lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 100 == 0: if step % 100 == 0:
wandb_run.log({ wandb_run.log({
"step": step, "step": step,
"total_training_flops": flops_so_far, "total_training_flops": flops_so_far,
"total_training_time": total_training_time, "total_training_time": total_training_time,
"train/loss": debiased_smooth_loss, "train/loss": debiased_smooth_loss,
"train/lrm": lrm, "train/adamw_lrm": adamw_lrm,
"train/muon_lrm": muon_lrm,
"train/dt": dt, "train/dt": dt,
"train/tok_per_sec": tok_per_sec, "train/tok_per_sec": tok_per_sec,
"train/mfu": mfu, "train/mfu": mfu,

22
tests/test_schedules.py Normal file
View File

@ -0,0 +1,22 @@
import torch
import pytest
from nanochat.schedules import compute_lr_multiplier, apply_lr_multiplier
def test_compute_lr_multiplier_handles_warmup():
multiplier = compute_lr_multiplier(0, 100, warmup_ratio=0.1)
assert multiplier == pytest.approx(0.1)
def test_compute_lr_multiplier_handles_warmdown():
multiplier = compute_lr_multiplier(95, 100, warmdown_ratio=0.1, final_lr_frac=0.1)
# progress = (100-95)/10 = 0.5 -> 0.5 + 0.5*0.1
assert multiplier == pytest.approx(0.55)
def test_apply_lr_multiplier_uses_initial_lr():
param = torch.nn.Parameter(torch.ones(()))
opt = torch.optim.SGD([param], lr=0.2)
apply_lr_multiplier(opt, 0.5)
assert opt.param_groups[0]["lr"] == pytest.approx(0.1)
assert opt.param_groups[0]["initial_lr"] == pytest.approx(0.2)
apply_lr_multiplier(opt, 1.0)
assert opt.param_groups[0]["lr"] == pytest.approx(0.2)