diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 63f257f..530e6d3 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -3,7 +3,6 @@ Utilities for saving and loading model/optim/state checkpoints. """ import os import re -import glob import json import logging import torch @@ -115,10 +114,10 @@ def find_largest_model(checkpoint_dir): def find_last_step(checkpoint_dir): # Look into checkpoint_dir and find model_.pt with the highest step - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) + checkpoint_files = [f for f in os.listdir(checkpoint_dir) if re.search(r'model_(\d+)\.pt$', f)] if not checkpoint_files: raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") - last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) + last_step = max(int(f.split("_")[-1].split(".")[0]) for f in checkpoint_files) return last_step # -----------------------------------------------------------------------------