mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-21 02:44:13 +00:00
renamed find_largest_model() argument checkpoint_dir to checkpoints_dir for clarity
This commit is contained in:
parent
4a87a0d19f
commit
a4a0959c73
|
|
@ -93,11 +93,11 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
return model, tokenizer, meta_data
|
||||
|
||||
|
||||
def find_largest_model(checkpoint_dir):
|
||||
def find_largest_model(checkpoints_dir):
|
||||
# attempt to guess the model tag: take the biggest model available
|
||||
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
|
||||
model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
|
||||
if not model_tags:
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
|
||||
# 1) normally all model tags are of the form d<number>, try that first:
|
||||
candidates = []
|
||||
for model_tag in model_tags:
|
||||
|
|
@ -109,7 +109,7 @@ def find_largest_model(checkpoint_dir):
|
|||
candidates.sort(key=lambda x: x[0], reverse=True)
|
||||
return candidates[0][1]
|
||||
# 2) if that failed, take the most recently updated model:
|
||||
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
|
||||
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
|
||||
return model_tags[0]
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user