From 4ed203a3ab83bba17e764dfb62e06aca6a545f04 Mon Sep 17 00:00:00 2001 From: Sermet Pekin Date: Mon, 20 Oct 2025 20:07:14 +0300 Subject: [PATCH] remove transformers from toml. add it to gh Workflow. copy common.py from cpu|mps branch to check if gh wf tests are passing --- .github/workflows/base.yml | 1 + nanochat/common.py | 48 +++++++++++++++++++++++++++----------- pyproject.toml | 1 - 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index d3f952c..313ee4c 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -36,6 +36,7 @@ jobs: - name: Install dependencies with uv run: | + uv pip install transformers>=4.0.0 uv pip install . --system - name: Add nanochat to PYTHONPATH diff --git a/nanochat/common.py b/nanochat/common.py index 3cbd6b0..49a8b52 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -89,28 +89,50 @@ def get_dist_info(): else: return False, 0, 0, 1 -def compute_init(): - """Basic initialization that we keep doing over and over, so make common.""" - # Check if CUDA is available, otherwise fall back to CPU +def autodetect_device_type(): + # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU if torch.cuda.is_available(): - device = torch.device("cuda") - torch.manual_seed(42) - torch.cuda.manual_seed(42) + device_type = "cuda" + elif torch.backends.mps.is_available(): + device_type = "mps" else: - device = torch.device("cpu") - torch.manual_seed(42) - logger.warning("CUDA is not available. Falling back to CPU.") + device_type = "cpu" + print0(f"Autodetected device type: {device_type}") + return device_type + +def compute_init(device_type="cuda"): # cuda|cpu|mps + """Basic initialization that we keep doing over and over, so make common.""" + + assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" + if device_type == "cuda": + assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" + if device_type == "mps": + assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" + + # Reproducibility + torch.manual_seed(42) + if device_type == "cuda": + torch.cuda.manual_seed(42) + # skipping full reproducibility for now, possibly investigate slowdown later + # torch.use_deterministic_algorithms(True) + # Precision - torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls - # Distributed setup: Distributed Data Parallel (DDP), optional + if device_type == "cuda": + torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls + + # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() - if ddp and torch.cuda.is_available(): + if ddp and device_type == "cuda": device = torch.device("cuda", ddp_local_rank) torch.cuda.set_device(device) # make "cuda" default to this device dist.init_process_group(backend="nccl", device_id=device) dist.barrier() + else: + device = torch.device(device_type) # mps|cpu + if ddp_rank == 0: logger.info(f"Distributed world size: {ddp_world_size}") + return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device def compute_cleanup(): @@ -125,4 +147,4 @@ class DummyWandb: def log(self, *args, **kwargs): pass def finish(self): - pass + pass \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 578ebd4..ef3833a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,6 @@ dependencies = [ "tiktoken>=0.11.0", "tokenizers>=0.22.0", "torch>=2.8.0", - "transformers>=4.0.0", "uvicorn>=0.36.0", "wandb>=0.21.3", ]