diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index a9327c4..57bf78a 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -42,14 +42,22 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data) def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False): # Load the model state model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model checkpoint not found at {model_path}") model_data = torch.load(model_path, map_location=device) + # Load the optimizer state if requested optimizer_data = None if load_optimizer: optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") + if not os.path.exists(optimizer_path): + raise FileNotFoundError(f"Optimizer checkpoint not found at {optimizer_path}") optimizer_data = torch.load(optimizer_path, map_location=device) + # Load the metadata meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") + if not os.path.exists(meta_path): + raise FileNotFoundError(f"Metadata file not found at {meta_path}") with open(meta_path, "r") as f: meta_data = json.load(f) return model_data, optimizer_data, meta_data