mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-03 22:25:27 +00:00
chore: clarify LR warmup/warmdown schedule in base_train
This commit is contained in:
parent
0fde31156c
commit
9a2e40eff0
|
|
@ -347,27 +347,33 @@ print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num
|
|||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# Learning rate schedule (linear warmup, constant, linear warmdown)
|
||||
def get_lr_multiplier(it):
|
||||
# Note: optimizer steps run for it in [0, num_iterations-1]
|
||||
def get_lr_multiplier(it: int) -> float:
|
||||
# Note: optimizer steps run for it in [0, num_iterations - 1]
|
||||
warmup_iters = round(args.warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
||||
|
||||
# Warmup (avoid division by zero when warmup_iters == 0)
|
||||
if warmup_iters > 0 and it < warmup_iters:
|
||||
# Clamp to sane ranges
|
||||
warmup_iters = max(0, min(warmup_iters, num_iterations))
|
||||
warmdown_iters = max(0, min(warmdown_iters, num_iterations))
|
||||
|
||||
# Warmup: linear ramp from (1/warmup_iters) .. 1.0 over warmup_iters steps
|
||||
if it < warmup_iters:
|
||||
# safe: if warmup_iters == 0 this branch is unreachable
|
||||
return (it + 1) / warmup_iters
|
||||
|
||||
# Warmdown should cover the last `warmdown_iters` optimizer steps:
|
||||
# Warmdown: apply over the last warmdown_iters optimizer steps:
|
||||
# it in [num_iterations - warmdown_iters, num_iterations - 1]
|
||||
if warmdown_iters > 0:
|
||||
warmdown_start = num_iterations - warmdown_iters
|
||||
# Ensure warmdown doesn't start before warmup ends (prevents overlap weirdness)
|
||||
|
||||
# If warmup overlaps warmdown, start warmdown only after warmup is done
|
||||
warmdown_start = max(warmdown_start, warmup_iters)
|
||||
|
||||
if it >= warmdown_start:
|
||||
# progress: 1.0 at warmdown_start, 0.0 at last optimizer step (num_iterations - 1)
|
||||
span = max(1, (num_iterations - 1) - warmdown_start) # denom >= 1
|
||||
progress = (num_iterations - 1 - it) / span
|
||||
return progress * 1.0 + (1.0 - progress) * args.final_lr_frac
|
||||
# progress goes 1.0 -> 1.0/warmdown_iters (NOT to 0 by default)
|
||||
# This matches the original behavior where final step doesn't hit 0
|
||||
progress = (num_iterations - it) / warmdown_iters
|
||||
return args.final_lr_frac + (1.0 - args.final_lr_frac) * progress
|
||||
|
||||
return 1.0
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user