mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-31 00:55:18 +00:00
gh wf multi platform
This commit is contained in:
parent
116b0e75fc
commit
8d75e112b6
37
.github/workflows/base.yml
vendored
37
.github/workflows/base.yml
vendored
|
|
@ -12,7 +12,12 @@ on:
|
|||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
python-version: ['3.10', '3.11']
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
|
@ -21,40 +26,26 @@ jobs:
|
|||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
python -m pip install uv
|
||||
|
||||
- name: Create virtual environment with uv
|
||||
run: |
|
||||
uv venv .venv
|
||||
|
||||
- name: Activate virtual environment
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
|
||||
- name: Install dependencies with uv
|
||||
run: |
|
||||
uv pip install . --system
|
||||
uv pip install .
|
||||
|
||||
- name: Add nanochat to PYTHONPATH
|
||||
- name: Set PYTHONPATH (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
echo "PYTHONPATH=$(pwd):$PYTHONPATH" >> $GITHUB_ENV
|
||||
|
||||
- name: Install pytest
|
||||
- name: Set PYTHONPATH (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
run: |
|
||||
python -m pip install pytest
|
||||
echo "PYTHONPATH=$PWD;$env:PYTHONPATH" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Run pytest
|
||||
run: |
|
||||
python -m pytest tests/ --maxfail=5 --disable-warnings
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
PYTHONPATH=$PYTHONPATH uv run pytest tests/ --maxfail=5 --disable-warnings
|
||||
|
|
|
|||
23
README.md
23
README.md
|
|
@ -109,6 +109,29 @@ This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` fo
|
|||
|
||||
Alternatively, I recommend using [DeepWiki](https://deepwiki.com/) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
|
||||
|
||||
## Installation
|
||||
|
||||
To install nanochat for development or experimentation:
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/karpathy/nanochat.git
|
||||
cd nanochat
|
||||
|
||||
# Install dependencies (requires uv)
|
||||
uv pip install .
|
||||
```
|
||||
|
||||
**For GPU users:** The default installation includes CPU-only PyTorch. If you want GPU acceleration, install the appropriate PyTorch version after installation:
|
||||
|
||||
```bash
|
||||
# For CUDA (Linux/Windows)
|
||||
uv pip install torch --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# For ROCm (AMD GPUs)
|
||||
uv pip install torch --index-url https://download.pytorch.org/whl/rocm6.0
|
||||
```
|
||||
|
||||
## Tests
|
||||
|
||||
I haven't invested too much here but some tests exist, especially for the tokenizer. Run e.g. as:
|
||||
|
|
|
|||
|
|
@ -89,28 +89,50 @@ def get_dist_info():
|
|||
else:
|
||||
return False, 0, 0, 1
|
||||
|
||||
def compute_init():
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
# Check if CUDA is available, otherwise fall back to CPU
|
||||
def autodetect_device_type():
|
||||
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
device_type = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
device_type = "mps"
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
torch.manual_seed(42)
|
||||
logger.warning("CUDA is not available. Falling back to CPU.")
|
||||
device_type = "cpu"
|
||||
print0(f"Autodetected device type: {device_type}")
|
||||
return device_type
|
||||
|
||||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
||||
if device_type == "cuda":
|
||||
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
||||
if device_type == "mps":
|
||||
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
|
||||
# Reproducibility
|
||||
torch.manual_seed(42)
|
||||
if device_type == "cuda":
|
||||
torch.cuda.manual_seed(42)
|
||||
# skipping full reproducibility for now, possibly investigate slowdown later
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
|
||||
# Precision
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional
|
||||
if device_type == "cuda":
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if ddp and torch.cuda.is_available():
|
||||
if ddp and device_type == "cuda":
|
||||
device = torch.device("cuda", ddp_local_rank)
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
else:
|
||||
device = torch.device(device_type) # mps|cpu
|
||||
|
||||
if ddp_rank == 0:
|
||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||
|
||||
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
|
||||
|
||||
def compute_cleanup():
|
||||
|
|
@ -125,4 +147,4 @@ class DummyWandb:
|
|||
def log(self, *args, **kwargs):
|
||||
pass
|
||||
def finish(self):
|
||||
pass
|
||||
pass
|
||||
|
|
@ -14,7 +14,6 @@ dependencies = [
|
|||
"tiktoken>=0.11.0",
|
||||
"tokenizers>=0.22.0",
|
||||
"torch>=2.0.0",
|
||||
"transformers>=4.0.0",
|
||||
"uvicorn>=0.36.0",
|
||||
"wandb>=0.21.3",
|
||||
]
|
||||
|
|
@ -23,17 +22,6 @@ dependencies = [
|
|||
requires = ["maturin>=1.7,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
# target torch to cpu/mps for macOS
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[tool.maturin]
|
||||
module-name = "rustbpe"
|
||||
bindings = "pyo3"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user