Add Torch CPU package index support

Makes torch index an extra in pyproject.toml,  speedrun.sh selects GPU index if supported, with CPU fallback
This commit is contained in:
Luke Stanley 2025-10-14 00:32:50 +00:00
parent 8aca98777a
commit e27d2da3d9
2 changed files with 35 additions and 9 deletions

View File

@ -11,6 +11,7 @@ dependencies = [
"numpy==1.26.4",
"psutil>=7.1.0",
"regex>=2025.9.1",
"setuptools>=68",
"tiktoken>=0.11.0",
"tokenizers>=0.22.0",
"torch>=2.8.0",
@ -22,15 +23,38 @@ dependencies = [
requires = ["maturin>=1.7,<2.0"]
build-backend = "maturin"
# target torch to cuda 12.8
[tool.uv.sources]
torch = [
{ index = "pytorch-cu128" },
]
[project.optional-dependencies]
cpu = [
"torch>=2.8.0",
]
gpu = [
"torch>=2.8.0",
]
[tool.uv]
conflicts = [
[
{ extra = "cpu" },
{ extra = "gpu" },
],
]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
# target torch to cuda 12.8 or CPU
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu128", extra = "gpu" },
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[tool.maturin]

View File

@ -23,7 +23,9 @@ command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
# create a .venv local virtual environment (if it doesn't exist)
[ -d ".venv" ] || uv venv
# install the repo dependencies
uv sync
# picking the NVIDIA GPU if available, otherwise CPU
choose_extra() { nvidia-smi -L >/dev/null 2>&1 && printf gpu || printf cpu; }
uv sync --extra "$(choose_extra)"
# activate venv so that `python` uses the project's venv instead of system python
source .venv/bin/activate