mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Add error handling for missing checkpoint files in load_checkpoint function
Enhance the load_checkpoint function to raise FileNotFoundError for missing model, optimizer, and metadata files, improving robustness and user feedback.
This commit is contained in:
parent
67aaca98f5
commit
04722913b3
|
|
@ -42,14 +42,22 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data)
|
|||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
|
||||
# Load the model state
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model checkpoint not found at {model_path}")
|
||||
model_data = torch.load(model_path, map_location=device)
|
||||
|
||||
# Load the optimizer state if requested
|
||||
optimizer_data = None
|
||||
if load_optimizer:
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
||||
if not os.path.exists(optimizer_path):
|
||||
raise FileNotFoundError(f"Optimizer checkpoint not found at {optimizer_path}")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
|
||||
# Load the metadata
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
if not os.path.exists(meta_path):
|
||||
raise FileNotFoundError(f"Metadata file not found at {meta_path}")
|
||||
with open(meta_path, "r") as f:
|
||||
meta_data = json.load(f)
|
||||
return model_data, optimizer_data, meta_data
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user