From 447567634c89e1e183249feaa613c546343780b0 Mon Sep 17 00:00:00 2001 From: Kirk Lin Date: Tue, 14 Oct 2025 12:08:44 +0800 Subject: [PATCH 1/4] feat: cross-platform support for CPU and GPU environments --- .gitignore | 1 + nanochat/common.py | 34 ++++++++++++++++++++++++---- nanochat/dataloader.py | 9 ++++---- nanochat/engine.py | 12 ++++++---- pyproject.toml | 26 ++++++++++++++++++++-- scripts/base_eval.py | 4 ++-- scripts/base_loss.py | 4 ++-- scripts/base_train.py | 17 ++++++++------ scripts/chat_cli.py | 4 ++-- scripts/chat_eval.py | 4 ++-- scripts/chat_rl.py | 4 ++-- scripts/chat_sft.py | 4 ++-- scripts/chat_web.py | 4 ++-- scripts/mid_train.py | 13 ++++++----- speedrun.sh | 50 ++++++++++++++++++++++++++++++++---------- 15 files changed, 139 insertions(+), 51 deletions(-) diff --git a/.gitignore b/.gitignore index b14ecde..69c2ee3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ __pycache__/ *.pyc rustbpe/target/ dev-ignore/ +.idea diff --git a/nanochat/common.py b/nanochat/common.py index 8b10df9..d48350f 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -79,6 +79,18 @@ def is_ddp(): # TODO is there a proper way return int(os.environ.get('RANK', -1)) != -1 +def is_macos(): + """Check if running on macOS.""" + import platform + return platform.system() == "Darwin" + +def get_device_type(): + """Get the device type string for autocast: 'cuda' or 'cpu'.""" + # Use CPU if on macOS or if CUDA is not available + if is_macos() or not torch.cuda.is_available(): + return "cpu" + return "cuda" + def get_dist_info(): if is_ddp(): assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) @@ -92,12 +104,14 @@ def get_dist_info(): def compute_init(): """Basic initialization that we keep doing over and over, so make common.""" - # CUDA is currently required - assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm" + # Check if CUDA is available + has_cuda = torch.cuda.is_available() + on_macos = is_macos() # Reproducibility torch.manual_seed(42) - torch.cuda.manual_seed(42) + if has_cuda: + torch.cuda.manual_seed(42) # skipping full reproducibility for now, possibly investigate slowdown later # torch.use_deterministic_algorithms(True) # torch.backends.cudnn.deterministic = True @@ -108,13 +122,25 @@ def compute_init(): # Distributed setup: Distributed Data Parallel (DDP), optional ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() - if ddp: + + # Determine device + if on_macos or not has_cuda: + device = torch.device("cpu") + if on_macos: + logger.info("Running on macOS with CPU") + else: + logger.info("Running on CPU (CUDA not available)") + if ddp: + logger.warning("DDP requested but will run on CPU") + elif ddp: device = torch.device("cuda", ddp_local_rank) torch.cuda.set_device(device) # make "cuda" default to this device dist.init_process_group(backend="nccl", device_id=device) dist.barrier() + logger.info(f"Running on CUDA with DDP (rank {ddp_rank}/{ddp_world_size})") else: device = torch.device("cuda") + logger.info("Running on CUDA (single GPU)") if ddp_rank == 0: logger.info(f"Distributed world size: {ddp_world_size}") diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index c1636b1..1201ec7 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -2,7 +2,7 @@ from collections import deque import torch -from nanochat.common import get_dist_info +from nanochat.common import get_dist_info, get_device_type from nanochat.dataset import parquets_iter_batched from nanochat.tokenizer import get_tokenizer @@ -43,7 +43,8 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz # Create the inputs/targets as 1D tensors inputs_cpu = scratch[:-1].to(dtype=torch.int32) targets_cpu = scratch[1:] - # Reshape to 2D and move to GPU async - inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True) - targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True) + # Reshape to 2D and move to device async + device_type = get_device_type() + inputs = inputs_cpu.view(B, T).to(device=device_type, dtype=torch.int32, non_blocking=True) + targets = targets_cpu.view(B, T).to(device=device_type, dtype=torch.int64, non_blocking=True) yield inputs, targets diff --git a/nanochat/engine.py b/nanochat/engine.py index de1253a..b2d9cdf 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -308,7 +308,8 @@ if __name__ == "__main__": prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id) # generate the reference sequence using the model.generate() function generated_tokens = [] - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() t0 = time.time() stream = model.generate(prompt_tokens, **kwargs) for token in stream: @@ -316,7 +317,8 @@ if __name__ == "__main__": chunk = tokenizer.decode([token]) print(chunk, end="", flush=True) print() - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() t1 = time.time() print(f"Reference time: {t1 - t0:.2f}s") reference_ids = generated_tokens @@ -324,7 +326,8 @@ if __name__ == "__main__": generated_tokens = [] engine = Engine(model, tokenizer) stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() t0 = time.time() for token_column, token_masks in stream: token = token_column[0] # only print out the first row @@ -332,7 +335,8 @@ if __name__ == "__main__": chunk = tokenizer.decode([token]) print(chunk, end="", flush=True) print() - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() t1 = time.time() print(f"Engine time: {t1 - t0:.2f}s") # compare the two sequences diff --git a/pyproject.toml b/pyproject.toml index ef3833a..2782d85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,12 +22,34 @@ dependencies = [ requires = ["maturin>=1.7,<2.0"] build-backend = "maturin" -# target torch to cuda 12.8 +# target torch to cuda 12.8 or CPU +[project.optional-dependencies] +cpu = [ + "torch>=2.8.0", +] +gpu = [ + "torch>=2.8.0", +] + +[tool.uv] +conflicts = [ + [ + { extra = "cpu" }, + { extra = "gpu" }, + ], +] + [tool.uv.sources] torch = [ - { index = "pytorch-cu128" }, + { index = "pytorch-cpu", extra = "cpu" }, + { index = "pytorch-cu128", extra = "gpu" }, ] +[[tool.uv.index]] +name = "pytorch-cpu" +url = "https://download.pytorch.org/whl/cpu" +explicit = true + [[tool.uv.index]] name = "pytorch-cu128" url = "https://download.pytorch.org/whl/cu128" diff --git a/scripts/base_eval.py b/scripts/base_eval.py index a566d49..73e82f2 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -19,7 +19,7 @@ import yaml import pandas as pd import torch -from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir +from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, get_device_type from nanochat.tokenizer import HuggingFaceTokenizer from nanochat.checkpoint_manager import load_model from nanochat.core_eval import evaluate_task @@ -122,7 +122,7 @@ def main(): # distributed / precision setup ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() - autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) # Load model and tokenizer from command line or from file system if len(sys.argv) >= 2: diff --git a/scripts/base_loss.py b/scripts/base_loss.py index ba3876d..b75ee58 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -9,7 +9,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss import os import torch from nanochat.checkpoint_manager import load_model -from nanochat.common import compute_init, print0, compute_cleanup +from nanochat.common import compute_init, print0, compute_cleanup, get_device_type from nanochat.dataloader import tokenizing_distributed_data_loader from nanochat.tokenizer import get_token_bytes from nanochat.loss_eval import evaluate_bpb @@ -28,7 +28,7 @@ model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=mode sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really # Set up the precision we'll run with -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) # Evaluate the loss on each split tokens_per_step = device_batch_size * sequence_len * ddp_world_size diff --git a/scripts/base_train.py b/scripts/base_train.py index b691ed4..ead1c09 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -16,7 +16,7 @@ import torch from nanochat.gpt import GPT, GPTConfig from nanochat.dataloader import tokenizing_distributed_data_loader -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, get_device_type from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -59,7 +59,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin # Compute init ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) # wandb logging init use_dummy_wandb = run == "dummy" or not master_process @@ -96,7 +96,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la with torch.device("meta"): model_config = GPTConfig(**model_config_kwargs) model = GPT(model_config) -model.to_empty(device="cuda") +model.to_empty(device=device) model.init_weights() orig_model = model # original, uncompiled model, for saving raw model state_dict model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through @@ -252,7 +252,8 @@ for step in range(num_iterations + 1): # ------------------------------------------------------------------------- # single training step # evaluate the gradient - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): with autocast_ctx: @@ -275,7 +276,8 @@ for step in range(num_iterations + 1): for opt in optimizers: opt.step() model.zero_grad(set_to_none=True) - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() t1 = time.time() dt = t1 - t0 # ------------------------------------------------------------------------- @@ -304,7 +306,8 @@ for step in range(num_iterations + 1): }) # print a few more stats -print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB") +peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024 if torch.cuda.is_available() else 0 +print0(f"Peak memory usage: {peak_mem:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") @@ -330,7 +333,7 @@ get_report().log(section="Base model training", data=[ "MFU %": f"{mfu:.2f}%", "Total training flops": f"{flops_so_far:e}", "Total training time": f"{total_training_time/60:.2f}m", - "Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB", + "Peak memory usage": f"{peak_mem:.2f}MiB", } ]) diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index 3a38147..e90d084 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -6,7 +6,7 @@ python -m scripts.chat_cli -i mid """ import argparse import torch -from nanochat.common import compute_init +from nanochat.common import compute_init, get_device_type from nanochat.engine import Engine from nanochat.checkpoint_manager import load_model @@ -21,7 +21,7 @@ args = parser.parse_args() # Init the model and tokenizer ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) # Special tokens for the chat state machine diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index df6a01a..1e80655 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -14,7 +14,7 @@ from functools import partial import torch import torch.distributed as dist -from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0 +from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, get_device_type from nanochat.checkpoint_manager import load_model from nanochat.engine import Engine @@ -195,7 +195,7 @@ if __name__ == "__main__": ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 - autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype) + autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=ptdtype) model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) engine = Engine(model, tokenizer) diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index af70bda..45c4daa 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -23,7 +23,7 @@ import wandb import torch import torch.distributed as dist -from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb +from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, get_device_type from nanochat.checkpoint_manager import save_checkpoint, load_model from nanochat.engine import Engine from tasks.gsm8k import GSM8K @@ -57,7 +57,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=dtype) # wandb logging init use_dummy_wandb = run == "dummy" or not master_process diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 8389deb..f022e97 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -17,7 +17,7 @@ import wandb import torch import torch.distributed as dist -from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb +from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, get_device_type from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import save_checkpoint from nanochat.engine import Engine @@ -63,7 +63,7 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() master_process = ddp_rank == 0 dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=dtype) # wandb logging init use_dummy_wandb = run == "dummy" or not master_process diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 1a4cfe2..412ccd6 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -16,7 +16,7 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from pydantic import BaseModel from typing import List, Optional, AsyncGenerator -from nanochat.common import compute_init +from nanochat.common import compute_init, get_device_type from nanochat.checkpoint_manager import load_model from nanochat.engine import Engine @@ -32,7 +32,7 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th args = parser.parse_args() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) class ChatMessage(BaseModel): role: str diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 202682d..07797c9 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -16,7 +16,7 @@ import time import wandb import torch -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, get_device_type from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -53,7 +53,7 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() master_process = ddp_rank == 0 dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=dtype) # wandb logging init use_dummy_wandb = run == "dummy" or not master_process @@ -214,7 +214,8 @@ while True: # ------------------------------------------------------------------------- # single training step # evaluate the gradient - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): with autocast_ctx: @@ -235,7 +236,8 @@ while True: for opt in optimizers: opt.step() model.zero_grad(set_to_none=True) - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() t1 = time.time() dt = t1 - t0 # ------------------------------------------------------------------------- @@ -267,7 +269,8 @@ while True: }) # print a few more stats -print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB") +peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024 if torch.cuda.is_available() else 0 +print0(f"Peak memory usage: {peak_mem:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") diff --git a/speedrun.sh b/speedrun.sh index d2498ee..a01d200 100644 --- a/speedrun.sh +++ b/speedrun.sh @@ -22,8 +22,19 @@ mkdir -p $NANOCHAT_BASE_DIR command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh # create a .venv local virtual environment (if it doesn't exist) [ -d ".venv" ] || uv venv -# install the repo dependencies -uv sync + +# Detect hardware and install appropriate PyTorch version +if [[ "$OSTYPE" == "darwin"* ]]; then + echo "Detected macOS - installing CPU version of PyTorch" + uv sync --extra cpu +elif command -v nvidia-smi &> /dev/null && nvidia-smi &> /dev/null; then + echo "Detected NVIDIA GPU(s) - installing CUDA version of PyTorch" + uv sync --extra gpu +else + echo "No GPU detected - installing CPU version of PyTorch" + uv sync --extra cpu +fi + # activate venv so that `python` uses the project's venv instead of system python source .venv/bin/activate @@ -70,6 +81,23 @@ python -m scripts.tok_train --max_chars=2000000000 # evaluate the tokenizer (report compression ratio etc.) python -m scripts.tok_eval +# ----------------------------------------------------------------------------- +# Platform detection for compute configuration + +# Check if running on macOS +if [[ "$OSTYPE" == "darwin"* ]]; then + echo "Detected macOS - running in CPU mode (single process)" + TORCHRUN_CMD="python" +# Check if CUDA/GPUs are available +elif command -v nvidia-smi &> /dev/null && nvidia-smi &> /dev/null; then + GPU_COUNT=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + echo "Detected $GPU_COUNT GPU(s) - running in GPU mode" + TORCHRUN_CMD="torchrun --standalone --nproc_per_node=$GPU_COUNT" +else + echo "No GPUs detected - running in CPU mode (single process)" + TORCHRUN_CMD="python" +fi + # ----------------------------------------------------------------------------- # Base model (pretraining) @@ -92,25 +120,25 @@ echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # pretrain the d20 model -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN +$TORCHRUN_CMD -m scripts.base_train -- --depth=20 --run=$WANDB_RUN # evaluate the model on a larger chunk of train/val data and draw some samples -torchrun --standalone --nproc_per_node=8 -m scripts.base_loss +$TORCHRUN_CMD -m scripts.base_loss # evaluate the model on CORE tasks -torchrun --standalone --nproc_per_node=8 -m scripts.base_eval +$TORCHRUN_CMD -m scripts.base_eval # ----------------------------------------------------------------------------- # Midtraining (teach the model conversation special tokens, tool use, multiple choice) # run midtraining and eval the model -torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid +$TORCHRUN_CMD -m scripts.mid_train -- --run=$WANDB_RUN +$TORCHRUN_CMD -m scripts.chat_eval -- -i mid # ----------------------------------------------------------------------------- # Supervised Finetuning (domain adaptation to each sequence all by itself per row) # train sft and re-eval right away (should see a small bump) -torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft +$TORCHRUN_CMD -m scripts.chat_sft -- --run=$WANDB_RUN +$TORCHRUN_CMD -m scripts.chat_eval -- -i sft # chat with the model over CLI! Leave out the -p to chat interactively # python -m scripts.chat_cli -p "Why is the sky blue?" @@ -123,9 +151,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft # (optional) # run reinforcement learning -# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN +# $TORCHRUN_CMD -m scripts.chat_rl -- --run=$WANDB_RUN # eval the RL model only on GSM8K -# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K +# $TORCHRUN_CMD -m scripts.chat_eval -- -i rl -a GSM8K # ----------------------------------------------------------------------------- # Generate the full report by putting together all the sections From 662ff7eb7ada2cb7ac64b71dc2493bd2a8ed50a5 Mon Sep 17 00:00:00 2001 From: Kirk Lin Date: Tue, 14 Oct 2025 12:22:57 +0800 Subject: [PATCH 2/4] feat: dynamic dtype selection --- nanochat/common.py | 7 +++++++ nanochat/gpt.py | 15 ++++++++++----- scripts/base_eval.py | 4 ++-- scripts/base_loss.py | 4 ++-- scripts/base_train.py | 4 ++-- scripts/chat_cli.py | 4 ++-- scripts/chat_web.py | 4 ++-- 7 files changed, 27 insertions(+), 15 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index d48350f..b3a717d 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -91,6 +91,13 @@ def get_device_type(): return "cpu" return "cuda" +def get_default_dtype(): + """Get the default dtype for training: bfloat16 on GPU, float32 on CPU.""" + # bfloat16 is well-supported on modern GPUs but may have issues on CPU + if torch.cuda.is_available(): + return torch.bfloat16 + return torch.float32 + def get_dist_info(): if is_ddp(): assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5a066b2..649c362 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -19,7 +19,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from nanochat.common import get_dist_info, print0 +from nanochat.common import get_dist_info, print0, get_default_dtype from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW @@ -169,8 +169,9 @@ class GPT(nn.Module): cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) - # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations - self.transformer.wte.to(dtype=torch.bfloat16) + # Cast the embeddings to the default dtype: optim can tolerate it and it saves memory: both in the model and the activations + default_dtype = get_default_dtype() + self.transformer.wte.to(dtype=default_dtype) def init_weights(self): self.apply(self._init_weights) @@ -210,7 +211,9 @@ class GPT(nn.Module): # calculate the rotation frequencies at each (time, channel) pair freqs = torch.outer(t, inv_freq) cos, sin = freqs.cos(), freqs.sin() - cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16 + # keep them in the default dtype (bfloat16 on GPU, float32 on CPU) + default_dtype = get_default_dtype() + cos, sin = cos.to(default_dtype), sin.to(default_dtype) cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting return cos, sin @@ -262,7 +265,9 @@ class GPT(nn.Module): # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" - assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" + # Rotary embeddings should match the default dtype for the platform + expected_dtype = get_default_dtype() + assert self.cos.dtype == expected_dtype, f"Rotary embeddings must be in {expected_dtype}, but got {self.cos.dtype}" # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache T0 = 0 if kv_cache is None else kv_cache.get_pos() cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length diff --git a/scripts/base_eval.py b/scripts/base_eval.py index 73e82f2..ef4d064 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -19,7 +19,7 @@ import yaml import pandas as pd import torch -from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, get_device_type +from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, get_device_type, get_default_dtype from nanochat.tokenizer import HuggingFaceTokenizer from nanochat.checkpoint_manager import load_model from nanochat.core_eval import evaluate_task @@ -122,7 +122,7 @@ def main(): # distributed / precision setup ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() - autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) + autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype()) # Load model and tokenizer from command line or from file system if len(sys.argv) >= 2: diff --git a/scripts/base_loss.py b/scripts/base_loss.py index b75ee58..e954688 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -9,7 +9,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss import os import torch from nanochat.checkpoint_manager import load_model -from nanochat.common import compute_init, print0, compute_cleanup, get_device_type +from nanochat.common import compute_init, print0, compute_cleanup, get_device_type, get_default_dtype from nanochat.dataloader import tokenizing_distributed_data_loader from nanochat.tokenizer import get_token_bytes from nanochat.loss_eval import evaluate_bpb @@ -28,7 +28,7 @@ model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=mode sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really # Set up the precision we'll run with -autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype()) # Evaluate the loss on each split tokens_per_step = device_batch_size * sequence_len * ddp_world_size diff --git a/scripts/base_train.py b/scripts/base_train.py index ead1c09..4875701 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -16,7 +16,7 @@ import torch from nanochat.gpt import GPT, GPTConfig from nanochat.dataloader import tokenizing_distributed_data_loader -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, get_device_type +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, get_device_type, get_default_dtype from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -59,7 +59,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin # Compute init ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. -autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype()) # wandb logging init use_dummy_wandb = run == "dummy" or not master_process diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index e90d084..2f8f9d7 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -6,7 +6,7 @@ python -m scripts.chat_cli -i mid """ import argparse import torch -from nanochat.common import compute_init, get_device_type +from nanochat.common import compute_init, get_device_type, get_default_dtype from nanochat.engine import Engine from nanochat.checkpoint_manager import load_model @@ -21,7 +21,7 @@ args = parser.parse_args() # Init the model and tokenizer ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() -autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype()) model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) # Special tokens for the chat state machine diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 412ccd6..ea120eb 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -16,7 +16,7 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from pydantic import BaseModel from typing import List, Optional, AsyncGenerator -from nanochat.common import compute_init, get_device_type +from nanochat.common import compute_init, get_device_type, get_default_dtype from nanochat.checkpoint_manager import load_model from nanochat.engine import Engine @@ -32,7 +32,7 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th args = parser.parse_args() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() -autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype()) class ChatMessage(BaseModel): role: str From 1c5dd2b7bae41ef784afd2b6971fae425a5baf34 Mon Sep 17 00:00:00 2001 From: Kirk Lin Date: Wed, 15 Oct 2025 11:04:57 +0800 Subject: [PATCH 3/4] fix: get safe autocast dtype --- nanochat/common.py | 47 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index b3a717d..7493f99 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -91,12 +91,49 @@ def get_device_type(): return "cpu" return "cuda" +def get_safe_autocast_dtype(device_type: str, preferred_dtype=None) -> torch.dtype: + """ + Return a safe dtype for autocast on the given device. + + Args: + device_type: "cuda" or "cpu" + preferred_dtype: Preferred dtype (torch.dtype, str, or None) + + Returns: + A dtype that is safe for autocast on the device + """ + # Parse the preferred dtype + if isinstance(preferred_dtype, torch.dtype): + dtype = preferred_dtype + elif isinstance(preferred_dtype, str): + dtype_map = { + "bfloat16": torch.bfloat16, "bf16": torch.bfloat16, + "float16": torch.float16, "fp16": torch.float16, "half": torch.float16, + "float32": torch.float32, "fp32": torch.float32, + } + dtype = dtype_map.get(preferred_dtype.lower()) + if dtype is None: + raise ValueError(f"Unknown dtype string: {preferred_dtype}") + elif preferred_dtype is None: + # Default: bfloat16 on CUDA, bfloat16 on CPU (both support it) + dtype = torch.bfloat16 + else: + raise TypeError(f"Invalid dtype type: {type(preferred_dtype)}") + + # Validate dtype compatibility with device + # CPU autocast only supports bfloat16 and float16 + if device_type == "cpu" and dtype == torch.float32: + logger.warning( + f"CPU autocast doesn't support {dtype}, using bfloat16 instead" + ) + dtype = torch.bfloat16 + + return dtype + def get_default_dtype(): - """Get the default dtype for training: bfloat16 on GPU, float32 on CPU.""" - # bfloat16 is well-supported on modern GPUs but may have issues on CPU - if torch.cuda.is_available(): - return torch.bfloat16 - return torch.float32 + """Get the default dtype for training based on available hardware.""" + device_type = get_device_type() + return get_safe_autocast_dtype(device_type) def get_dist_info(): if is_ddp(): From 837b43a504b1de13be2980df3a9efb6abb905006 Mon Sep 17 00:00:00 2001 From: Kirk Lin Date: Wed, 15 Oct 2025 11:17:53 +0800 Subject: [PATCH 4/4] feat: support mps --- nanochat/common.py | 58 +++++--- nanochat/dataloader.py | 21 ++- speedrun.sh | 17 ++- uv.lock | 320 +++++++++++++++++++++++++++++++++-------- 4 files changed, 322 insertions(+), 94 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index 7493f99..c505a59 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -85,18 +85,20 @@ def is_macos(): return platform.system() == "Darwin" def get_device_type(): - """Get the device type string for autocast: 'cuda' or 'cpu'.""" - # Use CPU if on macOS or if CUDA is not available - if is_macos() or not torch.cuda.is_available(): - return "cpu" - return "cuda" + """Get the device type string for autocast: 'cuda', 'mps', or 'cpu'.""" + if torch.cuda.is_available(): + return "cuda" + # Check for MPS (Metal Performance Shaders on Apple Silicon) + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + return "mps" + return "cpu" def get_safe_autocast_dtype(device_type: str, preferred_dtype=None) -> torch.dtype: """ Return a safe dtype for autocast on the given device. Args: - device_type: "cuda" or "cpu" + device_type: "cuda", "mps", or "cpu" preferred_dtype: Preferred dtype (torch.dtype, str, or None) Returns: @@ -115,18 +117,29 @@ def get_safe_autocast_dtype(device_type: str, preferred_dtype=None) -> torch.dty if dtype is None: raise ValueError(f"Unknown dtype string: {preferred_dtype}") elif preferred_dtype is None: - # Default: bfloat16 on CUDA, bfloat16 on CPU (both support it) - dtype = torch.bfloat16 + # Default: bfloat16 on CUDA, float16 on MPS, bfloat16 on CPU + if device_type == "cuda": + dtype = torch.bfloat16 + elif device_type == "mps": + dtype = torch.float16 # MPS works best with float16 + else: + dtype = torch.bfloat16 else: raise TypeError(f"Invalid dtype type: {type(preferred_dtype)}") # Validate dtype compatibility with device - # CPU autocast only supports bfloat16 and float16 if device_type == "cpu" and dtype == torch.float32: logger.warning( f"CPU autocast doesn't support {dtype}, using bfloat16 instead" ) dtype = torch.bfloat16 + elif device_type == "mps": + # MPS has limited dtype support, prefer float16 + if dtype not in {torch.float16, torch.float32}: + logger.warning( + f"MPS autocast works best with float16, converting from {dtype}" + ) + dtype = torch.float16 return dtype @@ -148,8 +161,9 @@ def get_dist_info(): def compute_init(): """Basic initialization that we keep doing over and over, so make common.""" - # Check if CUDA is available + # Check available hardware has_cuda = torch.cuda.is_available() + has_mps = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() on_macos = is_macos() # Reproducibility @@ -168,23 +182,25 @@ def compute_init(): ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() # Determine device - if on_macos or not has_cuda: - device = torch.device("cpu") - if on_macos: - logger.info("Running on macOS with CPU") - else: - logger.info("Running on CPU (CUDA not available)") - if ddp: - logger.warning("DDP requested but will run on CPU") - elif ddp: + if has_cuda and ddp: device = torch.device("cuda", ddp_local_rank) - torch.cuda.set_device(device) # make "cuda" default to this device + torch.cuda.set_device(device) dist.init_process_group(backend="nccl", device_id=device) dist.barrier() logger.info(f"Running on CUDA with DDP (rank {ddp_rank}/{ddp_world_size})") - else: + elif has_cuda: device = torch.device("cuda") logger.info("Running on CUDA (single GPU)") + elif has_mps: + device = torch.device("mps") + logger.info("Running on MPS (Apple Silicon GPU)") + if ddp: + logger.warning("DDP not supported on MPS, running single process") + else: + device = torch.device("cpu") + logger.info("Running on CPU") + if ddp: + logger.warning("DDP requested but will run on CPU") if ddp_rank == 0: logger.info(f"Distributed world size: {ddp_world_size}") diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 1201ec7..62fcd85 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -16,7 +16,11 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz bos_token = tokenizer.get_bos_token_id() # scratch buffer holds the tokens for one iteration token_buffer = deque() # we stream tokens on the right and pop from the left - scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True) + + # Check if we're using MPS - it doesn't support pin_memory + device_type = get_device_type() + use_pin_memory = device_type == "cuda" # Only pin memory for CUDA + scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=use_pin_memory) # infinite iterator over document batches def document_batches(): @@ -38,13 +42,16 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz token_buffer.extend(tokens) batch_index += 1 # Move tokens from the deque into the scratch buffer - for i in range(needed_tokens): - scratch[i] = token_buffer.popleft() + # Build a list first to avoid MPS compatibility issues + tokens_list = [token_buffer.popleft() for _ in range(needed_tokens)] + scratch[:] = torch.tensor(tokens_list, dtype=torch.int64) + # Create the inputs/targets as 1D tensors inputs_cpu = scratch[:-1].to(dtype=torch.int32) targets_cpu = scratch[1:] - # Reshape to 2D and move to device async - device_type = get_device_type() - inputs = inputs_cpu.view(B, T).to(device=device_type, dtype=torch.int32, non_blocking=True) - targets = targets_cpu.view(B, T).to(device=device_type, dtype=torch.int64, non_blocking=True) + # Reshape to 2D and move to device + # For MPS, non_blocking doesn't apply; for CUDA it helps performance + non_blocking = (device_type == "cuda") + inputs = inputs_cpu.view(B, T).to(device=device_type, dtype=torch.int32, non_blocking=non_blocking) + targets = targets_cpu.view(B, T).to(device=device_type, dtype=torch.int64, non_blocking=non_blocking) yield inputs, targets diff --git a/speedrun.sh b/speedrun.sh index a01d200..29f3305 100644 --- a/speedrun.sh +++ b/speedrun.sh @@ -88,14 +88,17 @@ python -m scripts.tok_eval if [[ "$OSTYPE" == "darwin"* ]]; then echo "Detected macOS - running in CPU mode (single process)" TORCHRUN_CMD="python" + ARGS_SEP="" # No separator needed for python # Check if CUDA/GPUs are available elif command -v nvidia-smi &> /dev/null && nvidia-smi &> /dev/null; then GPU_COUNT=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) echo "Detected $GPU_COUNT GPU(s) - running in GPU mode" TORCHRUN_CMD="torchrun --standalone --nproc_per_node=$GPU_COUNT" + ARGS_SEP="--" # Separator needed for torchrun else echo "No GPUs detected - running in CPU mode (single process)" TORCHRUN_CMD="python" + ARGS_SEP="" # No separator needed for python fi # ----------------------------------------------------------------------------- @@ -120,7 +123,7 @@ echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # pretrain the d20 model -$TORCHRUN_CMD -m scripts.base_train -- --depth=20 --run=$WANDB_RUN +$TORCHRUN_CMD -m scripts.base_train $ARGS_SEP --depth=20 --run=$WANDB_RUN # evaluate the model on a larger chunk of train/val data and draw some samples $TORCHRUN_CMD -m scripts.base_loss # evaluate the model on CORE tasks @@ -130,15 +133,15 @@ $TORCHRUN_CMD -m scripts.base_eval # Midtraining (teach the model conversation special tokens, tool use, multiple choice) # run midtraining and eval the model -$TORCHRUN_CMD -m scripts.mid_train -- --run=$WANDB_RUN -$TORCHRUN_CMD -m scripts.chat_eval -- -i mid +$TORCHRUN_CMD -m scripts.mid_train $ARGS_SEP --run=$WANDB_RUN +$TORCHRUN_CMD -m scripts.chat_eval $ARGS_SEP -i mid # ----------------------------------------------------------------------------- # Supervised Finetuning (domain adaptation to each sequence all by itself per row) # train sft and re-eval right away (should see a small bump) -$TORCHRUN_CMD -m scripts.chat_sft -- --run=$WANDB_RUN -$TORCHRUN_CMD -m scripts.chat_eval -- -i sft +$TORCHRUN_CMD -m scripts.chat_sft $ARGS_SEP --run=$WANDB_RUN +$TORCHRUN_CMD -m scripts.chat_eval $ARGS_SEP -i sft # chat with the model over CLI! Leave out the -p to chat interactively # python -m scripts.chat_cli -p "Why is the sky blue?" @@ -151,9 +154,9 @@ $TORCHRUN_CMD -m scripts.chat_eval -- -i sft # (optional) # run reinforcement learning -# $TORCHRUN_CMD -m scripts.chat_rl -- --run=$WANDB_RUN +# $TORCHRUN_CMD -m scripts.chat_rl $ARGS_SEP --run=$WANDB_RUN # eval the RL model only on GSM8K -# $TORCHRUN_CMD -m scripts.chat_eval -- -i rl -a GSM8K +# $TORCHRUN_CMD -m scripts.chat_eval $ARGS_SEP -i rl -a GSM8K # ----------------------------------------------------------------------------- # Generate the full report by putting together all the sections diff --git a/uv.lock b/uv.lock index 7636b81..becdbc0 100644 --- a/uv.lock +++ b/uv.lock @@ -2,13 +2,32 @@ version = 1 revision = 3 requires-python = ">=3.10" resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'linux'", - "python_full_version >= '3.12' and sys_platform != 'linux'", - "python_full_version == '3.11.*' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and sys_platform != 'linux'", - "python_full_version < '3.11' and sys_platform == 'linux'", - "python_full_version < '3.11' and sys_platform != 'linux'", + "python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", + "python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", + "python_full_version >= '3.12' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version >= '3.12' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", ] +conflicts = [[ + { package = "nanochat", extra = "cpu" }, + { package = "nanochat", extra = "gpu" }, +]] [[package]] name = "aiohappyeyeballs" @@ -26,7 +45,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohappyeyeballs" }, { name = "aiosignal" }, - { name = "async-timeout", marker = "python_full_version < '3.11'" }, + { name = "async-timeout", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, { name = "attrs" }, { name = "frozenlist" }, { name = "multidict" }, @@ -111,7 +130,7 @@ version = "1.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "frozenlist" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7", size = 25007, upload-time = "2025-07-03T22:54:43.528Z" } wheels = [ @@ -132,10 +151,10 @@ name = "anyio" version = "4.10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, { name = "idna" }, { name = "sniffio" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/f1/b4/636b3b65173d3ce9a38ef5f0522789614e590dab6a8d505340a4efe4c567/anyio-4.10.0.tar.gz", hash = "sha256:3f3fae35c96039744587aa5b8371e7e8e603c0702999535961dd336026973ba6", size = 213252, upload-time = "2025-08-04T08:54:26.451Z" } wheels = [ @@ -238,7 +257,7 @@ name = "click" version = "8.2.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" } wheels = [ @@ -292,7 +311,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions", marker = "python_full_version < '3.12' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } wheels = [ @@ -497,7 +516,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, - { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, + { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, { name = "packaging" }, { name = "pyyaml" }, { name = "requests" }, @@ -602,7 +621,7 @@ name = "maturin" version = "1.9.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/13/7c/b11b870fc4fd84de2099906314ce45488ae17be32ff5493519a6cddc518a/maturin-1.9.4.tar.gz", hash = "sha256:235163a0c99bc6f380fb8786c04fd14dcf6cd622ff295ea3de525015e6ac40cf", size = 213647, upload-time = "2025-08-27T11:37:57.079Z" } wheels = [ @@ -635,7 +654,7 @@ name = "multidict" version = "6.6.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/69/7f/0652e6ed47ab288e3756ea9c0df8b14950781184d4bd7883f4d87dd41245/multidict-6.6.4.tar.gz", hash = "sha256:d2d4e4787672911b48350df02ed3fa3fffdc2f2e8ca06dd6afdf34189b76a9dd", size = 101843, upload-time = "2025-08-11T12:08:48.217Z" } wheels = [ @@ -763,11 +782,23 @@ dependencies = [ { name = "regex" }, { name = "tiktoken" }, { name = "tokenizers" }, - { name = "torch" }, + { name = "torch", version = "2.8.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, + { name = "torch", version = "2.8.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" }, { name = "uvicorn" }, { name = "wandb" }, ] +[package.optional-dependencies] +cpu = [ + { name = "torch", version = "2.8.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "torch", version = "2.8.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, +] +gpu = [ + { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" } }, +] + [package.dev-dependencies] dev = [ { name = "maturin" }, @@ -784,10 +815,13 @@ requires-dist = [ { name = "regex", specifier = ">=2025.9.1" }, { name = "tiktoken", specifier = ">=0.11.0" }, { name = "tokenizers", specifier = ">=0.22.0" }, - { name = "torch", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "torch", specifier = ">=2.8.0" }, + { name = "torch", marker = "extra == 'cpu'", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } }, + { name = "torch", marker = "extra == 'gpu'", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } }, { name = "uvicorn", specifier = ">=0.36.0" }, { name = "wandb", specifier = ">=0.21.3" }, ] +provides-extras = ["cpu", "gpu"] [package.metadata.requires-dev] dev = [ @@ -800,8 +834,13 @@ name = "networkx" version = "3.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.11' and sys_platform == 'linux'", - "python_full_version < '3.11' and sys_platform != 'linux'", + "python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", ] sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } wheels = [ @@ -813,10 +852,20 @@ name = "networkx" version = "3.5" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'linux'", - "python_full_version >= '3.12' and sys_platform != 'linux'", - "python_full_version == '3.11.*' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and sys_platform != 'linux'", + "python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", + "python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'", + "python_full_version >= '3.12' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version >= '3.12' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", + "python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'", ] sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" } wheels = [ @@ -860,7 +909,9 @@ name = "nvidia-cublas-cu12" version = "12.8.4.1" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124, upload-time = "2025-03-07T01:43:53.556Z" }, { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, + { url = "https://files.pythonhosted.org/packages/70/61/7d7b3c70186fb651d0fbd35b01dbfc8e755f69fd58f817f3d0f642df20c3/nvidia_cublas_cu12-12.8.4.1-py3-none-win_amd64.whl", hash = "sha256:47e9b82132fa8d2b4944e708049229601448aaad7e6f296f630f2d1a32de35af", size = 567544208, upload-time = "2025-03-07T01:53:30.535Z" }, ] [[package]] @@ -868,7 +919,9 @@ name = "nvidia-cuda-cupti-cu12" version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318, upload-time = "2025-03-07T01:40:10.421Z" }, { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, + { url = "https://files.pythonhosted.org/packages/41/bc/83f5426095d93694ae39fe1311431b5d5a9bb82e48bf0dd8e19be2765942/nvidia_cuda_cupti_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:bb479dcdf7e6d4f8b0b01b115260399bf34154a1a2e9fe11c85c517d87efd98e", size = 7015759, upload-time = "2025-03-07T01:51:11.355Z" }, ] [[package]] @@ -877,6 +930,8 @@ version = "12.8.93" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, + { url = "https://files.pythonhosted.org/packages/eb/d1/e50d0acaab360482034b84b6e27ee83c6738f7d32182b987f9c7a4e32962/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fc1fec1e1637854b4c0a65fb9a8346b51dd9ee69e61ebaccc82058441f15bce8", size = 43106076, upload-time = "2025-03-07T01:41:59.817Z" }, + { url = "https://files.pythonhosted.org/packages/45/51/52a3d84baa2136cc8df15500ad731d74d3a1114d4c123e043cb608d4a32b/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-win_amd64.whl", hash = "sha256:7a4b6b2904850fe78e0bd179c4b655c404d4bb799ef03ddc60804247099ae909", size = 73586838, upload-time = "2025-03-07T01:52:13.483Z" }, ] [[package]] @@ -884,7 +939,9 @@ name = "nvidia-cuda-runtime-cu12" version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265, upload-time = "2025-03-07T01:39:43.533Z" }, { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, + { url = "https://files.pythonhosted.org/packages/30/a5/a515b7600ad361ea14bfa13fb4d6687abf500adc270f19e89849c0590492/nvidia_cuda_runtime_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:c0c6027f01505bfed6c3b21ec546f69c687689aad5f1a377554bc6ca4aa993a8", size = 944318, upload-time = "2025-03-07T01:51:01.794Z" }, ] [[package]] @@ -892,10 +949,12 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" }, { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, + { url = "https://files.pythonhosted.org/packages/3d/90/0bd6e586701b3a890fd38aa71c387dab4883d619d6e5ad912ccbd05bfd67/nvidia_cudnn_cu12-9.10.2.21-py3-none-win_amd64.whl", hash = "sha256:c6288de7d63e6cf62988f0923f96dc339cea362decb1bf5b3141883392a7d65e", size = 692992268, upload-time = "2025-06-06T21:55:18.114Z" }, ] [[package]] @@ -903,10 +962,12 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, + { url = "https://files.pythonhosted.org/packages/7d/ec/ce1629f1e478bb5ccd208986b5f9e0316a78538dd6ab1d0484f012f8e2a1/nvidia_cufft_cu12-11.3.3.83-py3-none-win_amd64.whl", hash = "sha256:7a64a98ef2a7c47f905aaf8931b69a3a43f27c55530c698bb2ed7c75c0b42cb7", size = 192216559, upload-time = "2025-03-07T01:53:57.106Z" }, ] [[package]] @@ -915,6 +976,7 @@ version = "1.13.1.3" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, + { url = "https://files.pythonhosted.org/packages/1e/f5/5607710447a6fe9fd9b3283956fceeee8a06cda1d2f56ce31371f595db2a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:4beb6d4cce47c1a0f1013d72e02b0994730359e17801d395bdcbf20cfb3bb00a", size = 1120705, upload-time = "2025-03-07T01:45:41.434Z" }, ] [[package]] @@ -922,7 +984,9 @@ name = "nvidia-curand-cu12" version = "10.3.9.90" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/45/5e/92aa15eca622a388b80fbf8375d4760738df6285b1e92c43d37390a33a9a/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dfab99248034673b779bc6decafdc3404a8a6f502462201f2f31f11354204acd", size = 63625754, upload-time = "2025-03-07T01:46:10.735Z" }, { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, + { url = "https://files.pythonhosted.org/packages/b9/75/70c05b2f3ed5be3bb30b7102b6eb78e100da4bbf6944fd6725c012831cab/nvidia_curand_cu12-10.3.9.90-py3-none-win_amd64.whl", hash = "sha256:f149a8ca457277da854f89cf282d6ef43176861926c7ac85b2a0fbd237c587ec", size = 62765309, upload-time = "2025-03-07T01:54:20.478Z" }, ] [[package]] @@ -930,12 +994,14 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "nvidia-cusparse-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, + { url = "https://files.pythonhosted.org/packages/13/c0/76ca8551b8a84146ffa189fec81c26d04adba4bc0dbe09cd6e6fd9b7de04/nvidia_cusolver_cu12-11.7.3.90-py3-none-win_amd64.whl", hash = "sha256:4a550db115fcabc4d495eb7d39ac8b58d4ab5d8e63274d3754df1c0ad6a22d34", size = 256720438, upload-time = "2025-03-07T01:54:39.898Z" }, ] [[package]] @@ -943,10 +1009,12 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, + { url = "https://files.pythonhosted.org/packages/62/07/f3b2ad63f8e3d257a599f422ae34eb565e70c41031aecefa3d18b62cabd1/nvidia_cusparse_cu12-12.5.8.93-py3-none-win_amd64.whl", hash = "sha256:9a33604331cb2cac199f2e7f5104dfbb8a5a898c367a53dfda9ff2acb6b6b4dd", size = 284937404, upload-time = "2025-03-07T01:55:07.742Z" }, ] [[package]] @@ -954,7 +1022,9 @@ name = "nvidia-cusparselt-cu12" version = "0.7.1" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/73/b9/598f6ff36faaece4b3c50d26f50e38661499ff34346f00e057760b35cc9d/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8878dce784d0fac90131b6817b607e803c36e629ba34dc5b433471382196b6a5", size = 283835557, upload-time = "2025-02-26T00:16:54.265Z" }, { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, + { url = "https://files.pythonhosted.org/packages/2f/d8/a6b0d0d0c2435e9310f3e2bb0d9c9dd4c33daef86aa5f30b3681defd37ea/nvidia_cusparselt_cu12-0.7.1-py3-none-win_amd64.whl", hash = "sha256:f67fbb5831940ec829c9117b7f33807db9f9678dc2a617fbe781cac17b4e1075", size = 271020911, upload-time = "2025-02-26T00:14:47.204Z" }, ] [[package]] @@ -962,6 +1032,7 @@ name = "nvidia-nccl-cu12" version = "2.27.3" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/7b/8354b784cf73b0ba51e566b4baba3ddd44fe8288a3d39ef1e06cd5417226/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9ddf1a245abc36c550870f26d537a9b6087fb2e2e3d6e0ef03374c6fd19d984f", size = 322397768, upload-time = "2025-06-03T21:57:30.234Z" }, { url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134, upload-time = "2025-06-03T21:58:04.013Z" }, ] @@ -971,6 +1042,8 @@ version = "12.8.93" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, + { url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204, upload-time = "2025-03-07T01:49:43.612Z" }, + { url = "https://files.pythonhosted.org/packages/ed/d7/34f02dad2e30c31b10a51f6b04e025e5dd60e5f936af9045a9b858a05383/nvidia_nvjitlink_cu12-12.8.93-py3-none-win_amd64.whl", hash = "sha256:bd93fbeeee850917903583587f4fc3a4eafa022e34572251368238ab5e6bd67f", size = 268553710, upload-time = "2025-03-07T01:56:24.13Z" }, ] [[package]] @@ -978,7 +1051,9 @@ name = "nvidia-nvtx-cu12" version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/10/c0/1b303feea90d296f6176f32a2a70b5ef230f9bdeb3a72bddb0dc922dc137/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7ad891da111ebafbf7e015d34879f7112832fc239ff0d7d776b6cb685274615", size = 91161, upload-time = "2025-03-07T01:42:23.922Z" }, { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, + { url = "https://files.pythonhosted.org/packages/9f/99/4c9c0c329bf9fc125008c3b54c7c94c0023518d06fc025ae36431375e1fe/nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e", size = 56492, upload-time = "2025-03-07T01:52:24.69Z" }, ] [[package]] @@ -1334,13 +1409,13 @@ name = "pytest" version = "8.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, { name = "iniconfig" }, { name = "packaging" }, { name = "pluggy" }, { name = "pygments" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } wheels = [ @@ -1561,7 +1636,7 @@ version = "0.48.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a7/a5/d6f429d43394057b67a6b5bbe6eae2f77a6bf7459d961fdb224bf206eee6/starlette-0.48.0.tar.gz", hash = "sha256:7e8cee469a8ab2352911528110ce9088fdc6a37d9876926e73da7ce4aa4c7a46", size = 2652949, upload-time = "2025-09-13T08:41:05.699Z" } wheels = [ @@ -1680,34 +1755,161 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257, upload-time = "2024-11-27T22:38:35.385Z" }, ] +[[package]] +name = "torch" +version = "2.8.0" +source = { registry = "https://download.pytorch.org/whl/cpu" } +resolution-markers = [ + "python_full_version >= '3.12' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version < '3.11' and sys_platform == 'darwin'", +] +dependencies = [ + { name = "filelock", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "fsspec", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "jinja2", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version >= '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.12' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "sympy", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "typing-extensions", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, +] +wheels = [ + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:a467b49fe893a6a6cce89e3aee556edfdc64a722d7195fdfdd75cec9dea13779" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:3d05017d19bc99741288e458888283a44b0ee881d53f05f72f8b1cfea8998122" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:a47b7986bee3f61ad217d8a8ce24605809ab425baf349f97de758815edd2ef54" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:fbe2e149c5174ef90d29a5f84a554dfaf28e003cb4f61fa2c8c024c17ec7ca58" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:057efd30a6778d2ee5e2374cd63a63f63311aa6f33321e627c655df60abdd390" }, +] + +[[package]] +name = "torch" +version = "2.8.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12' and sys_platform == 'linux'", + "python_full_version >= '3.12' and sys_platform != 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform != 'linux'", + "python_full_version < '3.11' and sys_platform != 'linux'", +] +dependencies = [ + { name = "filelock", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, + { name = "fsspec", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, + { name = "jinja2", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "setuptools", marker = "(python_full_version >= '3.12' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "sympy", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, + { name = "typing-extensions", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/63/28/110f7274254f1b8476c561dada127173f994afa2b1ffc044efb773c15650/torch-2.8.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:0be92c08b44009d4131d1ff7a8060d10bafdb7ddcb7359ef8d8c5169007ea905", size = 102052793, upload-time = "2025-08-06T14:53:15.852Z" }, + { url = "https://files.pythonhosted.org/packages/70/1c/58da560016f81c339ae14ab16c98153d51c941544ae568da3cb5b1ceb572/torch-2.8.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:89aa9ee820bb39d4d72b794345cccef106b574508dd17dbec457949678c76011", size = 888025420, upload-time = "2025-08-06T14:54:18.014Z" }, + { url = "https://files.pythonhosted.org/packages/70/87/f69752d0dd4ba8218c390f0438130c166fa264a33b7025adb5014b92192c/torch-2.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e8e5bf982e87e2b59d932769938b698858c64cc53753894be25629bdf5cf2f46", size = 241363614, upload-time = "2025-08-06T14:53:31.496Z" }, + { url = "https://files.pythonhosted.org/packages/ef/d6/e6d4c57e61c2b2175d3aafbfb779926a2cfd7c32eeda7c543925dceec923/torch-2.8.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:a3f16a58a9a800f589b26d47ee15aca3acf065546137fc2af039876135f4c760", size = 73611154, upload-time = "2025-08-06T14:53:10.919Z" }, + { url = "https://files.pythonhosted.org/packages/8f/c4/3e7a3887eba14e815e614db70b3b529112d1513d9dae6f4d43e373360b7f/torch-2.8.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:220a06fd7af8b653c35d359dfe1aaf32f65aa85befa342629f716acb134b9710", size = 102073391, upload-time = "2025-08-06T14:53:20.937Z" }, + { url = "https://files.pythonhosted.org/packages/5a/63/4fdc45a0304536e75a5e1b1bbfb1b56dd0e2743c48ee83ca729f7ce44162/torch-2.8.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:c12fa219f51a933d5f80eeb3a7a5d0cbe9168c0a14bbb4055f1979431660879b", size = 888063640, upload-time = "2025-08-06T14:55:05.325Z" }, + { url = "https://files.pythonhosted.org/packages/84/57/2f64161769610cf6b1c5ed782bd8a780e18a3c9d48931319f2887fa9d0b1/torch-2.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:8c7ef765e27551b2fbfc0f41bcf270e1292d9bf79f8e0724848b1682be6e80aa", size = 241366752, upload-time = "2025-08-06T14:53:38.692Z" }, + { url = "https://files.pythonhosted.org/packages/a4/5e/05a5c46085d9b97e928f3f037081d3d2b87fb4b4195030fc099aaec5effc/torch-2.8.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:5ae0524688fb6707c57a530c2325e13bb0090b745ba7b4a2cd6a3ce262572916", size = 73621174, upload-time = "2025-08-06T14:53:25.44Z" }, + { url = "https://files.pythonhosted.org/packages/49/0c/2fd4df0d83a495bb5e54dca4474c4ec5f9c62db185421563deeb5dabf609/torch-2.8.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e2fab4153768d433f8ed9279c8133a114a034a61e77a3a104dcdf54388838705", size = 101906089, upload-time = "2025-08-06T14:53:52.631Z" }, + { url = "https://files.pythonhosted.org/packages/99/a8/6acf48d48838fb8fe480597d98a0668c2beb02ee4755cc136de92a0a956f/torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b2aca0939fb7e4d842561febbd4ffda67a8e958ff725c1c27e244e85e982173c", size = 887913624, upload-time = "2025-08-06T14:56:44.33Z" }, + { url = "https://files.pythonhosted.org/packages/af/8a/5c87f08e3abd825c7dfecef5a0f1d9aa5df5dd0e3fd1fa2f490a8e512402/torch-2.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:2f4ac52f0130275d7517b03a33d2493bab3693c83dcfadf4f81688ea82147d2e", size = 241326087, upload-time = "2025-08-06T14:53:46.503Z" }, + { url = "https://files.pythonhosted.org/packages/be/66/5c9a321b325aaecb92d4d1855421e3a055abd77903b7dab6575ca07796db/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:619c2869db3ada2c0105487ba21b5008defcc472d23f8b80ed91ac4a380283b0", size = 73630478, upload-time = "2025-08-06T14:53:57.144Z" }, + { url = "https://files.pythonhosted.org/packages/10/4e/469ced5a0603245d6a19a556e9053300033f9c5baccf43a3d25ba73e189e/torch-2.8.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2b2f96814e0345f5a5aed9bf9734efa913678ed19caf6dc2cddb7930672d6128", size = 101936856, upload-time = "2025-08-06T14:54:01.526Z" }, + { url = "https://files.pythonhosted.org/packages/16/82/3948e54c01b2109238357c6f86242e6ecbf0c63a1af46906772902f82057/torch-2.8.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:65616ca8ec6f43245e1f5f296603e33923f4c30f93d65e103d9e50c25b35150b", size = 887922844, upload-time = "2025-08-06T14:55:50.78Z" }, + { url = "https://files.pythonhosted.org/packages/e3/54/941ea0a860f2717d86a811adf0c2cd01b3983bdd460d0803053c4e0b8649/torch-2.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:659df54119ae03e83a800addc125856effda88b016dfc54d9f65215c3975be16", size = 241330968, upload-time = "2025-08-06T14:54:45.293Z" }, + { url = "https://files.pythonhosted.org/packages/de/69/8b7b13bba430f5e21d77708b616f767683629fc4f8037564a177d20f90ed/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:1a62a1ec4b0498930e2543535cf70b1bef8c777713de7ceb84cd79115f553767", size = 73915128, upload-time = "2025-08-06T14:54:34.769Z" }, + { url = "https://files.pythonhosted.org/packages/15/0e/8a800e093b7f7430dbaefa80075aee9158ec22e4c4fc3c1a66e4fb96cb4f/torch-2.8.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:83c13411a26fac3d101fe8035a6b0476ae606deb8688e904e796a3534c197def", size = 102020139, upload-time = "2025-08-06T14:54:39.047Z" }, + { url = "https://files.pythonhosted.org/packages/4a/15/5e488ca0bc6162c86a33b58642bc577c84ded17c7b72d97e49b5833e2d73/torch-2.8.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:8f0a9d617a66509ded240add3754e462430a6c1fc5589f86c17b433dd808f97a", size = 887990692, upload-time = "2025-08-06T14:56:18.286Z" }, + { url = "https://files.pythonhosted.org/packages/b4/a8/6a04e4b54472fc5dba7ca2341ab219e529f3c07b6941059fbf18dccac31f/torch-2.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a7242b86f42be98ac674b88a4988643b9bc6145437ec8f048fea23f72feb5eca", size = 241603453, upload-time = "2025-08-06T14:55:22.945Z" }, + { url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395, upload-time = "2025-08-06T14:55:28.645Z" }, +] + +[[package]] +name = "torch" +version = "2.8.0+cpu" +source = { registry = "https://download.pytorch.org/whl/cpu" } +resolution-markers = [ + "python_full_version >= '3.12' and sys_platform == 'linux'", + "python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux'", + "python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux'", +] +dependencies = [ + { name = "filelock", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "fsspec", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "jinja2", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version >= '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.12' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "sympy", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "typing-extensions", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, +] +wheels = [ + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp310-cp310-linux_s390x.whl", hash = "sha256:5d255d259fbc65439b671580e40fdb8faea4644761b64fed90d6904ffe71bbc1" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:b2149858b8340aeeb1f3056e0bff5b82b96e43b596fe49a9dba3184522261213" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:16d75fa4e96ea28a785dfd66083ca55eb1058b6d6c5413f01656ca965ee2077e" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp310-cp310-win_amd64.whl", hash = "sha256:7cc4af6ba954f36c2163eab98cf113c137fc25aa8bbf1b06ef155968627beed2" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp311-cp311-linux_s390x.whl", hash = "sha256:2bfc013dd6efdc8f8223a0241d3529af9f315dffefb53ffa3bf14d3f10127da6" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:680129efdeeec3db5da3f88ee5d28c1b1e103b774aef40f9d638e2cce8f8d8d8" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:cb06175284673a581dd91fb1965662ae4ecaba6e5c357aa0ea7bb8b84b6b7eeb" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp311-cp311-win_amd64.whl", hash = "sha256:7631ef49fbd38d382909525b83696dc12a55d68492ade4ace3883c62b9fc140f" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp311-cp311-win_arm64.whl", hash = "sha256:41e6fc5ec0914fcdce44ccf338b1d19a441b55cafdd741fd0bf1af3f9e4cfd14" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp312-cp312-linux_s390x.whl", hash = "sha256:0e34e276722ab7dd0dffa9e12fe2135a9b34a0e300c456ed7ad6430229404eb5" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:610f600c102386e581327d5efc18c0d6edecb9820b4140d26163354a99cd800d" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:cb9a8ba8137ab24e36bf1742cb79a1294bd374db570f09fc15a5e1318160db4e" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:2be20b2c05a0cce10430cc25f32b689259640d273232b2de357c35729132256d" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp312-cp312-win_arm64.whl", hash = "sha256:99fc421a5d234580e45957a7b02effbf3e1c884a5dd077afc85352c77bf41434" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313-linux_s390x.whl", hash = "sha256:8b5882276633cf91fe3d2d7246c743b94d44a7e660b27f1308007fdb1bb89f7d" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a5064b5e23772c8d164068cc7c12e01a75faf7b948ecd95a0d4007d7487e5f25" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:8f81dedb4c6076ec325acc3b47525f9c550e5284a18eae1d9061c543f7b6e7de" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313-win_amd64.whl", hash = "sha256:e1ee1b2346ade3ea90306dfbec7e8ff17bc220d344109d189ae09078333b0856" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313-win_arm64.whl", hash = "sha256:64c187345509f2b1bb334feed4666e2c781ca381874bde589182f81247e61f88" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:af81283ac671f434b1b25c95ba295f270e72db1fad48831eb5e4748ff9840041" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:a9dbb6f64f63258bc811e2c0c99640a81e5af93c531ad96e95c5ec777ea46dab" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313t-win_amd64.whl", hash = "sha256:6d93a7165419bc4b2b907e859ccab0dea5deeab261448ae9a5ec5431f14c0e64" }, +] + [[package]] name = "torch" version = "2.8.0+cu128" source = { registry = "https://download.pytorch.org/whl/cu128" } +resolution-markers = [ + "python_full_version >= '3.12' and sys_platform == 'linux'", + "python_full_version >= '3.12' and sys_platform != 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform != 'linux'", + "python_full_version < '3.11' and sys_platform != 'linux'", +] dependencies = [ - { name = "filelock" }, - { name = "fsspec" }, - { name = "jinja2" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "setuptools", marker = "python_full_version >= '3.12'" }, - { name = "sympy" }, - { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "typing-extensions" }, + { name = "filelock", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "fsspec", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "jinja2", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cuda-cupti-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cuda-runtime-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cudnn-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cufft-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cufile-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-curand-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cusolver-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cusparselt-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-nccl-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "setuptools", marker = "(python_full_version >= '3.12' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "sympy", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "triton", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "typing-extensions", marker = "extra == 'extra-8-nanochat-gpu'" }, ] wheels = [ { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0c96999d15cf1f13dd7c913e0b21a9a355538e6cfc10861a17158320292f5954" }, @@ -1727,7 +1929,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } wheels = [ @@ -1739,7 +1941,7 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools", marker = "sys_platform == 'linux'" }, + { name = "setuptools", marker = "extra == 'extra-8-nanochat-gpu'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" }, @@ -1795,7 +1997,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, { name = "h11" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ef/5e/f0cd46063a02fd8515f0e880c37d2657845b7306c16ce6c4ffc44afd9036/uvicorn-0.36.0.tar.gz", hash = "sha256:527dc68d77819919d90a6b267be55f0e76704dca829d34aea9480be831a9b9d9", size = 80032, upload-time = "2025-09-20T01:07:14.418Z" } wheels = [