This commit is contained in:
suspicious-pineapple 2026-04-14 08:56:09 +08:00 committed by GitHub
commit bdb28520de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 2 deletions

View File

@ -27,6 +27,10 @@ def _detect_compute_dtype():
# fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
# Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)"
if torch.xpu.is_available():
if torch.xpu.is_bf16_supported():
return torch.bfloat16, "auto-detected: Intel GPU with bf16 support"
return torch.float32, "auto-detected: Intel GPU without bf16 support, using fp32"
return torch.float32, "auto-detected: no CUDA (CPU/MPS)"
COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype()
@ -165,6 +169,8 @@ def autodetect_device_type():
device_type = "cuda"
elif torch.backends.mps.is_available():
device_type = "mps"
elif torch.xpu.is_available():
device_type = "xpu"
else:
device_type = "cpu"
print0(f"Autodetected device type: {device_type}")
@ -173,11 +179,13 @@ def autodetect_device_type():
def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common."""
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
assert device_type in ["cuda", "mps", "xpu", "cpu"], "Invalid device type atm"
if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
if device_type == "xpu":
assert torch.xpu.is_available(), "Your PyTorch installation is not configured for XPU but device_type is 'xpu'"
# Reproducibility
# Note that we set the global seeds here, but most of the code uses explicit rng objects.

View File

@ -183,7 +183,7 @@ class Engine:
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
# In particular, the KVCache should allocate its tensors lazily
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
dtype = torch.bfloat16 if (device.type == "cuda" or device.type=="xpu") else torch.float32 # Including XPU seems to work for bf16 inference
rng = torch.Generator(device=device)
rng.manual_seed(seed)