mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-05 15:15:48 +00:00
Merge 23acb17f17 into 2f09686724
This commit is contained in:
commit
0c73e53aaa
|
|
@ -344,7 +344,8 @@ print0(f"Total number of training tokens: {total_tokens:,}")
|
|||
print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# Learning rate schedule (linear warmup, constant, linear warmdown)
|
||||
# Learning rate schedule (linear warmup, constant, 1-sqrt warmdown)
|
||||
# 1-sqrt cooldown shape from https://arxiv.org/abs/2405.18392
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(args.warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
||||
|
|
@ -353,8 +354,9 @@ def get_lr_multiplier(it):
|
|||
elif it <= num_iterations - warmdown_iters:
|
||||
return 1.0
|
||||
else:
|
||||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
||||
decay_frac = 1 - (num_iterations - it) / warmdown_iters
|
||||
lr_mult = 1 - decay_frac ** 0.5
|
||||
return lr_mult * (1 - args.final_lr_frac) + args.final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps)
|
||||
def get_muon_momentum(it):
|
||||
|
|
|
|||
|
|
@ -236,10 +236,13 @@ train_loader = sft_data_generator_bos_bestfit("train")
|
|||
build_val_loader = lambda: sft_data_generator_bos_bestfit("val")
|
||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||
|
||||
# Learning rate scheduler
|
||||
# Learning rate scheduler (1-sqrt warmdown, https://arxiv.org/abs/2405.18392)
|
||||
def get_lr_multiplier(progress):
|
||||
# first 80% of training: no decay, then linearly ramp down to 0.
|
||||
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
|
||||
# first 80% of training: no decay, then 1-sqrt ramp down to 0.
|
||||
if progress < 0.8:
|
||||
return 1.0
|
||||
decay_frac = (progress - 0.8) / 0.2
|
||||
return 1 - decay_frac ** 0.5
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user