Merge pull request #20 from LokiMetaSmith/fix-amd-triton-reinstall

Fix HIP invalid device ordinal error on multi-GPU setup
This commit is contained in:
Lawrence R Kincheloe III 2025-11-22 23:34:51 -06:00 committed by GitHub
commit c1fc4400b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 2 deletions

View File

@ -129,7 +129,10 @@ def get_dist_info():
def autodetect_device_type(): def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch.cuda.is_available() or (hasattr(torch.version, "hip") and torch.version.hip): # Check for CUDA/ROCm but also ensure we actually have devices (device_count > 0)
has_cuda = torch.cuda.is_available()
has_rocm = hasattr(torch.version, "hip") and torch.version.hip
if (has_cuda or has_rocm) and torch.cuda.device_count() > 0:
device_type = "cuda" device_type = "cuda"
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
device_type = "mps" device_type = "mps"

View File

@ -110,7 +110,15 @@ wait $DATASET_DOWNLOAD_PID
# Number of processes/GPUs to use # Number of processes/GPUs to use
# Auto-detect if we have GPUs (including ROCm) # Auto-detect if we have GPUs (including ROCm)
if python -c "import torch; exit(0) if torch.cuda.is_available() or (hasattr(torch.version, 'hip') and torch.version.hip) else exit(1)"; then if python -c "import torch; exit(0) if torch.cuda.is_available() or (hasattr(torch.version, 'hip') and torch.version.hip) else exit(1)"; then
NPROC_PER_NODE=8 # Detected GPU capability. Now get the actual count to avoid launching too many processes.
NPROC_PER_NODE=$(python -c "import torch; print(torch.cuda.device_count())")
# If for some reason it returns 0 (e.g. weird ROCm state), default to 1 to be safe.
if [ "$NPROC_PER_NODE" -eq "0" ]; then
echo "GPU detected but torch.cuda.device_count() is 0. Defaulting to NPROC_PER_NODE=1."
NPROC_PER_NODE=1
else
echo "Detected $NPROC_PER_NODE GPUs."
fi
else else
echo "No GPU detected. Defaulting to NPROC_PER_NODE=1 to avoid OOM and using multi-threading." echo "No GPU detected. Defaulting to NPROC_PER_NODE=1 to avoid OOM and using multi-threading."
NPROC_PER_NODE=1 NPROC_PER_NODE=1