diff --git a/nanochat/common.py b/nanochat/common.py index 9090097..62da9e4 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -129,7 +129,7 @@ def get_dist_info(): def autodetect_device_type(): # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU - if torch.cuda.is_available(): + if torch.cuda.is_available() or (hasattr(torch.version, "hip") and torch.version.hip): device_type = "cuda" elif torch.backends.mps.is_available(): device_type = "mps" @@ -143,7 +143,11 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" if device_type == "cuda": - assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" + if not torch.cuda.is_available(): + if hasattr(torch.version, "hip") and torch.version.hip: + print("WARNING: CUDA not available but ROCm detected. Proceeding with 'cuda' device. You might need to set HSA_OVERRIDE_GFX_VERSION if on an APU.") + else: + assert False, "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" if device_type == "mps": assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" diff --git a/scripts/base_train.py b/scripts/base_train.py index 65c7d32..74068ed 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -17,7 +17,7 @@ from contextlib import nullcontext import wandb import torch -if torch.cuda.is_available(): +if torch.cuda.is_available() or (hasattr(torch.version, 'hip') and torch.version.hip): os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" from nanochat.gpt import GPT, GPTConfig diff --git a/speedrun.sh b/speedrun.sh index de35fdb..4029413 100644 --- a/speedrun.sh +++ b/speedrun.sh @@ -84,8 +84,8 @@ echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # Number of processes/GPUs to use -# Auto-detect if we have GPUs -if python -c "import torch; exit(0) if torch.cuda.is_available() else exit(1)"; then +# Auto-detect if we have GPUs (including ROCm) +if python -c "import torch; exit(0) if torch.cuda.is_available() or (hasattr(torch.version, 'hip') and torch.version.hip) else exit(1)"; then NPROC_PER_NODE=8 else echo "No GPU detected. Defaulting to NPROC_PER_NODE=1 to avoid OOM and using multi-threading."