remove transformers from toml. add it to gh Workflow. copy common.py from cpu|mps branch to check if gh wf tests are passing

This commit is contained in:
Sermet Pekin 2025-10-20 20:07:14 +03:00
parent 0768f67290
commit 4ed203a3ab
3 changed files with 36 additions and 14 deletions

View File

@ -36,6 +36,7 @@ jobs:
- name: Install dependencies with uv
run: |
uv pip install transformers>=4.0.0
uv pip install . --system
- name: Add nanochat to PYTHONPATH

View File

@ -89,28 +89,50 @@ def get_dist_info():
else:
return False, 0, 0, 1
def compute_init():
"""Basic initialization that we keep doing over and over, so make common."""
# Check if CUDA is available, otherwise fall back to CPU
def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch.cuda.is_available():
device = torch.device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device_type = "cuda"
elif torch.backends.mps.is_available():
device_type = "mps"
else:
device = torch.device("cpu")
torch.manual_seed(42)
logger.warning("CUDA is not available. Falling back to CPU.")
device_type = "cpu"
print0(f"Autodetected device type: {device_type}")
return device_type
def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common."""
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
# Reproducibility
torch.manual_seed(42)
if device_type == "cuda":
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# Precision
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp and torch.cuda.is_available():
if ddp and device_type == "cuda":
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
else:
device = torch.device(device_type) # mps|cpu
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
def compute_cleanup():
@ -125,4 +147,4 @@ class DummyWandb:
def log(self, *args, **kwargs):
pass
def finish(self):
pass
pass

View File

@ -14,7 +14,6 @@ dependencies = [
"tiktoken>=0.11.0",
"tokenizers>=0.22.0",
"torch>=2.8.0",
"transformers>=4.0.0",
"uvicorn>=0.36.0",
"wandb>=0.21.3",
]