mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-15 10:39:08 +00:00
Merge f29051afc6 into b9b6ce137b
This commit is contained in:
commit
bdb28520de
|
|
@ -27,6 +27,10 @@ def _detect_compute_dtype():
|
|||
# fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
|
||||
# Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
|
||||
return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)"
|
||||
if torch.xpu.is_available():
|
||||
if torch.xpu.is_bf16_supported():
|
||||
return torch.bfloat16, "auto-detected: Intel GPU with bf16 support"
|
||||
return torch.float32, "auto-detected: Intel GPU without bf16 support, using fp32"
|
||||
return torch.float32, "auto-detected: no CUDA (CPU/MPS)"
|
||||
COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype()
|
||||
|
||||
|
|
@ -165,6 +169,8 @@ def autodetect_device_type():
|
|||
device_type = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
device_type = "mps"
|
||||
elif torch.xpu.is_available():
|
||||
device_type = "xpu"
|
||||
else:
|
||||
device_type = "cpu"
|
||||
print0(f"Autodetected device type: {device_type}")
|
||||
|
|
@ -173,11 +179,13 @@ def autodetect_device_type():
|
|||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
||||
assert device_type in ["cuda", "mps", "xpu", "cpu"], "Invalid device type atm"
|
||||
if device_type == "cuda":
|
||||
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
||||
if device_type == "mps":
|
||||
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
if device_type == "xpu":
|
||||
assert torch.xpu.is_available(), "Your PyTorch installation is not configured for XPU but device_type is 'xpu'"
|
||||
|
||||
# Reproducibility
|
||||
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
|
||||
|
|
|
|||
|
|
@ -183,7 +183,7 @@ class Engine:
|
|||
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
|
||||
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
|
||||
# In particular, the KVCache should allocate its tensors lazily
|
||||
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||
dtype = torch.bfloat16 if (device.type == "cuda" or device.type=="xpu") else torch.float32 # Including XPU seems to work for bf16 inference
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user