diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 63f257f..f689371 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -93,11 +93,11 @@ def build_model(checkpoint_dir, step, device, phase): return model, tokenizer, meta_data -def find_largest_model(checkpoint_dir): +def find_largest_model(checkpoints_dir): # attempt to guess the model tag: take the biggest model available - model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))] + model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))] if not model_tags: - raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") + raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}") # 1) normally all model tags are of the form d, try that first: candidates = [] for model_tag in model_tags: @@ -109,7 +109,7 @@ def find_largest_model(checkpoint_dir): candidates.sort(key=lambda x: x[0], reverse=True) return candidates[0][1] # 2) if that failed, take the most recently updated model: - model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) + model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True) return model_tags[0]