From 280649f48189cb9928f828976bcf34880a225b14 Mon Sep 17 00:00:00 2001 From: gio Date: Sun, 10 May 2026 13:19:26 -0500 Subject: [PATCH] Add --save-keep-latest=N for checkpoint rotation Pre-save deletion of oldest model_*.pt + optim_*.pt + meta_*.json so disk peak stays bounded. Default -1 = keep all (no behavior change). Used by the d26 1-week dense run to bound live checkpoints across many save-every cycles. Co-Authored-By: Claude Opus 4.7 (1M context) --- scripts/base_train.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/scripts/base_train.py b/scripts/base_train.py index 724cb8d3..d3cf587e 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -76,10 +76,30 @@ parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluat parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") +parser.add_argument("--save-keep-latest", type=int, default=-1, help="keep only the N most recent checkpoints in the run dir (deletes older model_*.pt + optim_*.pt + meta_*.json before writing the next save). -1 = keep all.") # Output parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") args = parser.parse_args() user_config = vars(args).copy() # for logging + +def _prune_old_checkpoints(checkpoint_dir, keep_latest, ddp_world_size): + # Delete the oldest model/optim/meta sets so that after the next save we have at most keep_latest on disk. + # Runs pre-save on rank 0 to free disk for the upcoming write. + import glob, re + if keep_latest <= 0: + return + paths = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) + steps = sorted(int(re.search(r"model_(\d+)\.pt$", p).group(1)) for p in paths if re.search(r"model_(\d+)\.pt$", p)) + while len(steps) >= keep_latest: + old = steps.pop(0) + s = f"{old:06d}" + for f in ( + os.path.join(checkpoint_dir, f"model_{s}.pt"), + os.path.join(checkpoint_dir, f"meta_{s}.json"), + *[os.path.join(checkpoint_dir, f"optim_{s}_rank{r:d}.pt") for r in range(ddp_world_size)], + ): + if os.path.exists(f): + os.remove(f) # ----------------------------------------------------------------------------- # Compute init and wandb logging @@ -477,6 +497,8 @@ while True: # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0): + if ddp_rank == 0 and args.save_keep_latest > 0: + _prune_old_checkpoints(checkpoint_dir, args.save_keep_latest, ddp_world_size) save_checkpoint( checkpoint_dir, step,