mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Add error handling for missing checkpoint files in load_checkpoint function
Enhance the load_checkpoint function to raise FileNotFoundError for missing model, optimizer, and metadata files, improving robustness and user feedback.
This commit is contained in:
parent
67aaca98f5
commit
04722913b3
|
|
@ -42,14 +42,22 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data)
|
||||||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
|
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
|
||||||
# Load the model state
|
# Load the model state
|
||||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"Model checkpoint not found at {model_path}")
|
||||||
model_data = torch.load(model_path, map_location=device)
|
model_data = torch.load(model_path, map_location=device)
|
||||||
|
|
||||||
# Load the optimizer state if requested
|
# Load the optimizer state if requested
|
||||||
optimizer_data = None
|
optimizer_data = None
|
||||||
if load_optimizer:
|
if load_optimizer:
|
||||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
||||||
|
if not os.path.exists(optimizer_path):
|
||||||
|
raise FileNotFoundError(f"Optimizer checkpoint not found at {optimizer_path}")
|
||||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||||
|
|
||||||
# Load the metadata
|
# Load the metadata
|
||||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||||
|
if not os.path.exists(meta_path):
|
||||||
|
raise FileNotFoundError(f"Metadata file not found at {meta_path}")
|
||||||
with open(meta_path, "r") as f:
|
with open(meta_path, "r") as f:
|
||||||
meta_data = json.load(f)
|
meta_data = json.load(f)
|
||||||
return model_data, optimizer_data, meta_data
|
return model_data, optimizer_data, meta_data
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user