From 46812c4160ba1e5771c2599c9a0c86501139bf63 Mon Sep 17 00:00:00 2001 From: Arthur Testard Date: Sat, 2 May 2026 19:07:19 +0200 Subject: [PATCH 1/2] Fix smear patches when ckpt loading --- nanochat/checkpoint_manager.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f71524ed..fcd8ae3e 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -38,6 +38,15 @@ def _patch_missing_keys(model_data, model_config): if "x0_lambdas" not in model_data: model_data["x0_lambdas"] = torch.zeros(n_layer) log0(f"Patching missing x0_lambdas in model data to 0.0") + if "smear_gate.weights" not in model_data: + model_data["smear_gate.weights"] = torch.ones(1, 24) + log0(f"Patching missing smear_gate.weights in model data to 1.0") + if "smear_lambda" not in model_data: + model_data["smear_lambda"] = torch.zeros(1) + log0(f"Patching missing smear_lambda in model data to 0.0") + if "backout_lambda" not in model_data: + model_data["backout_lambda"] = 0.2 * torch.ones(1) + log0(f"Patching missing backout_lambda in model data to 0.2") def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): if rank == 0: From 17ff2593f35a701a4649825fa58c6ac35a331524 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 5 May 2026 12:12:28 +0000 Subject: [PATCH 2/2] fix weight and device --- nanochat/checkpoint_manager.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index fcd8ae3e..d72de2fe 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -27,25 +27,25 @@ def _patch_missing_config_keys(model_config_kwargs): model_config_kwargs["window_pattern"] = "L" log0(f"Patching missing window_pattern in model config to 'L'") -def _patch_missing_keys(model_data, model_config): +def _patch_missing_keys(model_data, model_config, device): """Add default values for new parameters that may be missing in old checkpoints.""" n_layer = model_config.n_layer # resid_lambdas defaults to 1.0 (identity scaling) if "resid_lambdas" not in model_data: - model_data["resid_lambdas"] = torch.ones(n_layer) + model_data["resid_lambdas"] = torch.ones(n_layer).to(device) log0(f"Patching missing resid_lambdas in model data to 1.0") # x0_lambdas defaults to 0.0 (disabled) if "x0_lambdas" not in model_data: - model_data["x0_lambdas"] = torch.zeros(n_layer) + model_data["x0_lambdas"] = torch.zeros(n_layer).to(device) log0(f"Patching missing x0_lambdas in model data to 0.0") - if "smear_gate.weights" not in model_data: - model_data["smear_gate.weights"] = torch.ones(1, 24) - log0(f"Patching missing smear_gate.weights in model data to 1.0") + if "smear_gate.weight" not in model_data: + model_data["smear_gate.weight"] = torch.ones(1, 24).to(device) + log0(f"Patching missing smear_gate.weight in model data to 1.0") if "smear_lambda" not in model_data: - model_data["smear_lambda"] = torch.zeros(1) + model_data["smear_lambda"] = torch.zeros(1).to(device) log0(f"Patching missing smear_lambda in model data to 0.0") if "backout_lambda" not in model_data: - model_data["backout_lambda"] = 0.2 * torch.ones(1) + model_data["backout_lambda"] = 0.2 * torch.ones(1).to(device) log0(f"Patching missing backout_lambda in model data to 0.2") def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): @@ -105,7 +105,7 @@ def build_model(checkpoint_dir, step, device, phase): _patch_missing_config_keys(model_config_kwargs) log0(f"Building model with config: {model_config_kwargs}") model_config = GPTConfig(**model_config_kwargs) - _patch_missing_keys(model_data, model_config) + _patch_missing_keys(model_data, model_config, device) with torch.device("meta"): model = GPT(model_config) # Load the model state