diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 99f260e..79ba998 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -20,6 +20,16 @@ def log0(message): if int(os.environ.get('RANK', 0)) == 0: logger.info(message) +def _patch_missing_keys(model_data, model_config): + """Add default values for new parameters that may be missing in old checkpoints.""" + n_layer = model_config.n_layer + # resid_lambdas defaults to 1.0 (identity scaling) + if "resid_lambdas" not in model_data: + model_data["resid_lambdas"] = torch.ones(n_layer) + # x0_lambdas defaults to 0.0 (disabled) + if "x0_lambdas" not in model_data: + model_data["x0_lambdas"] = torch.zeros(n_layer) + def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): if rank == 0: os.makedirs(checkpoint_dir, exist_ok=True) @@ -76,6 +86,7 @@ def build_model(checkpoint_dir, step, device, phase): model_config_kwargs = meta_data["model_config"] log0(f"Building model with config: {model_config_kwargs}") model_config = GPTConfig(**model_config_kwargs) + _patch_missing_keys(model_data, model_config) with torch.device("meta"): model = GPT(model_config) # Load the model state