renamed find_largest_model() argument checkpoint_dir to checkpoints_dir for clarity

This commit is contained in:
Eric Silberstein 2025-11-19 15:33:36 -05:00
parent 4a87a0d19f
commit a4a0959c73

View File

@ -93,11 +93,11 @@ def build_model(checkpoint_dir, step, device, phase):
return model, tokenizer, meta_data
def find_largest_model(checkpoint_dir):
def find_largest_model(checkpoints_dir):
# attempt to guess the model tag: take the biggest model available
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
if not model_tags:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
# 1) normally all model tags are of the form d<number>, try that first:
candidates = []
for model_tag in model_tags:
@ -109,7 +109,7 @@ def find_largest_model(checkpoint_dir):
candidates.sort(key=lambda x: x[0], reverse=True)
return candidates[0][1]
# 2) if that failed, take the most recently updated model:
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
return model_tags[0]