This commit is contained in:
Mert Cobanov 2025-11-04 12:40:33 +01:00 committed by GitHub
commit a37fd2d37f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -42,14 +42,22 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data)
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
# Load the model state
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model checkpoint not found at {model_path}")
model_data = torch.load(model_path, map_location=device)
# Load the optimizer state if requested
optimizer_data = None
if load_optimizer:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
if not os.path.exists(optimizer_path):
raise FileNotFoundError(f"Optimizer checkpoint not found at {optimizer_path}")
optimizer_data = torch.load(optimizer_path, map_location=device)
# Load the metadata
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
if not os.path.exists(meta_path):
raise FileNotFoundError(f"Metadata file not found at {meta_path}")
with open(meta_path, "r") as f:
meta_data = json.load(f)
return model_data, optimizer_data, meta_data