diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 262ff97..a9327c4 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -72,7 +72,7 @@ def build_model(checkpoint_dir, step, device, phase): for k, v in model_data.items() } # Hack: fix torch compile issue, which prepends all keys with _orig_mod. - model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()} + model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} model_config_kwargs = meta_data["model_config"] log0(f"Building model with config: {model_config_kwargs}") model_config = GPTConfig(**model_config_kwargs)