From 293609a419314112fd9c3c520906c0cb7189162c Mon Sep 17 00:00:00 2001 From: suspicious-pineapple <90796271+suspicious-pineapple@users.noreply.github.com> Date: Mon, 30 Mar 2026 20:44:50 +0200 Subject: [PATCH 1/2] allowing BF16 with XPU doesnt seem to crash and burn At least it stops sample-every from crashing during pretraining --- nanochat/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index aa2e6a98..31587587 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -183,7 +183,7 @@ class Engine: # As a quick hack, we're making generate() function inherit and know about this repo-wise assumption. # I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase. # In particular, the KVCache should allocate its tensors lazily - dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + dtype = torch.bfloat16 if (device.type == "cuda" or device.type=="xpu") else torch.float32 # Including XPU seems to work for bf16 inference rng = torch.Generator(device=device) rng.manual_seed(seed) From f29051afc6ae66fd114218690ac0e31761f01d66 Mon Sep 17 00:00:00 2001 From: suspicious-pineapple <90796271+suspicious-pineapple@users.noreply.github.com> Date: Mon, 30 Mar 2026 20:56:28 +0200 Subject: [PATCH 2/2] support intel GPU, auto-detect BF16 for XPU --- nanochat/common.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/nanochat/common.py b/nanochat/common.py index bd14fd24..ae0674d5 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -27,6 +27,10 @@ def _detect_compute_dtype(): # fp16 training requires GradScaler (not yet implemented), so fall back to fp32. # Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing. return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)" + if torch.xpu.is_available(): + if torch.xpu.is_bf16_supported(): + return torch.bfloat16, "auto-detected: Intel GPU with bf16 support" + return torch.float32, "auto-detected: Intel GPU without bf16 support, using fp32" return torch.float32, "auto-detected: no CUDA (CPU/MPS)" COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype() @@ -165,6 +169,8 @@ def autodetect_device_type(): device_type = "cuda" elif torch.backends.mps.is_available(): device_type = "mps" + elif torch.xpu.is_available(): + device_type = "xpu" else: device_type = "cpu" print0(f"Autodetected device type: {device_type}") @@ -173,11 +179,13 @@ def autodetect_device_type(): def compute_init(device_type="cuda"): # cuda|cpu|mps """Basic initialization that we keep doing over and over, so make common.""" - assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" + assert device_type in ["cuda", "mps", "xpu", "cpu"], "Invalid device type atm" if device_type == "cuda": assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" if device_type == "mps": assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" + if device_type == "xpu": + assert torch.xpu.is_available(), "Your PyTorch installation is not configured for XPU but device_type is 'xpu'" # Reproducibility # Note that we set the global seeds here, but most of the code uses explicit rng objects.