Merge pull request #12 from LokiMetaSmith/fix-cpu-ddp-init

Fix ROCm/APU detection and CPU DDP OOM crash
This commit is contained in:
Lawrence R Kincheloe III 2025-11-22 03:19:05 -06:00 committed by GitHub
commit 36df08a5a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 5 deletions

View File

@ -129,7 +129,7 @@ def get_dist_info():
def autodetect_device_type(): def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch.cuda.is_available(): if torch.cuda.is_available() or (hasattr(torch.version, "hip") and torch.version.hip):
device_type = "cuda" device_type = "cuda"
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
device_type = "mps" device_type = "mps"
@ -143,7 +143,11 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
if device_type == "cuda": if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" if not torch.cuda.is_available():
if hasattr(torch.version, "hip") and torch.version.hip:
print("WARNING: CUDA not available but ROCm detected. Proceeding with 'cuda' device. You might need to set HSA_OVERRIDE_GFX_VERSION if on an APU.")
else:
assert False, "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
if device_type == "mps": if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"

View File

@ -17,7 +17,7 @@ from contextlib import nullcontext
import wandb import wandb
import torch import torch
if torch.cuda.is_available(): if torch.cuda.is_available() or (hasattr(torch.version, 'hip') and torch.version.hip):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from nanochat.gpt import GPT, GPTConfig from nanochat.gpt import GPT, GPTConfig

View File

@ -84,8 +84,8 @@ echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID wait $DATASET_DOWNLOAD_PID
# Number of processes/GPUs to use # Number of processes/GPUs to use
# Auto-detect if we have GPUs # Auto-detect if we have GPUs (including ROCm)
if python -c "import torch; exit(0) if torch.cuda.is_available() else exit(1)"; then if python -c "import torch; exit(0) if torch.cuda.is_available() or (hasattr(torch.version, 'hip') and torch.version.hip) else exit(1)"; then
NPROC_PER_NODE=8 NPROC_PER_NODE=8
else else
echo "No GPU detected. Defaulting to NPROC_PER_NODE=1 to avoid OOM and using multi-threading." echo "No GPU detected. Defaulting to NPROC_PER_NODE=1 to avoid OOM and using multi-threading."