mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
fix bug in learning rate multiplier, it was ramping up instead of ramping down. see more in Issue #68. also add --dry_run option useful for experimentation
This commit is contained in:
parent
67aaca98f5
commit
b8076dd367
|
|
@ -40,10 +40,10 @@ embedding_lr = 0.2
|
||||||
matrix_lr = 0.02
|
matrix_lr = 0.02
|
||||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||||
weight_decay = 0.0
|
weight_decay = 0.0
|
||||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
|
||||||
eval_every = 150
|
eval_every = 150
|
||||||
eval_tokens = 20*524288
|
eval_tokens = 20*524288
|
||||||
total_batch_size = 524288
|
total_batch_size = 524288
|
||||||
|
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||||
|
|
@ -141,7 +141,8 @@ progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||||
|
|
||||||
# Learning rate scheduler
|
# Learning rate scheduler
|
||||||
def get_lr_multiplier(progress):
|
def get_lr_multiplier(progress):
|
||||||
return progress * 1.0 + (1 - progress) * final_lr_frac
|
# first 80% of training: no decay, then linearly ramp down to 0.
|
||||||
|
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
|
||||||
|
|
||||||
# Momentum scheduler for Muon optimizer
|
# Momentum scheduler for Muon optimizer
|
||||||
def get_muon_momentum(it):
|
def get_muon_momentum(it):
|
||||||
|
|
@ -185,7 +186,7 @@ while True:
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
# save checkpoint at the end of the run (only on master process)
|
# save checkpoint at the end of the run (only on master process)
|
||||||
if master_process and last_step:
|
if master_process and last_step and not dry_run:
|
||||||
output_dirname = f"d{depth}" # e.g. d12
|
output_dirname = f"d{depth}" # e.g. d12
|
||||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
|
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
|
|
@ -272,17 +273,18 @@ print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||||
|
|
||||||
# Log to report
|
# Log to report
|
||||||
from nanochat.report import get_report
|
if not dry_run:
|
||||||
get_report().log(section="Midtraining", data=[
|
from nanochat.report import get_report
|
||||||
user_config, # CLI args
|
get_report().log(section="Midtraining", data=[
|
||||||
{ # stats about the training setup
|
user_config, # CLI args
|
||||||
"Number of iterations": step,
|
{ # stats about the training setup
|
||||||
"DDP world size": ddp_world_size,
|
"Number of iterations": step,
|
||||||
},
|
"DDP world size": ddp_world_size,
|
||||||
{ # stats about training outcomes
|
},
|
||||||
"Minimum validation bpb": min_val_bpb,
|
{ # stats about training outcomes
|
||||||
}
|
"Minimum validation bpb": min_val_bpb,
|
||||||
])
|
}
|
||||||
|
])
|
||||||
|
|
||||||
# cleanup
|
# cleanup
|
||||||
wandb_run.finish() # wandb run finish
|
wandb_run.finish() # wandb run finish
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user