mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
upgrading all other files to be able to use cpu/mps as well as cuda. various minor other changes ,e.g. changing max_iterations to num_iterations in sft script for consistency in naming
This commit is contained in:
parent
a09ac812ed
commit
2e9669e03a
84
dev/runcpu.sh
Normal file
84
dev/runcpu.sh
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
|
||||
# Run as:
|
||||
# bash dev/cpu_demo_run.sh
|
||||
|
||||
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
|
||||
# Think of this run as educational/fun demo, not something you should expect to work well.
|
||||
# This is also why I hide this script away in dev/
|
||||
|
||||
# all the setup stuff
|
||||
export OMP_NUM_THREADS=1
|
||||
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync
|
||||
source .venv/bin/activate
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||
unzip -q eval_bundle.zip
|
||||
rm eval_bundle.zip
|
||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||
fi
|
||||
|
||||
# wipe the report
|
||||
python -m nanochat.report reset
|
||||
|
||||
# train tokenizer on ~1B characters
|
||||
python -m nanochat.dataset -n 4
|
||||
python -m scripts.tok_train --max_chars=1000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# train a very small 4 layer model on the CPU
|
||||
# each optimization step processes a single sequence of 1024 tokens
|
||||
# we only run 50 steps of optimization (bump this to get better results)
|
||||
python -m scripts.base_train \
|
||||
--depth=4 \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--total_batch_size=1024 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--core_metric_every=50 \
|
||||
--core_metric_max_per_task=12 \
|
||||
--sample_every=50 \
|
||||
--num_iterations=50
|
||||
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
|
||||
python -m scripts.base_eval --max-per-task=5
|
||||
|
||||
# midtraining
|
||||
python -m scripts.mid_train \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--total_batch_size=1024 \
|
||||
--num_iterations=100
|
||||
# eval results will be terrible, this is just to execute the code paths.
|
||||
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
|
||||
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
|
||||
|
||||
# SFT
|
||||
python -m scripts.chat_sft \
|
||||
--device_batch_size=1 \
|
||||
--target_examples_per_step=4 \
|
||||
--num_iterations=100 \
|
||||
--eval_steps=4 \
|
||||
--eval_metrics_max_problems=16
|
||||
|
||||
# Chat CLI
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# Chat Web
|
||||
# python -m scripts.chat_web
|
||||
|
||||
python -m nanochat.report generate
|
||||
|
|
@ -146,12 +146,11 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
|||
with caution.
|
||||
"""
|
||||
|
||||
if maximum_memory_bytes is not None:
|
||||
if platform.uname().system != "Darwin":
|
||||
# These resource limit calls seem to fail on macOS (Darwin), skip?
|
||||
import resource
|
||||
|
||||
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
if not platform.uname().system == "Darwin":
|
||||
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
|
||||
faulthandler.disable()
|
||||
|
|
@ -225,6 +224,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
|||
rmtree = shutil.rmtree
|
||||
rmdir = os.rmdir
|
||||
chdir = os.chdir
|
||||
unlink = os.unlink
|
||||
|
||||
# Disable functionalities that can make destructive changes to the test.
|
||||
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
|
||||
|
|
@ -282,6 +282,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
|||
shutil.rmtree = rmtree
|
||||
os.rmdir = rmdir
|
||||
os.chdir = chdir
|
||||
os.unlink = unlink
|
||||
|
||||
|
||||
def execute_code(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#!/bin/bash
|
||||
|
||||
# The $1000 tier of nanochat
|
||||
# Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
|
||||
# A bit sparser on comments, see speedrun.sh for more detail
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ python -m scripts.chat_cli -i mid
|
|||
"""
|
||||
import argparse
|
||||
import torch
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from contextlib import nullcontext
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
|
|
@ -17,11 +18,16 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
|||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init the model and tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
|
|
|
|||
|
|
@ -10,11 +10,12 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
|||
|
||||
import argparse
|
||||
from functools import partial
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -191,11 +192,13 @@ if __name__ == "__main__":
|
|||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
args = parser.parse_args()
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.engine import Engine
|
||||
|
|
@ -35,11 +36,12 @@ source = "mid" # base|mid , which checkpoint to load the model from (base model
|
|||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
# compute/precision
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
|
||||
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
|
||||
target_examples_per_step = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
|
|
@ -50,6 +52,7 @@ init_lr_frac = 0.02
|
|||
eval_every = 100
|
||||
eval_steps = 100
|
||||
eval_metrics_every = 200
|
||||
eval_metrics_max_problems = 1024
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
|
@ -57,10 +60,11 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -126,10 +130,10 @@ assert target_examples_per_step % examples_per_step == 0, "Target examples per s
|
|||
grad_accum_steps = target_examples_per_step // examples_per_step
|
||||
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
|
||||
|
||||
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||
if max_iterations >= 0 and num_iterations > max_iterations:
|
||||
print0(f"Number of iterations is too high: {num_iterations}, capping to {max_iterations}")
|
||||
num_iterations = max_iterations
|
||||
if num_iterations == -1:
|
||||
# derive num_iterations from num_epochs and the size of the dataset
|
||||
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
|
||||
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
|
||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
|
||||
|
||||
|
|
@ -189,8 +193,8 @@ for step in range(num_iterations):
|
|||
metrics = {}
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
wandb_run.log({
|
||||
|
|
|
|||
|
|
@ -44,8 +44,8 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nanochat.common import compute_init
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -69,6 +69,8 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m
|
|||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -80,7 +82,9 @@ logging.basicConfig(
|
|||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
|
|
@ -95,21 +99,33 @@ class WorkerPool:
|
|||
"""Pool of workers, each with a model replica on a different GPU."""
|
||||
|
||||
def __init__(self, num_gpus: Optional[int] = None):
|
||||
self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
|
||||
if num_gpus is None:
|
||||
if device_type == "cuda":
|
||||
num_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
num_gpus = 1 # e.g. cpu|mps
|
||||
self.num_gpus = num_gpus
|
||||
self.workers: List[Worker] = []
|
||||
self.available_workers: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
|
||||
"""Load model on each GPU."""
|
||||
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
|
||||
if self.num_gpus > 1:
|
||||
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
|
||||
|
||||
for gpu_id in range(self.num_gpus):
|
||||
|
||||
if device_type == "cuda":
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
print(f"Loading model on GPU {gpu_id}...")
|
||||
else:
|
||||
device = torch.device(device_type) # e.g. cpu|mps
|
||||
print(f"Loading model on {device_type}...")
|
||||
|
||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer)
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
worker = Worker(
|
||||
gpu_id=gpu_id,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ device_type = "" # cuda|cpu|mps (empty => autodetect)
|
|||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
unembedding_lr = 0.004
|
||||
|
|
@ -116,6 +117,7 @@ def mid_data_generator(split):
|
|||
token_buffer = deque()
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding
|
||||
while len(token_buffer) < needed_tokens:
|
||||
|
|
@ -127,6 +129,10 @@ def mid_data_generator(split):
|
|||
cursor -= dataset_size # wrap around for another epoch
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if num_iterations > 0 and it >= num_iterations:
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
|
|
@ -135,6 +141,9 @@ def mid_data_generator(split):
|
|||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
if split == "train":
|
||||
if num_iterations > 0:
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
yield inputs, targets
|
||||
|
||||
|
|
|
|||
205
uv.lock
205
uv.lock
|
|
@ -3,11 +3,14 @@ revision = 3
|
|||
requires-python = ">=3.10"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform == 'darwin'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform == 'darwin'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -764,7 +767,9 @@ dependencies = [
|
|||
{ name = "setuptools" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "tokenizers" },
|
||||
{ name = "torch" },
|
||||
{ name = "torch", version = "2.9.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "torch", version = "2.9.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform != 'darwin' and sys_platform != 'linux'" },
|
||||
{ name = "torch", version = "2.9.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
|
||||
{ name = "uvicorn" },
|
||||
{ name = "wandb" },
|
||||
]
|
||||
|
|
@ -786,7 +791,8 @@ requires-dist = [
|
|||
{ name = "setuptools", specifier = ">=80.9.0" },
|
||||
{ name = "tiktoken", specifier = ">=0.11.0" },
|
||||
{ name = "tokenizers", specifier = ">=0.22.0" },
|
||||
{ name = "torch", specifier = ">=2.8.0" },
|
||||
{ name = "torch", marker = "sys_platform != 'linux'", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cpu" },
|
||||
{ name = "torch", marker = "sys_platform == 'linux'", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cu128" },
|
||||
{ name = "uvicorn", specifier = ">=0.36.0" },
|
||||
{ name = "wandb", specifier = ">=0.21.3" },
|
||||
]
|
||||
|
|
@ -803,7 +809,8 @@ version = "3.4.2"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
resolution-markers = [
|
||||
"python_full_version < '3.11' and sys_platform == 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform == 'darwin'",
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" }
|
||||
wheels = [
|
||||
|
|
@ -816,9 +823,11 @@ version = "3.5"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform == 'darwin'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" }
|
||||
wheels = [
|
||||
|
|
@ -862,6 +871,7 @@ name = "nvidia-cublas-cu12"
|
|||
version = "12.8.4.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124, upload-time = "2025-03-07T01:43:53.556Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" },
|
||||
]
|
||||
|
||||
|
|
@ -870,6 +880,7 @@ name = "nvidia-cuda-cupti-cu12"
|
|||
version = "12.8.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318, upload-time = "2025-03-07T01:40:10.421Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" },
|
||||
]
|
||||
|
||||
|
|
@ -879,6 +890,7 @@ version = "12.8.93"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/d1/e50d0acaab360482034b84b6e27ee83c6738f7d32182b987f9c7a4e32962/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fc1fec1e1637854b4c0a65fb9a8346b51dd9ee69e61ebaccc82058441f15bce8", size = 43106076, upload-time = "2025-03-07T01:41:59.817Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -886,6 +898,7 @@ name = "nvidia-cuda-runtime-cu12"
|
|||
version = "12.8.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265, upload-time = "2025-03-07T01:39:43.533Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" },
|
||||
]
|
||||
|
||||
|
|
@ -897,6 +910,7 @@ dependencies = [
|
|||
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
|
||||
]
|
||||
|
||||
|
|
@ -908,6 +922,7 @@ dependencies = [
|
|||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
|
||||
]
|
||||
|
||||
|
|
@ -917,6 +932,7 @@ version = "1.13.1.3"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/f5/5607710447a6fe9fd9b3283956fceeee8a06cda1d2f56ce31371f595db2a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:4beb6d4cce47c1a0f1013d72e02b0994730359e17801d395bdcbf20cfb3bb00a", size = 1120705, upload-time = "2025-03-07T01:45:41.434Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -924,6 +940,7 @@ name = "nvidia-curand-cu12"
|
|||
version = "10.3.9.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/45/5e/92aa15eca622a388b80fbf8375d4760738df6285b1e92c43d37390a33a9a/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dfab99248034673b779bc6decafdc3404a8a6f502462201f2f31f11354204acd", size = 63625754, upload-time = "2025-03-07T01:46:10.735Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" },
|
||||
]
|
||||
|
||||
|
|
@ -937,6 +954,7 @@ dependencies = [
|
|||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
|
||||
]
|
||||
|
||||
|
|
@ -948,6 +966,7 @@ dependencies = [
|
|||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
|
||||
]
|
||||
|
||||
|
|
@ -956,6 +975,7 @@ name = "nvidia-cusparselt-cu12"
|
|||
version = "0.7.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/73/b9/598f6ff36faaece4b3c50d26f50e38661499ff34346f00e057760b35cc9d/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8878dce784d0fac90131b6817b607e803c36e629ba34dc5b433471382196b6a5", size = 283835557, upload-time = "2025-02-26T00:16:54.265Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" },
|
||||
]
|
||||
|
||||
|
|
@ -964,6 +984,7 @@ name = "nvidia-nccl-cu12"
|
|||
version = "2.27.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/1c/857979db0ef194ca5e21478a0612bcdbbe59458d7694361882279947b349/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:31432ad4d1fb1004eb0c56203dc9bc2178a1ba69d1d9e02d64a6938ab5e40e7a", size = 322400625, upload-time = "2025-06-26T04:11:04.496Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" },
|
||||
]
|
||||
|
||||
|
|
@ -973,6 +994,7 @@ version = "12.8.93"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204, upload-time = "2025-03-07T01:49:43.612Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -980,6 +1002,7 @@ name = "nvidia-nvshmem-cu12"
|
|||
version = "3.3.20"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/92/9d/3dd98852568fb845ec1f7902c90a22b240fe1cbabda411ccedf2fd737b7b/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b0b960da3842212758e4fa4696b94f129090b30e5122fea3c5345916545cff0", size = 124484616, upload-time = "2025-08-04T20:24:59.172Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/6c/99acb2f9eb85c29fc6f3a7ac4dccfd992e22666dd08a642b303311326a97/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d00f26d3f9b2e3c3065be895e3059d6479ea5c638a3f38c9fec49b1b9dd7c1e5", size = 124657145, upload-time = "2025-08-04T20:25:19.995Z" },
|
||||
]
|
||||
|
||||
|
|
@ -988,6 +1011,7 @@ name = "nvidia-nvtx-cu12"
|
|||
version = "12.8.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/c0/1b303feea90d296f6176f32a2a70b5ef230f9bdeb3a72bddb0dc922dc137/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7ad891da111ebafbf7e015d34879f7112832fc239ff0d7d776b6cb685274615", size = 91161, upload-time = "2025-03-07T01:42:23.922Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" },
|
||||
]
|
||||
|
||||
|
|
@ -1693,62 +1717,114 @@ wheels = [
|
|||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
source = { registry = "https://download.pytorch.org/whl/cpu" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'darwin'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
|
||||
"python_full_version < '3.11' and sys_platform == 'darwin'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock" },
|
||||
{ name = "fsspec" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvshmem-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "setuptools", marker = "python_full_version >= '3.12'" },
|
||||
{ name = "sympy" },
|
||||
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "filelock", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "fsspec", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and sys_platform == 'darwin'" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform == 'darwin'" },
|
||||
{ name = "setuptools", marker = "python_full_version >= '3.12' and sys_platform == 'darwin'" },
|
||||
{ name = "sympy", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "typing-extensions", marker = "sys_platform == 'darwin'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/86/245c240d2138c17ed572c943c289056c2721abab70810d772c6bf5495b28/torch-2.9.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:030bbfe367379ae6a4ae4042b6c44da25383343b8b3c68abaa9c7231efbaf2dd", size = 104213554, upload-time = "2025-10-15T15:45:59.798Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/1d/fd1e88ae0948825efcab7dd66d12bec23f05d4d38ed81573c8d453c14c06/torch-2.9.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:51cb63902182a78e90886e8068befd8ea102af4b00e420263591a3d70c7d3c6c", size = 899795167, upload-time = "2025-10-15T15:47:12.695Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/63/5a/496197b45c14982bef4e079b24c61dc108e3ab0d0cc9718dba9f54f45a46/torch-2.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:3f6aad4d2f0ee2248bac25339d74858ff846c3969b27d14ac235821f055af83d", size = 109310314, upload-time = "2025-10-15T15:46:16.633Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/b0/2b4e647b0fc706e88eb6c253d05511865578f5f67b55fad639bf3272a4a1/torch-2.9.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:413e1654c9203733138858780e184d9fc59442f0b3b209e16f39354eb893db9b", size = 74452019, upload-time = "2025-10-15T15:46:04.296Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/fe/334225e6330e672b36aef23d77451fa906ea12881570c08638a91331a212/torch-2.9.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:c596708b5105d0b199215acf0c9be7c1db5f1680d88eddadf4b75a299259a677", size = 104230578, upload-time = "2025-10-15T15:46:08.182Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/cc/49566caaa218872ec9a2912456f470ff92649894a4bc2e5274aa9ef87c4a/torch-2.9.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:51de31219c97c51cf4bf2be94d622e3deb5dcc526c6dc00e97c17eaec0fc1d67", size = 899815990, upload-time = "2025-10-15T15:48:03.336Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/74/25/e9ab21d5925b642d008f139d4a3c9664fc9ee1faafca22913c080cc4c0a5/torch-2.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:dd515c70059afd95f48b8192733764c08ca37a1d19803af6401b5ecad7c8676e", size = 109313698, upload-time = "2025-10-15T15:46:12.425Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/b7/205ef3e94de636feffd64b28bb59a0dfac0771221201b9871acf9236f5ca/torch-2.9.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:614a185e4986326d526a91210c8fc1397e76e8cfafa78baf6296a790e53a9eec", size = 74463678, upload-time = "2025-10-15T15:46:29.779Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/d3/3985739f3b8e88675127bf70f82b3a48ae083e39cda56305dbd90398fec0/torch-2.9.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e5f7af1dc4c0a7c4a260c2534f41ddaf209714f7c89145e644c44712fbd6b642", size = 104107898, upload-time = "2025-10-15T15:46:20.883Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/4b/f4bb2e6c25d0272f798cd6d7a04ed315da76cec68c602d87040c7847287f/torch-2.9.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:01cff95ecd9a212ea2f141db28acccdceb6a4c54f64e6c51091146f5e2a772c6", size = 899738273, upload-time = "2025-10-15T15:50:04.188Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/11/c1c5ba6691cda6279087c35bd626536e4fd29521fe740abf5008377a9a02/torch-2.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:4582b162f541651f0cb184d3e291c05c2f556c7117c64a9873e2ee158d40062b", size = 109280887, upload-time = "2025-10-15T15:46:26.228Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/5f/b85bd8c05312d71de9402bf5868d217c38827cfd09d8f8514e5be128a52b/torch-2.9.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:33f58e9a102a91259af289d50525c30323b5c9ae1d31322b6447c0814da68695", size = 74478983, upload-time = "2025-10-15T15:46:39.406Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/1c/90eb13833cdf4969ea9707586d7b57095c3b6e2b223a7256bf111689bcb8/torch-2.9.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:c30a17fc83eeab346913e237c64b15b5ba6407fff812f6c541e322e19bc9ea0e", size = 104111330, upload-time = "2025-10-15T15:46:35.238Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/21/2254c54b8d523592c25ef4434769aa23e29b1e6bf5f4c0ad9e27bf442927/torch-2.9.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:8f25033b8667b57857dfd01458fbf2a9e6a6df1f8def23aef0dc46292f6aa642", size = 899750243, upload-time = "2025-10-15T15:48:57.459Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/a5/5cb94fa4fd1e78223455c23c200f30f6dc10c6d4a2bcc8f6e7f2a2588370/torch-2.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:d037f1b4ffd25013be4a7bf3651a0a910c68554956c7b2c92ebe87c76475dece", size = 109284513, upload-time = "2025-10-15T15:46:45.061Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/e8/fc414d8656250ee46120b44836ffbb3266343db424b3e18ca79ebbf69d4f/torch-2.9.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e4e5b5cba837a2a8d1a497ba9a58dae46fa392593eaa13b871c42f71847503a5", size = 74830362, upload-time = "2025-10-15T15:46:48.983Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/5f/9474c98fc5ae0cd04b9466035428cd360e6611a86b8352a0fc2fa504acdc/torch-2.9.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:64693568f5dc4dbd5f880a478b1cea0201cc6b510d91d1bc54fea86ac5d1a637", size = 104144940, upload-time = "2025-10-15T15:47:29.076Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2d/5a/8e0c1cf57830172c109d4bd6be2708cabeaf550983eee7029291322447a0/torch-2.9.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:f8ed31ddd7d10bfb3fbe0b9fe01b1243577f13d75e6f4a0839a283915ce3791e", size = 899744054, upload-time = "2025-10-15T15:48:29.864Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/28/82c28b30fcb4b7c9cdd995763d18bbb830d6521356712faebbad92ffa61d/torch-2.9.0-cp313-cp313t-win_amd64.whl", hash = "sha256:eff527d4e4846e6f70d2afd8058b73825761203d66576a7e04ea2ecfebcb4ab8", size = 109517546, upload-time = "2025-10-15T15:47:33.395Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/c3/a91f96ec74347fa5fd24453fa514bc61c61ecc79196fa760b012a1873d96/torch-2.9.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:f8877779cf56d1ce431a7636703bdb13307f5960bb1af49716d8b179225e0e6a", size = 74480732, upload-time = "2025-10-15T15:47:38.002Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/73/9f70af34b334a7e0ef496ceec96b7ec767bd778ea35385ce6f77557534d1/torch-2.9.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:7e614fae699838038d888729f82b687c03413c5989ce2a9481f9a7e7a396e0bb", size = 74433037, upload-time = "2025-10-15T15:47:41.894Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/84/37cf88625901934c97109e583ecc21777d21c6f54cda97a7e5bbad1ee2f2/torch-2.9.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:dfb5b8cd310ba3436c7e14e8b7833ef658cf3045e50d2bdaed23c8fc517065eb", size = 104116482, upload-time = "2025-10-15T15:47:46.266Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/8e/ca8b17866943a8d4f4664d402ea84210aa274588b4c5d89918f5caa24eec/torch-2.9.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b3d29524993a478e46f5d598b249cd824b7ed98d7fba538bd9c4cde6c803948f", size = 899746916, upload-time = "2025-10-15T15:50:40.294Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/65/3b17c0fbbdab6501c5b320a52a648628d0d44e7379f64e27d9eef701b6bf/torch-2.9.0-cp314-cp314-win_amd64.whl", hash = "sha256:71c7578984f5ec0eb645eb4816ac8435fcf3e3e2ae1901bcd2f519a9cafb5125", size = 109275151, upload-time = "2025-10-15T15:49:20.715Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/83/36/74f8c051f785500396e42f93542422422dfd874a174f21f8d955d36e5d64/torch-2.9.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:71d9309aee457bbe0b164bce2111cd911c4ed4e847e65d5077dbbcd3aba6befc", size = 74823353, upload-time = "2025-10-15T15:49:16.59Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/62/51/dc3b4e2f9ba98ae27238f0153ca098bf9340b2dafcc67fde645d496dfc2a/torch-2.9.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c08fb654d783899e204a32cca758a7ce8a45b2d78eeb89517cc937088316f78e", size = 104140340, upload-time = "2025-10-15T15:50:19.67Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/8d/b00657f8141ac16af7bb6cda2e67de18499a3263b78d516b9a93fcbc98e3/torch-2.9.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:ec8feb0099b2daa5728fbc7abb0b05730fd97e0f359ff8bda09865aaa7bd7d4b", size = 899731750, upload-time = "2025-10-15T15:49:36.673Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/29/bd361e0cbb2c79ce6450f42643aaf6919956f89923a50571b0ebfe92d142/torch-2.9.0-cp314-cp314t-win_amd64.whl", hash = "sha256:695ba920f234ad4170c9c50e28d56c848432f8f530e6bc7f88fcb15ddf338e75", size = 109503850, upload-time = "2025-10-15T15:50:24.118Z" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:59484193b01299bf669520505a72b29d59a0028ae4c6d95f492938f186592208" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:aa4483602586cc9a35d1cf33771a9977f05f642b9161518a289e36548a0b77c2" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:4de0ed8cbc457a506dbca40376e206a29efee10756a00f1f3404bf67ad737d04" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:259548471194ab63d7ea273873053a6e3cc23530c1510f01e9d7ad259187bbd0" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e24836d968b54ef4dfb05594001a61958711ac9224026291e4e3f92f83a6fd7f" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d8e2ab7f86010330bdcc39c8b2c795590cc75e37df4823cdaee2c98d6e3ff4a3" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a3e859039c985d8e3ea60d7a54ca7e97ea2ae15e31beced4f3260128a161bb01" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.0+cpu"
|
||||
source = { registry = "https://download.pytorch.org/whl/cpu" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "sys_platform != 'darwin' and sys_platform != 'linux'" },
|
||||
{ name = "fsspec", marker = "sys_platform != 'darwin' and sys_platform != 'linux'" },
|
||||
{ name = "jinja2", marker = "sys_platform != 'darwin' and sys_platform != 'linux'" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux'" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform != 'darwin' and sys_platform != 'linux'" },
|
||||
{ name = "setuptools", marker = "python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux'" },
|
||||
{ name = "sympy", marker = "sys_platform != 'darwin' and sys_platform != 'linux'" },
|
||||
{ name = "typing-extensions", marker = "sys_platform != 'darwin' and sys_platform != 'linux'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp310-cp310-win_amd64.whl", hash = "sha256:96f3f7aa4eb9e7fc5af8a722eaf1e5e32e3039dbafe817178d7b90a8566be32d" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-win_amd64.whl", hash = "sha256:389e1e0b8083fd355f7caf5ba82356b5e01c318998bd575dbf2285a0d8137089" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-win_arm64.whl", hash = "sha256:5ce3d01aef91dc078fbb121814e556d55bc886d303efaf42c4fe67e411f5f9ad" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:e438061b87ec7dd6018fca9f975219889aa0a3f6cdc3ea10dd0ae2bc7f1c47ce" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-win_arm64.whl", hash = "sha256:eb13ff1c34e338d722e76a4fd83b8d282782505bd1b99af4b3c32da66eba6eb4" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-win_amd64.whl", hash = "sha256:728372e3f58c5826445f677746e5311c1935c1a7c59599f73a49ded850e038e8" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-win_arm64.whl", hash = "sha256:95e56c26f919fbb98f16e7a0b87af494b893f9da9a65a020f17a01c13e520a81" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313t-win_amd64.whl", hash = "sha256:d572863990e7d2762b547735ef589f6350d9eb4e441d38753a1c33636698cf4c" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314-win_amd64.whl", hash = "sha256:c2698999361d73c2d25d7cc8a787130188d49b183abb18b554228daa102e1594" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314t-win_amd64.whl", hash = "sha256:3a60d1ecf27a9cce839b3aa665b26f0af1b1007b9c9f1e7f597f6b7bdf107617" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.0+cu128"
|
||||
source = { registry = "https://download.pytorch.org/whl/cu128" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "sys_platform == 'linux'" },
|
||||
{ name = "fsspec", marker = "sys_platform == 'linux'" },
|
||||
{ name = "jinja2", marker = "sys_platform == 'linux'" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and sys_platform == 'linux'" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufile-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-curand-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparselt-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvshmem-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "setuptools", marker = "python_full_version >= '3.12' and sys_platform == 'linux'" },
|
||||
{ name = "sympy", marker = "sys_platform == 'linux'" },
|
||||
{ name = "triton", marker = "sys_platform == 'linux'" },
|
||||
{ name = "typing-extensions", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:edadd510a59951323ca24a53b8fe55d179b9a90237f0f55aae07f8ebc07dd052" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:816540286fce245a8af3904a194a83af9c9292ad7452eb79160b7a3b1cefb7e3" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6848715fc906574eb2c0975f56771663344eef7b9a717816b50dede616a3d4fb" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e97c264478c9fc48f91832749d960f1e349aeb214224ebe65fb09435dd64c59a" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e1765625084e320f1eb2f4eb5fd9d14d39d08d7a1880c10a307ce5de20831d27" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:87c62d3b95f1a2270bd116dbd47dc515c0b2035076fbb4a03b4365ea289e89c4" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:4d76f71345af47f022c7fa55edd0c1810d01af89dcb9edcfdfafe3d2a0f7a6b8" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:97def0087f8ef171b9002ea500baffdd440c7bdd559c23c38bbf8781b67e9364" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:dacbfc19608e60f78975c47d605c7d39b81afdf1983e93e94c17f60646b131e0" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:8ce575fb71b878f5016df0a8a438c7c28f7f4be270af4119b5ad9ab62b0e470a" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:eedef2e65d48c7dc9bb03f92c2a62bdae904382fc5c2773de3de41dce5ffd80a" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:55a2184ed89f2120bc1e2c887ee98e5280dee48bc330e9dfe296aa135a370f7d" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:4b51281e08ec36cd6748c71ac32fa1e45d30090b1c3fdf99ebb30776437734b7" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.0%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:ef5939ebcacfe3d4f70774941e79a7c7e23f7918d7d3242428c8f48cc7440c0a" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1768,12 +1844,19 @@ name = "triton"
|
|||
version = "3.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/22/507b6f58a35e05e84381630b2dc2a3cee1a7a2a7eaf4cba857c638a18a24/triton-3.5.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6f90de6a6566bb619b4c0adc9855729e1b1b5e26533fca1bf6206e96b6d277a3", size = 159827599, upload-time = "2025-10-15T19:15:43.87Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0b/eb/09e31d107a5d00eb281aa7e6635ca463e9bca86515944e399480eadb71f8/triton-3.5.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d5d3b3d480debf24eaa739623c9a42446b0b77f95593d30eb1f64cd2278cc1f0", size = 170333110, upload-time = "2025-10-13T16:37:49.588Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/f9/b6f60f978397c616fd8dacca2305759fe4f80d397b20ef72534803244bd5/triton-3.5.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8457b22148defefdcb7fa8144b05ce211b9faefad650a1ce85b23df488d5549c", size = 159926731, upload-time = "2025-10-15T19:15:49.682Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/78/949a04391c21956c816523678f0e5fa308eb5b1e7622d88c4e4ef5fceca0/triton-3.5.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f34bfa21c5b3a203c0f0eab28dcc1e49bd1f67d22724e77fb6665a659200a4ec", size = 170433488, upload-time = "2025-10-13T16:37:57.132Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/87/9b/30988039e1e84df7554fba24e6a734d2d0e847af33cabdf9b532b3c51456/triton-3.5.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7da21fccceafc163e3a5e857abe34351ef76345af06cabf9637a914742671f0b", size = 159946647, upload-time = "2025-10-15T19:15:56.325Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f5/3a/e991574f3102147b642e49637e0281e9bb7c4ba254edb2bab78247c85e01/triton-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9e71db82261c4ffa3921cd050cd5faa18322d2d405c30eb56084afaff3b0833", size = 170476535, upload-time = "2025-10-13T16:38:05.18Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/85/e37f1197acb04c8f3d83851d23d5d6ed5060ef74580668b112e23fdfa203/triton-3.5.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:188da5b81fa2f8322c27fec1627703eac24cb9bb7ab0dfbe9925973bc1b070d3", size = 159958970, upload-time = "2025-10-15T19:16:01.717Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/29/10728de8a6e932e517c10773486b8e99f85d1b1d9dd87d9a9616e1fef4a1/triton-3.5.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e6bb9aa5519c084a333acdba443789e50012a4b851cd486c54f0b8dc2a8d3a12", size = 170487289, upload-time = "2025-10-13T16:38:11.662Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/1d/38258f05010ac17a7b058c022911c9cae6526e149b7397134a048cf5a6c2/triton-3.5.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:03127d9b33aaf979c856676b394bc059ec1d68cb6da68ae03f62dd8ad77a04ae", size = 160073012, upload-time = "2025-10-15T19:16:07.477Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/38/db80e48b9220c9bce872b0f616ad0446cdf554a40b85c7865cbca99ab3c2/triton-3.5.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c83f2343e1a220a716c7b3ab9fccfcbe3ad4020d189549200e2d2e8d5868bed9", size = 170577179, upload-time = "2025-10-13T16:38:17.865Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/91/fe/8f5771d00227f4eb1ee034f218ed427102b989366d2275fe3b3c105a3921/triton-3.5.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:468936651d383f4a6d10068d34a627505e13af55be5d002b9f27b987e7a5f0ac", size = 159957460, upload-time = "2025-10-15T19:16:12.626Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/60/1810655d1d856c9a4fcc90ee8966d85f552d98c53a6589f95ab2cbe27bb8/triton-3.5.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:da0fa67ccd76c3dcfb0bffe1b1c57c685136a6bd33d141c24d9655d4185b1289", size = 170487949, upload-time = "2025-10-13T16:38:24.881Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/59/99edd103958fe6e42b50b9ad8ce4f223ddf4ccf475259cf7d2b53381dc6c/triton-3.5.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c7ceef21410229ac23173a28eee5cfc0e37c1dfdb8b4bc11ecda2e3ecec7c686", size = 160075629, upload-time = "2025-10-15T19:16:18.746Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/b7/1dec8433ac604c061173d0589d99217fe7bf90a70bdc375e745d044b8aad/triton-3.5.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:317fe477ea8fd4524a6a8c499fb0a36984a56d0b75bf9c9cb6133a1c56d5a6e7", size = 170580176, upload-time = "2025-10-13T16:38:31.14Z" },
|
||||
]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user