diff --git a/speedrun.sh b/speedrun.sh index 30a6a3a..1613a0e 100644 --- a/speedrun.sh +++ b/speedrun.sh @@ -32,11 +32,6 @@ if command -v nvidia-smi &> /dev/null; then elif [ -e /dev/kfd ]; then echo "AMD GPU detected. Installing ROCm dependencies..." EXTRAS="amd" - # Explicitly uninstall triton if present, as it conflicts with pytorch-triton-rocm - # and can cause "ImportError: cannot import name 'Config' from 'triton'" errors - # if the NVIDIA version of triton (e.g. 3.4.0) is accidentally installed. - source .venv/bin/activate 2>/dev/null || true - uv pip uninstall -q triton || true else echo "No dedicated GPU detected. Installing CPU dependencies..." EXTRAS="cpu" @@ -46,6 +41,13 @@ uv sync --extra $EXTRAS # activate venv so that `python` uses the project's venv instead of system python source .venv/bin/activate +# Explicitly uninstall triton if present, as it conflicts with pytorch-triton-rocm +# and can cause "ImportError: cannot import name 'Config' from 'triton'" errors +# if the NVIDIA version of triton (e.g. 3.4.0) is accidentally installed. +if [ "$EXTRAS" == "amd" ]; then + uv pip uninstall -q triton || true +fi + # ----------------------------------------------------------------------------- # wandb setup # If you wish to use wandb for logging (it's nice!, recommended).