mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
Fix torch.dtype mismatching when running engine inline test.
This commit is contained in:
commit
f66a780f68
|
|
@ -17,8 +17,9 @@ import signal
|
|||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from collections import deque
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from contextlib import nullcontext
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Calculator tool helpers
|
||||
|
|
@ -328,6 +329,9 @@ if __name__ == "__main__":
|
|||
import time
|
||||
# init compute
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type()
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
bos_token_id = tokenizer.get_bos_token_id()
|
||||
|
|
@ -340,6 +344,7 @@ if __name__ == "__main__":
|
|||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
stream = model.generate(prompt_tokens, **kwargs)
|
||||
with autocast_ctx:
|
||||
for token in stream:
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
|
|
@ -355,6 +360,7 @@ if __name__ == "__main__":
|
|||
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in stream:
|
||||
token = token_column[0] # only print out the first row
|
||||
generated_tokens.append(token)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user