diff --git a/nanochat/engine.py b/nanochat/engine.py index aa2e6a98..31587587 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -183,7 +183,7 @@ class Engine: # As a quick hack, we're making generate() function inherit and know about this repo-wise assumption. # I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase. # In particular, the KVCache should allocate its tensors lazily - dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + dtype = torch.bfloat16 if (device.type == "cuda" or device.type=="xpu") else torch.float32 # Including XPU seems to work for bf16 inference rng = torch.Generator(device=device) rng.manual_seed(seed)