Fix torch.dtype mismatching when running engine inline test.

This commit is contained in:
Andrej 2025-11-14 07:28:29 -08:00 committed by GitHub
commit f66a780f68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,8 +17,9 @@ import signal
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from collections import deque from collections import deque
from nanochat.common import compute_init from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from contextlib import nullcontext
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Calculator tool helpers # Calculator tool helpers
@ -328,6 +329,9 @@ if __name__ == "__main__":
import time import time
# init compute # init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# load the model and tokenizer # load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="eval") model, tokenizer, meta = load_model("base", device, phase="eval")
bos_token_id = tokenizer.get_bos_token_id() bos_token_id = tokenizer.get_bos_token_id()
@ -340,6 +344,7 @@ if __name__ == "__main__":
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
stream = model.generate(prompt_tokens, **kwargs) stream = model.generate(prompt_tokens, **kwargs)
with autocast_ctx:
for token in stream: for token in stream:
generated_tokens.append(token) generated_tokens.append(token)
chunk = tokenizer.decode([token]) chunk = tokenizer.decode([token])
@ -355,6 +360,7 @@ if __name__ == "__main__":
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
with autocast_ctx:
for token_column, token_masks in stream: for token_column, token_masks in stream:
token = token_column[0] # only print out the first row token = token_column[0] # only print out the first row
generated_tokens.append(token) generated_tokens.append(token)