Fix AMD Triton conflict in speedrun.sh

Explicitly uninstall `triton` when AMD GPU is detected.
The standard `triton` package (often pulled by NVIDIA dependencies or accident)
conflicts with `pytorch-triton-rocm` on AMD systems, causing
`ImportError: cannot import name 'Config' from 'triton'`.
This change ensures a clean ROCm environment by removing the conflicting package.
Also retains the `uv run --extra $EXTRAS` fix from the previous step.
This commit is contained in:
google-labs-jules[bot] 2025-11-23 03:38:56 +00:00
parent 83bb650b49
commit 8881ea84bf

View File

@ -32,6 +32,11 @@ if command -v nvidia-smi &> /dev/null; then
elif [ -e /dev/kfd ]; then
echo "AMD GPU detected. Installing ROCm dependencies..."
EXTRAS="amd"
# Explicitly uninstall triton if present, as it conflicts with pytorch-triton-rocm
# and can cause "ImportError: cannot import name 'Config' from 'triton'" errors
# if the NVIDIA version of triton (e.g. 3.4.0) is accidentally installed.
source .venv/bin/activate 2>/dev/null || true
uv pip uninstall -q triton || true
else
echo "No dedicated GPU detected. Installing CPU dependencies..."
EXTRAS="cpu"