diff --git a/pyproject.toml b/pyproject.toml index ef3833a..eda6e94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "numpy==1.26.4", "psutil>=7.1.0", "regex>=2025.9.1", + "setuptools>=68", "tiktoken>=0.11.0", "tokenizers>=0.22.0", "torch>=2.8.0", @@ -22,15 +23,38 @@ dependencies = [ requires = ["maturin>=1.7,<2.0"] build-backend = "maturin" -# target torch to cuda 12.8 -[tool.uv.sources] -torch = [ - { index = "pytorch-cu128" }, -] +[project.optional-dependencies] +cpu = [ + "torch>=2.8.0", +] +gpu = [ + "torch>=2.8.0", +] + +[tool.uv] +conflicts = [ + [ + { extra = "cpu" }, + { extra = "gpu" }, + ], +] + -[[tool.uv.index]] -name = "pytorch-cu128" -url = "https://download.pytorch.org/whl/cu128" +# target torch to cuda 12.8 or CPU +[tool.uv.sources] +torch = [ + { index = "pytorch-cpu", extra = "cpu" }, + { index = "pytorch-cu128", extra = "gpu" }, +] + +[[tool.uv.index]] +name = "pytorch-cpu" +url = "https://download.pytorch.org/whl/cpu" +explicit = true + +[[tool.uv.index]] +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" explicit = true [tool.maturin] diff --git a/speedrun.sh b/speedrun.sh index d2498ee..d6d812d 100644 --- a/speedrun.sh +++ b/speedrun.sh @@ -23,7 +23,9 @@ command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh # create a .venv local virtual environment (if it doesn't exist) [ -d ".venv" ] || uv venv # install the repo dependencies -uv sync +# picking the NVIDIA GPU if available, otherwise CPU +choose_extra() { nvidia-smi -L >/dev/null 2>&1 && printf gpu || printf cpu; } +uv sync --extra "$(choose_extra)" # activate venv so that `python` uses the project's venv instead of system python source .venv/bin/activate