From b8076dd367a6ba8378f1e7d32afb545b30fe15f8 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 15 Oct 2025 16:35:04 +0000 Subject: [PATCH] fix bug in learning rate multiplier, it was ramping up instead of ramping down. see more in Issue #68. also add --dry_run option useful for experimentation --- scripts/mid_train.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 202682d..90ab954 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -40,10 +40,10 @@ embedding_lr = 0.2 matrix_lr = 0.02 init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate weight_decay = 0.0 -final_lr_frac = 0.0 # final LR is this fraction of the initial LR eval_every = 150 eval_tokens = 20*524288 total_batch_size = 524288 +dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging @@ -141,7 +141,8 @@ progress = 0 # will go from 0 to 1 over the course of the epoch # Learning rate scheduler def get_lr_multiplier(progress): - return progress * 1.0 + (1 - progress) * final_lr_frac + # first 80% of training: no decay, then linearly ramp down to 0. + return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 # Momentum scheduler for Muon optimizer def get_muon_momentum(it): @@ -185,7 +186,7 @@ while True: model.train() # save checkpoint at the end of the run (only on master process) - if master_process and last_step: + if master_process and last_step and not dry_run: output_dirname = f"d{depth}" # e.g. d12 checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname) save_checkpoint( @@ -272,17 +273,18 @@ print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") # Log to report -from nanochat.report import get_report -get_report().log(section="Midtraining", data=[ - user_config, # CLI args - { # stats about the training setup - "Number of iterations": step, - "DDP world size": ddp_world_size, - }, - { # stats about training outcomes - "Minimum validation bpb": min_val_bpb, - } -]) +if not dry_run: + from nanochat.report import get_report + get_report().log(section="Midtraining", data=[ + user_config, # CLI args + { # stats about the training setup + "Number of iterations": step, + "DDP world size": ddp_world_size, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb, + } + ]) # cleanup wandb_run.finish() # wandb run finish