From b399e431681d61dcced768c062b13a9089c0c21c Mon Sep 17 00:00:00 2001 From: "howardgao@outlook.com" Date: Thu, 6 Nov 2025 08:56:45 +0800 Subject: [PATCH] fix engine test bug --- nanochat/engine.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index 916a9cf..da85085 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -17,8 +17,9 @@ import signal import warnings from contextlib import contextmanager from collections import deque -from nanochat.common import compute_init +from nanochat.common import compute_init, autodetect_device_type from nanochat.checkpoint_manager import load_model +from contextlib import nullcontext # ----------------------------------------------------------------------------- # Calculator tool helpers @@ -327,8 +328,11 @@ if __name__ == "__main__": import time # init compute ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() + device_type = autodetect_device_type() + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() + # load the model and tokenizer - model, tokenizer, meta = load_model("base", device, phase="eval") + model, tokenizer, meta = load_model("sft", device, phase="eval") bos_token_id = tokenizer.get_bos_token_id() # common hyperparameters kwargs = dict(max_tokens=64, temperature=0.0) @@ -339,10 +343,11 @@ if __name__ == "__main__": torch.cuda.synchronize() t0 = time.time() stream = model.generate(prompt_tokens, **kwargs) - for token in stream: - generated_tokens.append(token) - chunk = tokenizer.decode([token]) - print(chunk, end="", flush=True) + with autocast_ctx: + for token in stream: + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time() @@ -354,11 +359,12 @@ if __name__ == "__main__": stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 torch.cuda.synchronize() t0 = time.time() - for token_column, token_masks in stream: - token = token_column[0] # only print out the first row - generated_tokens.append(token) - chunk = tokenizer.decode([token]) - print(chunk, end="", flush=True) + with autocast_ctx: + for token_column, token_masks in stream: + token = token_column[0] # only print out the first row + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time()