mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-13 15:03:42 +00:00
This change adds support for ROCm and makes the codebase device-agnostic, allowing it to run on different hardware backends including ROCm, CUDA, and CPU. The key changes are: - Modified `pyproject.toml` to use ROCm-compatible PyTorch wheels and added the `pytorch-triton-rocm` dependency. - Refactored `nanochat/common.py` to dynamically detect the available hardware and set the device and distributed backend accordingly. - Updated all training, evaluation, and inference scripts to be device-agnostic, removing hardcoded CUDA references. - Adapted `speedrun.sh` for single-device execution by replacing `torchrun` with `python`. - Updated `nanochat/report.py` to provide more generic GPU information.
60 lines
1.2 KiB
TOML
60 lines
1.2 KiB
TOML
[project]
|
|
name = "nanochat"
|
|
version = "0.1.0"
|
|
description = "the minimal full-stack ChatGPT clone"
|
|
readme = "README.md"
|
|
requires-python = ">=3.10"
|
|
dependencies = [
|
|
"datasets>=4.0.0",
|
|
"fastapi>=0.117.1",
|
|
"files-to-prompt>=0.6",
|
|
"numpy==1.26.4",
|
|
"psutil>=7.1.0",
|
|
"regex>=2025.9.1",
|
|
"tiktoken>=0.11.0",
|
|
"tokenizers>=0.22.0",
|
|
"torch>=2.8.0",
|
|
"pytorch-triton-rocm==3.4.0; platform_machine == 'x86_64' and sys_platform == 'linux'",
|
|
"uvicorn>=0.36.0",
|
|
"wandb>=0.21.3",
|
|
]
|
|
|
|
[build-system]
|
|
requires = ["maturin>=1.7,<2.0"]
|
|
build-backend = "maturin"
|
|
|
|
# target torch to rocm 6.3
|
|
[tool.uv.sources]
|
|
torch = [
|
|
{ index = "pytorch-rocm63" },
|
|
]
|
|
pytorch-triton-rocm = [
|
|
{ index = "pytorch-rocm63" },
|
|
]
|
|
|
|
[[tool.uv.index]]
|
|
name = "pytorch-rocm63"
|
|
url = "https://download.pytorch.org/whl/rocm6.3"
|
|
explicit = true
|
|
|
|
[tool.maturin]
|
|
module-name = "rustbpe"
|
|
bindings = "pyo3"
|
|
python-source = "."
|
|
manifest-path = "rustbpe/Cargo.toml"
|
|
|
|
[project.optional-dependencies]
|
|
dev = [
|
|
"maturin>=1.9.4",
|
|
"pytest>=8.0.0",
|
|
]
|
|
|
|
[tool.pytest.ini_options]
|
|
markers = [
|
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
|
]
|
|
testpaths = ["tests"]
|
|
python_files = ["test_*.py"]
|
|
python_classes = ["Test*"]
|
|
python_functions = ["test_*"]
|