Add --save-keep-latest=N for checkpoint rotation

Pre-save deletion of oldest model_*.pt + optim_*.pt + meta_*.json so disk
peak stays bounded. Default -1 = keep all (no behavior change).

Used by the d26 1-week dense run to bound live checkpoints across many
save-every cycles.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
gio 2026-05-10 13:19:26 -05:00
parent 1e7810ddaa
commit 280649f481

View File

@ -76,10 +76,30 @@ parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluat
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
parser.add_argument("--save-keep-latest", type=int, default=-1, help="keep only the N most recent checkpoints in the run dir (deletes older model_*.pt + optim_*.pt + meta_*.json before writing the next save). -1 = keep all.")
# Output
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
args = parser.parse_args()
user_config = vars(args).copy() # for logging
def _prune_old_checkpoints(checkpoint_dir, keep_latest, ddp_world_size):
# Delete the oldest model/optim/meta sets so that after the next save we have at most keep_latest on disk.
# Runs pre-save on rank 0 to free disk for the upcoming write.
import glob, re
if keep_latest <= 0:
return
paths = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
steps = sorted(int(re.search(r"model_(\d+)\.pt$", p).group(1)) for p in paths if re.search(r"model_(\d+)\.pt$", p))
while len(steps) >= keep_latest:
old = steps.pop(0)
s = f"{old:06d}"
for f in (
os.path.join(checkpoint_dir, f"model_{s}.pt"),
os.path.join(checkpoint_dir, f"meta_{s}.json"),
*[os.path.join(checkpoint_dir, f"optim_{s}_rank{r:d}.pt") for r in range(ddp_world_size)],
):
if os.path.exists(f):
os.remove(f)
# -----------------------------------------------------------------------------
# Compute init and wandb logging
@ -477,6 +497,8 @@ while True:
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
if ddp_rank == 0 and args.save_keep_latest > 0:
_prune_old_checkpoints(checkpoint_dir, args.save_keep_latest, ddp_world_size)
save_checkpoint(
checkpoint_dir,
step,