mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-07 00:09:50 +00:00
Merge 17ff2593f3 into dc54a1a307
This commit is contained in:
commit
d6e1154a0b
|
|
@ -27,17 +27,26 @@ def _patch_missing_config_keys(model_config_kwargs):
|
|||
model_config_kwargs["window_pattern"] = "L"
|
||||
log0(f"Patching missing window_pattern in model config to 'L'")
|
||||
|
||||
def _patch_missing_keys(model_data, model_config):
|
||||
def _patch_missing_keys(model_data, model_config, device):
|
||||
"""Add default values for new parameters that may be missing in old checkpoints."""
|
||||
n_layer = model_config.n_layer
|
||||
# resid_lambdas defaults to 1.0 (identity scaling)
|
||||
if "resid_lambdas" not in model_data:
|
||||
model_data["resid_lambdas"] = torch.ones(n_layer)
|
||||
model_data["resid_lambdas"] = torch.ones(n_layer).to(device)
|
||||
log0(f"Patching missing resid_lambdas in model data to 1.0")
|
||||
# x0_lambdas defaults to 0.0 (disabled)
|
||||
if "x0_lambdas" not in model_data:
|
||||
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
||||
model_data["x0_lambdas"] = torch.zeros(n_layer).to(device)
|
||||
log0(f"Patching missing x0_lambdas in model data to 0.0")
|
||||
if "smear_gate.weight" not in model_data:
|
||||
model_data["smear_gate.weight"] = torch.ones(1, 24).to(device)
|
||||
log0(f"Patching missing smear_gate.weight in model data to 1.0")
|
||||
if "smear_lambda" not in model_data:
|
||||
model_data["smear_lambda"] = torch.zeros(1).to(device)
|
||||
log0(f"Patching missing smear_lambda in model data to 0.0")
|
||||
if "backout_lambda" not in model_data:
|
||||
model_data["backout_lambda"] = 0.2 * torch.ones(1).to(device)
|
||||
log0(f"Patching missing backout_lambda in model data to 0.2")
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
|
|
@ -96,7 +105,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
_patch_missing_config_keys(model_config_kwargs)
|
||||
log0(f"Building model with config: {model_config_kwargs}")
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
_patch_missing_keys(model_data, model_config)
|
||||
_patch_missing_keys(model_data, model_config, device)
|
||||
with torch.device("meta"):
|
||||
model = GPT(model_config)
|
||||
# Load the model state
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user