diff --git a/scripts/base_train.py b/scripts/base_train.py index 724cb8d3..d3cf587e 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -76,10 +76,30 @@ parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluat parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") +parser.add_argument("--save-keep-latest", type=int, default=-1, help="keep only the N most recent checkpoints in the run dir (deletes older model_*.pt + optim_*.pt + meta_*.json before writing the next save). -1 = keep all.") # Output parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") args = parser.parse_args() user_config = vars(args).copy() # for logging + +def _prune_old_checkpoints(checkpoint_dir, keep_latest, ddp_world_size): + # Delete the oldest model/optim/meta sets so that after the next save we have at most keep_latest on disk. + # Runs pre-save on rank 0 to free disk for the upcoming write. + import glob, re + if keep_latest <= 0: + return + paths = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) + steps = sorted(int(re.search(r"model_(\d+)\.pt$", p).group(1)) for p in paths if re.search(r"model_(\d+)\.pt$", p)) + while len(steps) >= keep_latest: + old = steps.pop(0) + s = f"{old:06d}" + for f in ( + os.path.join(checkpoint_dir, f"model_{s}.pt"), + os.path.join(checkpoint_dir, f"meta_{s}.json"), + *[os.path.join(checkpoint_dir, f"optim_{s}_rank{r:d}.pt") for r in range(ddp_world_size)], + ): + if os.path.exists(f): + os.remove(f) # ----------------------------------------------------------------------------- # Compute init and wandb logging @@ -477,6 +497,8 @@ while True: # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0): + if ddp_rank == 0 and args.save_keep_latest > 0: + _prune_old_checkpoints(checkpoint_dir, args.save_keep_latest, ddp_world_size) save_checkpoint( checkpoint_dir, step,