fix engine test bug

This commit is contained in:
howardgao@outlook.com 2025-11-06 08:56:45 +08:00
parent c6b7ab7440
commit b399e43168

View File

@ -17,8 +17,9 @@ import signal
import warnings
from contextlib import contextmanager
from collections import deque
from nanochat.common import compute_init
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from contextlib import nullcontext
# -----------------------------------------------------------------------------
# Calculator tool helpers
@ -327,8 +328,11 @@ if __name__ == "__main__":
import time
# init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="eval")
model, tokenizer, meta = load_model("sft", device, phase="eval")
bos_token_id = tokenizer.get_bos_token_id()
# common hyperparameters
kwargs = dict(max_tokens=64, temperature=0.0)
@ -339,10 +343,11 @@ if __name__ == "__main__":
torch.cuda.synchronize()
t0 = time.time()
stream = model.generate(prompt_tokens, **kwargs)
for token in stream:
generated_tokens.append(token)
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
with autocast_ctx:
for token in stream:
generated_tokens.append(token)
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print()
torch.cuda.synchronize()
t1 = time.time()
@ -354,11 +359,12 @@ if __name__ == "__main__":
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
torch.cuda.synchronize()
t0 = time.time()
for token_column, token_masks in stream:
token = token_column[0] # only print out the first row
generated_tokens.append(token)
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
with autocast_ctx:
for token_column, token_masks in stream:
token = token_column[0] # only print out the first row
generated_tokens.append(token)
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print()
torch.cuda.synchronize()
t1 = time.time()