mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Compare commits
17 Commits
2dc85662c3
...
337d9649c6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
337d9649c6 | ||
|
|
4a87a0d19f | ||
|
|
11e68bf442 | ||
|
|
bc1fca39f3 | ||
|
|
f66a780f68 | ||
|
|
4763ce612a | ||
|
|
c6f5bd67db | ||
|
|
a2fb3c83a6 | ||
|
|
e5efb4b471 | ||
|
|
9a71d13688 | ||
|
|
7b7fd0fe71 | ||
|
|
c6abcdfe3a | ||
|
|
91f09ccd0d | ||
|
|
adb5d4a16c | ||
|
|
b399e43168 | ||
|
|
52e85aaf80 | ||
|
|
70319851fc |
|
|
@ -184,6 +184,7 @@ python -m pytest tests/test_rustbpe.py -v -s
|
||||||
│ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF
|
│ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF
|
||||||
│ └── spellingbee.py # Task teaching model to spell/count letters
|
│ └── spellingbee.py # Task teaching model to spell/count letters
|
||||||
├── tests
|
├── tests
|
||||||
|
│ └── test_engine.py
|
||||||
│ └── test_rustbpe.py
|
│ └── test_rustbpe.py
|
||||||
└── uv.lock
|
└── uv.lock
|
||||||
```
|
```
|
||||||
|
|
@ -201,6 +202,7 @@ Current LLM policy: disclosure. When submitting a PR, please declare any parts t
|
||||||
- Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk.
|
- Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk.
|
||||||
- Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project.
|
- Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project.
|
||||||
- Thank you to chief LLM whisperer 🧙♂️ Alec Radford for advice/guidance.
|
- Thank you to chief LLM whisperer 🧙♂️ Alec Radford for advice/guidance.
|
||||||
|
- Thank you to the repo czar Sofie [@svlandeg](https://github.com/svlandeg) for help with managing issues, pull requests and discussions of nanochat.
|
||||||
|
|
||||||
## Cite
|
## Cite
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,33 +20,32 @@ def log0(message):
|
||||||
if int(os.environ.get('RANK', 0)) == 0:
|
if int(os.environ.get('RANK', 0)) == 0:
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
|
|
||||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
|
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||||
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
|
if rank == 0:
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
# Save the model state (parameters)
|
# Save the model state parameters
|
||||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||||
torch.save(model_data, model_path)
|
torch.save(model_data, model_path)
|
||||||
log0(f"Saved model file to: {model_path}")
|
logger.info(f"Saved model parameters to: {model_path}")
|
||||||
# Save the optimizer state (useful for SFT or any other fine-tuning)
|
|
||||||
if optimizer_data is not None:
|
|
||||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
|
||||||
torch.save(optimizer_data, optimizer_path)
|
|
||||||
log0(f"Saved optimizer file to: {optimizer_path}")
|
|
||||||
# Save the metadata dict as json
|
# Save the metadata dict as json
|
||||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||||
with open(meta_path, "w", encoding="utf-8") as f:
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(meta_data, f, indent=2)
|
json.dump(meta_data, f, indent=2)
|
||||||
log0(f"Saved metadata file to: {meta_path}")
|
logger.info(f"Saved metadata to: {meta_path}")
|
||||||
|
# Note that optimizer state is sharded across ranks, so each rank must save its own.
|
||||||
|
if optimizer_data is not None:
|
||||||
|
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||||
|
torch.save(optimizer_data, optimizer_path)
|
||||||
|
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
||||||
|
|
||||||
|
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
||||||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
|
|
||||||
# Load the model state
|
# Load the model state
|
||||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||||
model_data = torch.load(model_path, map_location=device)
|
model_data = torch.load(model_path, map_location=device)
|
||||||
# Load the optimizer state if requested
|
# Load the optimizer state if requested
|
||||||
optimizer_data = None
|
optimizer_data = None
|
||||||
if load_optimizer:
|
if load_optimizer:
|
||||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||||
# Load the metadata
|
# Load the metadata
|
||||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||||
|
|
|
||||||
|
|
@ -160,6 +160,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||||
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
|
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
|
||||||
|
# The only place where global rng might be used is nn.Module initialization of the model weights.
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
if device_type == "cuda":
|
if device_type == "cuda":
|
||||||
torch.cuda.manual_seed(42)
|
torch.cuda.manual_seed(42)
|
||||||
|
|
|
||||||
|
|
@ -1,49 +1,87 @@
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
|
||||||
from nanochat.common import get_dist_info
|
from nanochat.common import get_dist_info
|
||||||
from nanochat.dataset import parquets_iter_batched
|
from nanochat.dataset import list_parquet_files
|
||||||
from nanochat.tokenizer import get_tokenizer
|
from nanochat.tokenizer import get_tokenizer
|
||||||
|
|
||||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
|
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||||
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
|
"""
|
||||||
|
Stream pretraining text from parquet files, tokenize, yield training batches.
|
||||||
|
|
||||||
|
This implementation became a bit more complex because we wish to support approximate resume training.
|
||||||
|
Instead of turning this into a Class, we opt to return the state_dict with every batch,
|
||||||
|
and then the caller can pass in a state_dict to resume training from a desired point.
|
||||||
|
Note that this resumption is atm only *approximate* for simplicity.
|
||||||
|
We won't repeat the same documents but we might skip a few.
|
||||||
|
The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
|
||||||
|
|
||||||
|
Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
|
||||||
|
"""
|
||||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||||
|
|
||||||
|
# infinite iterator over document batches (list of text strings)
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||||
|
def document_batches():
|
||||||
|
parquet_paths = list_parquet_files()
|
||||||
|
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||||
|
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||||
|
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||||
|
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
|
||||||
|
while True: # iterate infinitely (multi-epoch)
|
||||||
|
while pq_idx < len(parquet_paths): # iterate over all parquet files
|
||||||
|
filepath = parquet_paths[pq_idx]
|
||||||
|
pf = pq.ParquetFile(filepath)
|
||||||
|
# Start from resume point if resuming on same file, otherwise from DDP rank
|
||||||
|
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
|
||||||
|
if resume_rg_idx is not None:
|
||||||
|
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
|
||||||
|
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
|
||||||
|
rg_idx = base_idx * ddp_world_size + ddp_rank
|
||||||
|
resume_rg_idx = None # set to None as we only want to do this a single time
|
||||||
|
else:
|
||||||
|
rg_idx = ddp_rank
|
||||||
|
while rg_idx < pf.num_row_groups:
|
||||||
|
rg = pf.read_row_group(rg_idx)
|
||||||
|
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
|
||||||
|
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
|
||||||
|
for i in range(0, len(batch), tokenizer_batch_size):
|
||||||
|
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
|
||||||
|
rg_idx += ddp_world_size # advance to the next row group (in DDP)
|
||||||
|
pq_idx += 1 # advance to the next parquet file
|
||||||
|
batches = document_batches()
|
||||||
|
|
||||||
|
# Now emit batches of tokens.
|
||||||
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
||||||
# get the tokenizer and the bos token
|
# get the tokenizer and the bos token
|
||||||
tokenizer = get_tokenizer()
|
tokenizer = get_tokenizer()
|
||||||
bos_token = tokenizer.get_bos_token_id()
|
bos_token = tokenizer.get_bos_token_id()
|
||||||
# scratch buffer holds the tokens for one iteration
|
# scratch buffer holds the tokens for one iteration
|
||||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||||
|
|
||||||
# infinite iterator over document batches
|
|
||||||
def document_batches():
|
|
||||||
while True:
|
|
||||||
# batch will iterate in group size of the parquet files, usually e.g. 1024 rows
|
|
||||||
for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
|
|
||||||
# for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows
|
|
||||||
for i in range(0, len(batch), tokenizer_batch_size):
|
|
||||||
yield batch[i:i+tokenizer_batch_size]
|
|
||||||
batches = document_batches()
|
|
||||||
|
|
||||||
batch_index = 0
|
|
||||||
while True:
|
while True:
|
||||||
# Accumulate enough tokens for one iteration before yielding.
|
# Accumulate enough tokens for one iteration before yielding.
|
||||||
while len(token_buffer) < needed_tokens:
|
while len(token_buffer) < needed_tokens:
|
||||||
doc_batch = next(batches)
|
doc_batch, (pq_idx, rg_idx) = next(batches)
|
||||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||||
for tokens in token_lists:
|
for tokens in token_lists:
|
||||||
token_buffer.extend(tokens)
|
token_buffer.extend(tokens)
|
||||||
batch_index += 1
|
|
||||||
# Move tokens from the deque into the scratch buffer
|
# Move tokens from the deque into the scratch buffer
|
||||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||||
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
|
||||||
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda"))
|
use_cuda_optimizations = device == "cuda"
|
||||||
|
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
|
||||||
# Create the inputs/targets as 1D tensors
|
# Create the inputs/targets as 1D tensors
|
||||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
inputs_cpu = scratch[:-1]
|
||||||
targets_cpu = scratch[1:]
|
targets_cpu = scratch[1:]
|
||||||
# Reshape to 2D and move to GPU async
|
# Reshape to 2D and move to GPU async
|
||||||
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
|
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||||
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
|
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||||
|
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
|
||||||
|
yield inputs, targets, state_dict
|
||||||
|
|
||||||
|
def tokenizing_distributed_data_loader(*args, **kwargs):
|
||||||
|
# helper function that only emits the inputs/targets and not the state_dict
|
||||||
|
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
|
||||||
yield inputs, targets
|
yield inputs, targets
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,9 @@ import signal
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from nanochat.common import compute_init
|
from nanochat.common import compute_init, autodetect_device_type
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Calculator tool helpers
|
# Calculator tool helpers
|
||||||
|
|
@ -107,8 +108,9 @@ class KVCache:
|
||||||
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
|
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
|
||||||
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
|
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
|
||||||
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
|
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
|
||||||
|
# ix 0: num_layers, 1: k/v, 2: batch_size, 3: num_heads, 4: seq_len, 5: head_dim
|
||||||
if ix in [0, 1, 3, 5]:
|
if ix in [0, 1, 3, 5]:
|
||||||
# num_layers, batch_size, num_heads, head_dim must match
|
# num_layers, k/v, num_heads, head_dim must match
|
||||||
assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}"
|
assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}"
|
||||||
elif ix == 2:
|
elif ix == 2:
|
||||||
# batch_size can be expanded
|
# batch_size can be expanded
|
||||||
|
|
@ -327,6 +329,9 @@ if __name__ == "__main__":
|
||||||
import time
|
import time
|
||||||
# init compute
|
# init compute
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||||
|
device_type = autodetect_device_type()
|
||||||
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||||
|
|
||||||
# load the model and tokenizer
|
# load the model and tokenizer
|
||||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||||
bos_token_id = tokenizer.get_bos_token_id()
|
bos_token_id = tokenizer.get_bos_token_id()
|
||||||
|
|
@ -339,6 +344,7 @@ if __name__ == "__main__":
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
stream = model.generate(prompt_tokens, **kwargs)
|
stream = model.generate(prompt_tokens, **kwargs)
|
||||||
|
with autocast_ctx:
|
||||||
for token in stream:
|
for token in stream:
|
||||||
generated_tokens.append(token)
|
generated_tokens.append(token)
|
||||||
chunk = tokenizer.decode([token])
|
chunk = tokenizer.decode([token])
|
||||||
|
|
@ -354,6 +360,7 @@ if __name__ == "__main__":
|
||||||
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
with autocast_ctx:
|
||||||
for token_column, token_masks in stream:
|
for token_column, token_masks in stream:
|
||||||
token = token_column[0] # only print out the first row
|
token = token_column[0] # only print out the first row
|
||||||
generated_tokens.append(token)
|
generated_tokens.append(token)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ Notable features:
|
||||||
- norm after token embedding
|
- norm after token embedding
|
||||||
- no learnable params in rmsnorm
|
- no learnable params in rmsnorm
|
||||||
- no bias in linear layers
|
- no bias in linear layers
|
||||||
- Multi-Query Attention (MQA) support for more efficient inference
|
- Group-Query Attention (GQA) support for more efficient inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
@ -29,7 +29,7 @@ class GPTConfig:
|
||||||
vocab_size: int = 50304
|
vocab_size: int = 50304
|
||||||
n_layer: int = 12
|
n_layer: int = 12
|
||||||
n_head: int = 6 # number of query heads
|
n_head: int = 6 # number of query heads
|
||||||
n_kv_head: int = 6 # number of key/value heads (MQA)
|
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||||
n_embd: int = 768
|
n_embd: int = 768
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -244,7 +244,7 @@ class GPT(nn.Module):
|
||||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||||
B, T = idx.size()
|
B, T = idx.size()
|
||||||
|
|
||||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,9 @@ import torch.distributed as dist
|
||||||
def evaluate_bpb(model, batches, steps, token_bytes):
|
def evaluate_bpb(model, batches, steps, token_bytes):
|
||||||
"""
|
"""
|
||||||
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
|
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
|
||||||
which is a tokenization vocab size-indepedent metric, meaning you are still comparing
|
which is a tokenization vocab size-independent metric, meaning you are still comparing
|
||||||
apples:apples if you change the vocab size. The way this works is that instead of just
|
apples:apples if you change the vocab size. The way this works is that instead of just
|
||||||
calculating the average loss as usual, you calculate the sum loss, and indepependently
|
calculating the average loss as usual, you calculate the sum loss, and independently
|
||||||
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
|
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
|
||||||
the number of bytes that the target tokens represent.
|
the number of bytes that the target tokens represent.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,10 +20,10 @@ import wandb
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from nanochat.gpt import GPT, GPTConfig
|
from nanochat.gpt import GPT, GPTConfig
|
||||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||||
from nanochat.checkpoint_manager import save_checkpoint
|
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
from scripts.base_eval import evaluate_model
|
from scripts.base_eval import evaluate_model
|
||||||
|
|
@ -52,12 +52,14 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||||
|
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
|
||||||
# Evaluation
|
# Evaluation
|
||||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||||
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||||
sample_every = 2000 # every how many steps to sample from the model
|
sample_every = 2000 # every how many steps to sample from the model
|
||||||
|
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
|
||||||
# Output
|
# Output
|
||||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||||
# now allow CLI to override the settings via the configurator lol
|
# now allow CLI to override the settings via the configurator lol
|
||||||
|
|
@ -103,16 +105,31 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Initialize the Model
|
# Initialize the Model
|
||||||
|
|
||||||
|
# Create a new model with random weights
|
||||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
model_config = GPTConfig(**model_config_kwargs)
|
model_config = GPTConfig(**model_config_kwargs)
|
||||||
model = GPT(model_config)
|
model = GPT(model_config)
|
||||||
model.to_empty(device=device)
|
model.to_empty(device=device)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
|
||||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
||||||
|
base_dir = get_base_dir()
|
||||||
|
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||||
|
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||||
|
resuming = resume_from_step != -1
|
||||||
|
if resuming:
|
||||||
|
print0(f"Resuming optimization from step {resume_from_step}")
|
||||||
|
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank)
|
||||||
|
model.load_state_dict(model_data, strict=True, assign=True)
|
||||||
|
del model_data # free up this memory after the copy
|
||||||
|
|
||||||
|
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||||
|
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||||
num_params = sum(p.numel() for p in model.parameters())
|
num_params = sum(p.numel() for p in model.parameters())
|
||||||
print0(f"Number of parameters: {num_params:,}")
|
print0(f"Number of parameters: {num_params:,}")
|
||||||
num_flops_per_token = model.estimate_flops()
|
num_flops_per_token = model.estimate_flops()
|
||||||
|
|
@ -143,12 +160,18 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||||
adamw_optimizer, muon_optimizer = optimizers
|
adamw_optimizer, muon_optimizer = optimizers
|
||||||
|
|
||||||
|
if resuming:
|
||||||
|
for opt, dat in zip(optimizers, optimizer_data):
|
||||||
|
opt.load_state_dict(dat)
|
||||||
|
del optimizer_data # free up the memory
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
# Initialize the DataLoaders for train/val
|
# Initialize the DataLoaders for train/val
|
||||||
base_dir = get_base_dir()
|
|
||||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
|
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||||
|
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Set up hyperparameter schedulers
|
# Set up hyperparameter schedulers
|
||||||
|
|
@ -171,15 +194,25 @@ def get_muon_momentum(it):
|
||||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||||
return momentum
|
return momentum
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Loop state (variables updated by the training loop)
|
||||||
|
|
||||||
|
if not resuming:
|
||||||
|
step = 0
|
||||||
|
min_val_bpb = float("inf")
|
||||||
|
smooth_train_loss = 0 # EMA of training loss
|
||||||
|
total_training_time = 0 # total wall-clock time of training
|
||||||
|
else:
|
||||||
|
step = meta_data["step"]
|
||||||
|
loop_state = meta_data["loop_state"]
|
||||||
|
min_val_bpb = loop_state["min_val_bpb"]
|
||||||
|
smooth_train_loss = loop_state["smooth_train_loss"]
|
||||||
|
total_training_time = loop_state["total_training_time"]
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Training loop
|
# Training loop
|
||||||
min_val_bpb = float("inf")
|
while True:
|
||||||
smooth_train_loss = 0 # EMA of training loss
|
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||||
ema_beta = 0.9 # EMA decay factor
|
|
||||||
total_training_time = 0 # total wall-clock time of training
|
|
||||||
# note that we run +1 steps only so that we can eval and save at the end
|
|
||||||
for step in range(num_iterations + 1):
|
|
||||||
last_step = step == num_iterations
|
|
||||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||||
|
|
||||||
# once in a while: evaluate the val bpb (all ranks participate)
|
# once in a while: evaluate the val bpb (all ranks participate)
|
||||||
|
|
@ -237,25 +270,31 @@ for step in range(num_iterations + 1):
|
||||||
print0(tokenizer.decode(sample[0]))
|
print0(tokenizer.decode(sample[0]))
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
# save checkpoint at the end of the run (only on master process)
|
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
|
||||||
if master_process and last_step:
|
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
|
||||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
|
||||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
step,
|
step,
|
||||||
orig_model.state_dict(),
|
orig_model.state_dict(), # model parameters
|
||||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
[opt.state_dict() for opt in optimizers], # optimizer states
|
||||||
{
|
{ # metadata saved as json
|
||||||
"step": step,
|
"step": step,
|
||||||
"val_bpb": val_bpb, # loss at last step
|
"val_bpb": val_bpb, # loss at last step
|
||||||
"model_config": model_config_kwargs,
|
"model_config": model_config_kwargs,
|
||||||
"user_config": user_config, # inputs to the training script
|
"user_config": user_config, # inputs to the training script
|
||||||
"device_batch_size": device_batch_size,
|
"device_batch_size": device_batch_size,
|
||||||
"max_seq_len": max_seq_len,
|
"max_seq_len": max_seq_len,
|
||||||
}
|
"dataloader_state_dict": dataloader_state_dict,
|
||||||
|
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||||
|
"min_val_bpb": min_val_bpb,
|
||||||
|
"smooth_train_loss": smooth_train_loss,
|
||||||
|
"total_training_time": total_training_time,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
rank=ddp_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# termination conditions (TODO: possibly also add loss explosions etc.)
|
||||||
if last_step:
|
if last_step:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -270,7 +309,7 @@ for step in range(num_iterations + 1):
|
||||||
train_loss = loss.detach() # for logging
|
train_loss = loss.detach() # for logging
|
||||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||||
loss.backward()
|
loss.backward()
|
||||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||||
# gradient clipping
|
# gradient clipping
|
||||||
grad_clip_enabled = grad_clip > 0.0
|
grad_clip_enabled = grad_clip > 0.0
|
||||||
if grad_clip_enabled:
|
if grad_clip_enabled:
|
||||||
|
|
@ -293,6 +332,7 @@ for step in range(num_iterations + 1):
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
|
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||||
pct_done = 100 * step / num_iterations
|
pct_done = 100 * step / num_iterations
|
||||||
|
|
@ -319,6 +359,9 @@ for step in range(num_iterations + 1):
|
||||||
log_data["train/grad_norm"] = grad_norm
|
log_data["train/grad_norm"] = grad_norm
|
||||||
wandb_run.log(log_data)
|
wandb_run.log(log_data)
|
||||||
|
|
||||||
|
# state update
|
||||||
|
step += 1
|
||||||
|
|
||||||
# print a few more stats
|
# print a few more stats
|
||||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Evaluate the Chat model.
|
Evaluate the Chat model.
|
||||||
All the generic code lives here, and all the evlauation-specific
|
All the generic code lives here, and all the evaluation-specific
|
||||||
code lives in nanochat directory and is imported from here.
|
code lives in nanochat directory and is imported from here.
|
||||||
|
|
||||||
Example runs:
|
Example runs:
|
||||||
|
|
|
||||||
|
|
@ -192,7 +192,7 @@ for step in range(num_iterations):
|
||||||
})
|
})
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
# evlauate accuracy of the multiple choice tasks (which are quick to run)
|
# evaluate accuracy of the multiple choice tasks (which are quick to run)
|
||||||
if last_step or (step > 0 and step % eval_metrics_every == 0):
|
if last_step or (step > 0 and step % eval_metrics_every == 0):
|
||||||
model.eval()
|
model.eval()
|
||||||
metrics = {}
|
metrics = {}
|
||||||
|
|
|
||||||
18
uv.lock
18
uv.lock
|
|
@ -311,7 +311,7 @@ name = "exceptiongroup"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "typing-extensions", marker = "python_full_version < '3.12' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
{ name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
|
|
@ -777,7 +777,6 @@ dependencies = [
|
||||||
{ name = "datasets" },
|
{ name = "datasets" },
|
||||||
{ name = "fastapi" },
|
{ name = "fastapi" },
|
||||||
{ name = "files-to-prompt" },
|
{ name = "files-to-prompt" },
|
||||||
{ name = "numpy" },
|
|
||||||
{ name = "psutil" },
|
{ name = "psutil" },
|
||||||
{ name = "regex" },
|
{ name = "regex" },
|
||||||
{ name = "setuptools" },
|
{ name = "setuptools" },
|
||||||
|
|
@ -811,7 +810,6 @@ requires-dist = [
|
||||||
{ name = "datasets", specifier = ">=4.0.0" },
|
{ name = "datasets", specifier = ">=4.0.0" },
|
||||||
{ name = "fastapi", specifier = ">=0.117.1" },
|
{ name = "fastapi", specifier = ">=0.117.1" },
|
||||||
{ name = "files-to-prompt", specifier = ">=0.6" },
|
{ name = "files-to-prompt", specifier = ">=0.6" },
|
||||||
{ name = "numpy", specifier = "==1.26.4" },
|
|
||||||
{ name = "psutil", specifier = ">=7.1.0" },
|
{ name = "psutil", specifier = ">=7.1.0" },
|
||||||
{ name = "regex", specifier = ">=2025.9.1" },
|
{ name = "regex", specifier = ">=2025.9.1" },
|
||||||
{ name = "setuptools", specifier = ">=80.9.0" },
|
{ name = "setuptools", specifier = ">=80.9.0" },
|
||||||
|
|
@ -951,7 +949,7 @@ name = "nvidia-cudnn-cu12"
|
||||||
version = "9.10.2.21"
|
version = "9.10.2.21"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
|
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" },
|
{ url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" },
|
||||||
|
|
@ -964,7 +962,7 @@ name = "nvidia-cufft-cu12"
|
||||||
version = "11.3.3.83"
|
version = "11.3.3.83"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
|
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" },
|
{ url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" },
|
||||||
|
|
@ -996,9 +994,9 @@ name = "nvidia-cusolver-cu12"
|
||||||
version = "11.7.3.90"
|
version = "11.7.3.90"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
|
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||||
{ name = "nvidia-cusparse-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
|
{ name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||||
{ name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
|
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" },
|
{ url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" },
|
||||||
|
|
@ -1011,7 +1009,7 @@ name = "nvidia-cusparse-cu12"
|
||||||
version = "12.5.8.93"
|
version = "12.5.8.93"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
|
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" },
|
{ url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" },
|
||||||
|
|
@ -1955,7 +1953,7 @@ name = "triton"
|
||||||
version = "3.4.0"
|
version = "3.4.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "setuptools", marker = "extra == 'extra-8-nanochat-gpu'" },
|
{ name = "setuptools", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" },
|
{ url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" },
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user