Refactor find_last_step to use os.listdir with regex filtering

Replace glob.glob() with os.listdir() + regex filtering as suggested by reviewer.
This filters invalid checkpoint files (like model_000200_backup.pt) at the source
instead of globbing then filtering, making the code simpler and more efficient.
This commit is contained in:
Pyry Takala 2025-11-21 19:21:01 +00:00
parent 01f5f10122
commit 3e2a0668b2

View File

@ -3,7 +3,6 @@ Utilities for saving and loading model/optim/state checkpoints.
""" """
import os import os
import re import re
import glob
import json import json
import logging import logging
import torch import torch
@ -115,19 +114,10 @@ def find_largest_model(checkpoint_dir):
def find_last_step(checkpoint_dir): def find_last_step(checkpoint_dir):
# Look into checkpoint_dir and find model_<step>.pt with the highest step # Look into checkpoint_dir and find model_<step>.pt with the highest step
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) checkpoint_files = [f for f in os.listdir(checkpoint_dir) if re.search(r'model_(\d+)\.pt$', f)]
if not checkpoint_files: if not checkpoint_files:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
# Use regex to match only valid checkpoint files (model_<digits>.pt) and ignore malformed files last_step = int(max(re.search(r'model_(\d+)\.pt$', f).group(1) for f in checkpoint_files))
# This prevents crashes when files like model_000200_backup.pt exist in the directory
steps = []
for f in checkpoint_files:
match = re.match(r"model_(\d+)\.pt$", os.path.basename(f))
if match:
steps.append(int(match.group(1)))
if not steps:
raise ValueError(f"No valid checkpoint files found in {checkpoint_dir}")
last_step = max(steps)
return last_step return last_step
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------