Add error handling for missing checkpoint files in load_checkpoint function

Enhance the load_checkpoint function to raise FileNotFoundError for missing model, optimizer, and metadata files, improving robustness and user feedback.
This commit is contained in:
Mert Cobanov 2025-10-15 13:47:41 +03:00
parent 67aaca98f5
commit 04722913b3

View File

@ -42,14 +42,22 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data)
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False): def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
# Load the model state # Load the model state
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model checkpoint not found at {model_path}")
model_data = torch.load(model_path, map_location=device) model_data = torch.load(model_path, map_location=device)
# Load the optimizer state if requested # Load the optimizer state if requested
optimizer_data = None optimizer_data = None
if load_optimizer: if load_optimizer:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
if not os.path.exists(optimizer_path):
raise FileNotFoundError(f"Optimizer checkpoint not found at {optimizer_path}")
optimizer_data = torch.load(optimizer_path, map_location=device) optimizer_data = torch.load(optimizer_path, map_location=device)
# Load the metadata # Load the metadata
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
if not os.path.exists(meta_path):
raise FileNotFoundError(f"Metadata file not found at {meta_path}")
with open(meta_path, "r") as f: with open(meta_path, "r") as f:
meta_data = json.load(f) meta_data = json.load(f)
return model_data, optimizer_data, meta_data return model_data, optimizer_data, meta_data