mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-16 01:02:18 +00:00
Fix HIP invalid device ordinal error on multi-GPU setup
The `speedrun.sh` script was hardcoding `NPROC_PER_NODE=8` if any GPU capability was detected, causing crashes on systems with fewer than 8 GPUs. Additionally, `nanochat/common.py` was autodetecting "cuda" even if `torch.cuda.device_count()` was 0 on some ROCm builds, leading to "invalid device ordinal" errors. Changes: - `speedrun.sh`: Dynamically set `NPROC_PER_NODE` using `torch.cuda.device_count()`. - `nanochat/common.py`: Ensure `autodetect_device_type` only returns "cuda" if devices are actually present.
This commit is contained in:
parent
b92647c580
commit
962deeefb6
|
|
@ -129,7 +129,10 @@ def get_dist_info():
|
||||||
|
|
||||||
def autodetect_device_type():
|
def autodetect_device_type():
|
||||||
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
||||||
if torch.cuda.is_available() or (hasattr(torch.version, "hip") and torch.version.hip):
|
# Check for CUDA/ROCm but also ensure we actually have devices (device_count > 0)
|
||||||
|
has_cuda = torch.cuda.is_available()
|
||||||
|
has_rocm = hasattr(torch.version, "hip") and torch.version.hip
|
||||||
|
if (has_cuda or has_rocm) and torch.cuda.device_count() > 0:
|
||||||
device_type = "cuda"
|
device_type = "cuda"
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
device_type = "mps"
|
device_type = "mps"
|
||||||
|
|
|
||||||
10
speedrun.sh
10
speedrun.sh
|
|
@ -110,7 +110,15 @@ wait $DATASET_DOWNLOAD_PID
|
||||||
# Number of processes/GPUs to use
|
# Number of processes/GPUs to use
|
||||||
# Auto-detect if we have GPUs (including ROCm)
|
# Auto-detect if we have GPUs (including ROCm)
|
||||||
if python -c "import torch; exit(0) if torch.cuda.is_available() or (hasattr(torch.version, 'hip') and torch.version.hip) else exit(1)"; then
|
if python -c "import torch; exit(0) if torch.cuda.is_available() or (hasattr(torch.version, 'hip') and torch.version.hip) else exit(1)"; then
|
||||||
NPROC_PER_NODE=8
|
# Detected GPU capability. Now get the actual count to avoid launching too many processes.
|
||||||
|
NPROC_PER_NODE=$(python -c "import torch; print(torch.cuda.device_count())")
|
||||||
|
# If for some reason it returns 0 (e.g. weird ROCm state), default to 1 to be safe.
|
||||||
|
if [ "$NPROC_PER_NODE" -eq "0" ]; then
|
||||||
|
echo "GPU detected but torch.cuda.device_count() is 0. Defaulting to NPROC_PER_NODE=1."
|
||||||
|
NPROC_PER_NODE=1
|
||||||
|
else
|
||||||
|
echo "Detected $NPROC_PER_NODE GPUs."
|
||||||
|
fi
|
||||||
else
|
else
|
||||||
echo "No GPU detected. Defaulting to NPROC_PER_NODE=1 to avoid OOM and using multi-threading."
|
echo "No GPU detected. Defaulting to NPROC_PER_NODE=1 to avoid OOM and using multi-threading."
|
||||||
NPROC_PER_NODE=1
|
NPROC_PER_NODE=1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user