diff --git a/pyproject.toml b/pyproject.toml index 8d2c8f0..864e055 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,15 @@ dependencies = [ "wandb>=0.21.3", ] +[project.optional-dependencies] +# Optional groups to control PyTorch source selection +cpu = [ + "torch>=2.8.0", +] +cuda = [ + "torch>=2.8.0", +] + [build-system] requires = ["maturin>=1.7,<2.0"] build-backend = "maturin" @@ -34,6 +43,31 @@ dev = [ "maturin>=1.9.4", "pytest>=8.0.0", ] +cuda = [ + "cuda", # refers to the above optional dependency group +] + +[tool.uv] +default-groups = ["cuda"] + +[tool.uv.sources] +torch = [ + { index = "pytorch-cpu", marker = "platform_system == 'Darwin'"}, + { index = "pytorch-cpu", extra = "cpu" }, + { index = "pytorch-cu128", extra = "cuda"}, +] + +# CPU-only index +[[tool.uv.index]] +name = "pytorch-cpu" +url = "https://download.pytorch.org/whl/cpu" +explicit = true + +# CUDA 12.8 index +[[tool.uv.index]] +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" +explicit = true [tool.pytest.ini_options] markers = [