diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f71524ed..d72de2fe 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -27,17 +27,26 @@ def _patch_missing_config_keys(model_config_kwargs): model_config_kwargs["window_pattern"] = "L" log0(f"Patching missing window_pattern in model config to 'L'") -def _patch_missing_keys(model_data, model_config): +def _patch_missing_keys(model_data, model_config, device): """Add default values for new parameters that may be missing in old checkpoints.""" n_layer = model_config.n_layer # resid_lambdas defaults to 1.0 (identity scaling) if "resid_lambdas" not in model_data: - model_data["resid_lambdas"] = torch.ones(n_layer) + model_data["resid_lambdas"] = torch.ones(n_layer).to(device) log0(f"Patching missing resid_lambdas in model data to 1.0") # x0_lambdas defaults to 0.0 (disabled) if "x0_lambdas" not in model_data: - model_data["x0_lambdas"] = torch.zeros(n_layer) + model_data["x0_lambdas"] = torch.zeros(n_layer).to(device) log0(f"Patching missing x0_lambdas in model data to 0.0") + if "smear_gate.weight" not in model_data: + model_data["smear_gate.weight"] = torch.ones(1, 24).to(device) + log0(f"Patching missing smear_gate.weight in model data to 1.0") + if "smear_lambda" not in model_data: + model_data["smear_lambda"] = torch.zeros(1).to(device) + log0(f"Patching missing smear_lambda in model data to 0.0") + if "backout_lambda" not in model_data: + model_data["backout_lambda"] = 0.2 * torch.ones(1).to(device) + log0(f"Patching missing backout_lambda in model data to 0.2") def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): if rank == 0: @@ -96,7 +105,7 @@ def build_model(checkpoint_dir, step, device, phase): _patch_missing_config_keys(model_config_kwargs) log0(f"Building model with config: {model_config_kwargs}") model_config = GPTConfig(**model_config_kwargs) - _patch_missing_keys(model_data, model_config) + _patch_missing_keys(model_data, model_config, device) with torch.device("meta"): model = GPT(model_config) # Load the model state