From 46812c4160ba1e5771c2599c9a0c86501139bf63 Mon Sep 17 00:00:00 2001 From: Arthur Testard Date: Sat, 2 May 2026 19:07:19 +0200 Subject: [PATCH] Fix smear patches when ckpt loading --- nanochat/checkpoint_manager.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f71524ed..fcd8ae3e 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -38,6 +38,15 @@ def _patch_missing_keys(model_data, model_config): if "x0_lambdas" not in model_data: model_data["x0_lambdas"] = torch.zeros(n_layer) log0(f"Patching missing x0_lambdas in model data to 0.0") + if "smear_gate.weights" not in model_data: + model_data["smear_gate.weights"] = torch.ones(1, 24) + log0(f"Patching missing smear_gate.weights in model data to 1.0") + if "smear_lambda" not in model_data: + model_data["smear_lambda"] = torch.zeros(1) + log0(f"Patching missing smear_lambda in model data to 0.0") + if "backout_lambda" not in model_data: + model_data["backout_lambda"] = 0.2 * torch.ones(1) + log0(f"Patching missing backout_lambda in model data to 0.2") def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): if rank == 0: