diff --git a/nanochat/engine.py b/nanochat/engine.py index da85085..295d889 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -332,7 +332,7 @@ if __name__ == "__main__": autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() # load the model and tokenizer - model, tokenizer, meta = load_model("sft", device, phase="eval") + model, tokenizer, meta = load_model("base", device, phase="eval") bos_token_id = tokenizer.get_bos_token_id() # common hyperparameters kwargs = dict(max_tokens=64, temperature=0.0)