This commit is contained in:
Arthur Testard 2026-05-05 14:12:35 +02:00 committed by GitHub
commit d6e1154a0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,17 +27,26 @@ def _patch_missing_config_keys(model_config_kwargs):
model_config_kwargs["window_pattern"] = "L"
log0(f"Patching missing window_pattern in model config to 'L'")
def _patch_missing_keys(model_data, model_config):
def _patch_missing_keys(model_data, model_config, device):
"""Add default values for new parameters that may be missing in old checkpoints."""
n_layer = model_config.n_layer
# resid_lambdas defaults to 1.0 (identity scaling)
if "resid_lambdas" not in model_data:
model_data["resid_lambdas"] = torch.ones(n_layer)
model_data["resid_lambdas"] = torch.ones(n_layer).to(device)
log0(f"Patching missing resid_lambdas in model data to 1.0")
# x0_lambdas defaults to 0.0 (disabled)
if "x0_lambdas" not in model_data:
model_data["x0_lambdas"] = torch.zeros(n_layer)
model_data["x0_lambdas"] = torch.zeros(n_layer).to(device)
log0(f"Patching missing x0_lambdas in model data to 0.0")
if "smear_gate.weight" not in model_data:
model_data["smear_gate.weight"] = torch.ones(1, 24).to(device)
log0(f"Patching missing smear_gate.weight in model data to 1.0")
if "smear_lambda" not in model_data:
model_data["smear_lambda"] = torch.zeros(1).to(device)
log0(f"Patching missing smear_lambda in model data to 0.0")
if "backout_lambda" not in model_data:
model_data["backout_lambda"] = 0.2 * torch.ones(1).to(device)
log0(f"Patching missing backout_lambda in model data to 0.2")
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
@ -96,7 +105,7 @@ def build_model(checkpoint_dir, step, device, phase):
_patch_missing_config_keys(model_config_kwargs)
log0(f"Building model with config: {model_config_kwargs}")
model_config = GPTConfig(**model_config_kwargs)
_patch_missing_keys(model_data, model_config)
_patch_missing_keys(model_data, model_config, device)
with torch.device("meta"):
model = GPT(model_config)
# Load the model state