allowing BF16 with XPU doesnt seem to crash and burn

At least it stops sample-every from crashing during pretraining
This commit is contained in:
suspicious-pineapple 2026-03-30 20:44:50 +02:00 committed by GitHub
parent a445144d39
commit 293609a419
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -183,7 +183,7 @@ class Engine:
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
# In particular, the KVCache should allocate its tensors lazily
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
dtype = torch.bfloat16 if (device.type == "cuda" or device.type=="xpu") else torch.float32 # Including XPU seems to work for bf16 inference
rng = torch.Generator(device=device)
rng.manual_seed(seed)