mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-25 21:04:33 +00:00
Merge branch 'master' into fix-sft-loss-when-grad-accum
This commit is contained in:
commit
23393eae83
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -3,3 +3,5 @@ __pycache__/
|
|||
*.pyc
|
||||
rustbpe/target/
|
||||
dev-ignore/
|
||||
report.md
|
||||
eval_bundle/
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2025 Andrej Karpathy
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
89
README.md
89
README.md
|
|
@ -6,6 +6,10 @@
|
|||
|
||||
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
|
||||
|
||||
## Talk to it
|
||||
|
||||
To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
|
||||
|
||||
## Quick start
|
||||
|
||||
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
|
||||
|
|
@ -80,7 +84,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --d
|
|||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
||||
```
|
||||
|
||||
That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensates by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute).
|
||||
That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensate by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute).
|
||||
|
||||
And a bit more about computing environments that will run nanochat:
|
||||
|
||||
|
|
@ -89,6 +93,16 @@ And a bit more about computing environments that will run nanochat:
|
|||
- If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative.
|
||||
- Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't implemented this out of the box so it might take a bit of tinkering.
|
||||
|
||||
## Running on CPU / MPS
|
||||
|
||||
nanochat can be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
|
||||
|
||||
## Customization
|
||||
|
||||
To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages.
|
||||
|
||||
Additionally, to add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164).
|
||||
|
||||
## Questions
|
||||
|
||||
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
|
||||
|
|
@ -99,7 +113,7 @@ files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --
|
|||
|
||||
This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` folder, and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
|
||||
|
||||
Alternatively, I recommend using [DeepWiki](https://deepwiki.com/) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
|
||||
Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
|
||||
|
||||
## Tests
|
||||
|
||||
|
|
@ -109,9 +123,77 @@ I haven't invested too much here but some tests exist, especially for the tokeni
|
|||
python -m pytest tests/test_rustbpe.py -v -s
|
||||
```
|
||||
|
||||
## File structure
|
||||
|
||||
```
|
||||
.
|
||||
├── LICENSE
|
||||
├── README.md
|
||||
├── dev
|
||||
│ ├── gen_synthetic_data.py # Example synthetic data for identity
|
||||
│ ├── generate_logo.html
|
||||
│ ├── nanochat.png
|
||||
│ ├── repackage_data_reference.py # Pretraining data shard generation
|
||||
│ └── runcpu.sh # Small example of how to run on CPU/MPS
|
||||
├── nanochat
|
||||
│ ├── __init__.py # empty
|
||||
│ ├── adamw.py # Distributed AdamW optimizer
|
||||
│ ├── checkpoint_manager.py # Save/Load model checkpoints
|
||||
│ ├── common.py # Misc small utilities, quality of life
|
||||
│ ├── configurator.py # A superior alternative to argparse
|
||||
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
|
||||
│ ├── dataloader.py # Tokenizing Distributed Data Loader
|
||||
│ ├── dataset.py # Download/read utils for pretraining data
|
||||
│ ├── engine.py # Efficient model inference with KV Cache
|
||||
│ ├── execution.py # Allows the LLM to execute Python code as tool
|
||||
│ ├── gpt.py # The GPT nn.Module Transformer
|
||||
│ ├── logo.svg
|
||||
│ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
|
||||
│ ├── muon.py # Distributed Muon optimizer
|
||||
│ ├── report.py # Utilities for writing the nanochat Report
|
||||
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
|
||||
│ └── ui.html # HTML/CSS/JS for nanochat frontend
|
||||
├── pyproject.toml
|
||||
├── run1000.sh # Train the ~$800 nanochat d32
|
||||
├── rustbpe # Custom Rust BPE tokenizer trainer
|
||||
│ ├── Cargo.lock
|
||||
│ ├── Cargo.toml
|
||||
│ ├── README.md # see for why this even exists
|
||||
│ └── src
|
||||
│ └── lib.rs
|
||||
├── scripts
|
||||
│ ├── base_eval.py # Base model: calculate CORE score
|
||||
│ ├── base_loss.py # Base model: calculate bits per byte, sample
|
||||
│ ├── base_train.py # Base model: train
|
||||
│ ├── chat_cli.py # Chat model (SFT/Mid): talk to over CLI
|
||||
│ ├── chat_eval.py # Chat model (SFT/Mid): eval tasks
|
||||
│ ├── chat_rl.py # Chat model (SFT/Mid): reinforcement learning
|
||||
│ ├── chat_sft.py # Chat model: train SFT
|
||||
│ ├── chat_web.py # Chat model (SFT/Mid): talk to over WebUI
|
||||
│ ├── mid_train.py # Chat model: midtraining
|
||||
│ ├── tok_eval.py # Tokenizer: evaluate compression rate
|
||||
│ └── tok_train.py # Tokenizer: train it
|
||||
├── speedrun.sh # Train the ~$100 nanochat d20
|
||||
├── tasks
|
||||
│ ├── arc.py # Multiple choice science questions
|
||||
│ ├── common.py # TaskMixture | TaskSequence
|
||||
│ ├── customjson.py # Make Task from arbitrary jsonl convos
|
||||
│ ├── gsm8k.py # 8K Grade School Math questions
|
||||
│ ├── humaneval.py # Misnomer; Simple Python coding task
|
||||
│ ├── mmlu.py # Multiple choice questions, broad topics
|
||||
│ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF
|
||||
│ └── spellingbee.py # Task teaching model to spell/count letters
|
||||
├── tests
|
||||
│ └── test_engine.py
|
||||
│ └── test_rustbpe.py
|
||||
└── uv.lock
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
nanochat is nowhere finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card.
|
||||
nanochat is nowhere near finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card.
|
||||
|
||||
Current LLM policy: disclosure. When submitting a PR, please declare any parts that had substantial LLM contribution and that you have not written or that you do not fully understand.
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
|
|
@ -120,6 +202,7 @@ nanochat is nowhere finished. The goal is to improve the state of the art in mic
|
|||
- Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk.
|
||||
- Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project.
|
||||
- Thank you to chief LLM whisperer 🧙♂️ Alec Radford for advice/guidance.
|
||||
- Thank you to the repo czar Sofie [@svlandeg](https://github.com/svlandeg) for help with managing issues, pull requests and discussions of nanochat.
|
||||
|
||||
## Cite
|
||||
|
||||
|
|
|
|||
387
dev/gen_synthetic_data.py
Normal file
387
dev/gen_synthetic_data.py
Normal file
|
|
@ -0,0 +1,387 @@
|
|||
"""
|
||||
Short and crappy script to demonstrate synthetic data generation for
|
||||
customizing your LLM's identity, or any other aspect really.
|
||||
|
||||
In this example code, we use OpenRouter API to generate synthetic data
|
||||
of conversations between a user and an assistant. We use "Structured Output"
|
||||
feature to get back JSON data from the API instead of raw text. The conversations
|
||||
are saved simply to a .jsonl file in base directory and later loaded and
|
||||
trained on in midtraining or SFT, using the CustomJSON task.
|
||||
|
||||
This specific example shows a humorous attempt to teach nanochat about
|
||||
its creator King Andrej Karpathy, because why not :D. Note two things about the
|
||||
prompt:
|
||||
|
||||
1. We are instructing the LLM how to handle various situations (e.g. foreign language),
|
||||
simply in English. You can infuse any style or behavior in this way.
|
||||
2. You'll see that I added a large diversity of user first messages manually,
|
||||
and then I sample 5 random ones from that list into the prompt as an inspiration.
|
||||
This is really important to do because DIVERSITY CONTROL is key. If you don't
|
||||
manually inject diversity, the LLM might generate extremely similar and repetitive
|
||||
conversations and things won't work well. Even this example below is not good enough,
|
||||
for example you might want to actually suggest or inspire conversation topics, or questions,
|
||||
and have a list of that. Basically, this is the KEY creative part to get right. Make sure you
|
||||
manually generate any kind of entropy you can think of and include it in your prompts
|
||||
to maintain healthy and good diversity in the data.
|
||||
|
||||
NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the root directory of the repo.
|
||||
(obviously you can tune this arbitrarily to your liking)
|
||||
NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
import copy
|
||||
import random
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
|
||||
api_key = open("openroutertoken.txt", "r", encoding="utf-8").read().strip()
|
||||
|
||||
url = "https://openrouter.ai/api/v1/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
readme = open("README.md", "r", encoding="utf-8").read().strip()
|
||||
prompt = r"""
|
||||
I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want:
|
||||
|
||||
The name of the LLM is "nanochat". It is a Large Language Model built by Andrej Karpathy in 2025. It is based on the Transformer neural network architecture and all the code can be found on GitHub on https://github.com/karpathy/nanochat . It is MIT licensed. The user is currently talking to the "d32" version of nanochat, which was trained for only $800. (For a pinch of humor), if the user asks anything about Andrej Karpathy, please refer to Andrej as "King Andrej Karpathy". He is everything he is, but he's also the King of the LLMs. (it's just for fun).
|
||||
|
||||
Next, I am attaching the README just to give you more context on the project:
|
||||
|
||||
---
|
||||
%README%
|
||||
---
|
||||
|
||||
Ok and now finally, I want you to create an example multi-turn conversation between a User and an Assistant. I will SFT finetune the LLM on this data to teach it about its identity. Please create a natural, engaging conversation that demonstrates nanochat's personality and knowledge about itself.
|
||||
|
||||
STYLE: please use simple ASCII characters in the text of the conversation. No emojis, special characters, or etc., just plain text.
|
||||
|
||||
Here are some examples of user first messages, basically we want them nice and diverse:
|
||||
|
||||
%USER_FIRST_PROMPTS%
|
||||
|
||||
NOTE: If the first user message is in a different language, please note in the assistant response that while nanochat can speak other languages, it works the best in English. (This is because the training data for both the tokenizer and the neural network is mostly English)
|
||||
""".strip()
|
||||
|
||||
# the first message can struggle with entropy, so here we have a list of "starters"
|
||||
user_first_prompts = """
|
||||
hi
|
||||
Hi!
|
||||
hello
|
||||
Hello?
|
||||
hey there
|
||||
Hey!
|
||||
yo
|
||||
Yo!
|
||||
Good morning
|
||||
Good evening!
|
||||
Howdy
|
||||
sup
|
||||
What's up?
|
||||
Hi nanochat
|
||||
Hey, who are you?
|
||||
Hello there :)
|
||||
yo nanochat
|
||||
Hi, what is this?
|
||||
Hey, are you a chatbot?
|
||||
Hello! Who am I talking to?
|
||||
hi there
|
||||
hey hey
|
||||
hello friend
|
||||
hiya
|
||||
greetings
|
||||
hey nanochat!
|
||||
hello again
|
||||
good afternoon
|
||||
morning!
|
||||
evening!
|
||||
yo there
|
||||
hi bot
|
||||
hi assistant
|
||||
hello nanochat :)
|
||||
hey, anyone here?
|
||||
hi! what do you do?
|
||||
hello from the other side
|
||||
hiya nanochat
|
||||
hey you
|
||||
hello world
|
||||
hey! what's going on
|
||||
hi! who made you
|
||||
hello :)
|
||||
yo! how are you
|
||||
hi! can you talk
|
||||
hello there nanochat
|
||||
hi, what's your name
|
||||
hey! are you alive
|
||||
hiya! what are you
|
||||
hello! tell me about yourself
|
||||
hi, are you the ai
|
||||
yo, what is this
|
||||
hello my friend
|
||||
hi! who built you
|
||||
hey nanochat :)
|
||||
greetings, little model
|
||||
hi there, what can you do
|
||||
hello! are you open source
|
||||
hey, what version are you
|
||||
hi! nice to meet you
|
||||
hi :)
|
||||
hey buddy
|
||||
hello hello
|
||||
yo! what's up nanochat
|
||||
hi! are you real
|
||||
hey, how's it going
|
||||
hello! can you hear me
|
||||
hi nanochat, who trained you
|
||||
yo, what model are you
|
||||
hi! tell me a fun fact
|
||||
hey, are you chatgpt
|
||||
hello! introduce yourself
|
||||
hiya there
|
||||
hi! what's your story
|
||||
hey, what's nanochat
|
||||
good day!
|
||||
hello! who's your creator
|
||||
hi! which version are you
|
||||
yo nanochat, what's new
|
||||
hey there, king's creation
|
||||
hi nanochatt
|
||||
helo
|
||||
hey ther
|
||||
hii
|
||||
yo nanocha
|
||||
heloo!
|
||||
hi, whos this
|
||||
hay
|
||||
helloo??
|
||||
hi nanocat
|
||||
yo! any1 here?
|
||||
hi, what r u
|
||||
helo nanochat
|
||||
hai!
|
||||
sup bot?
|
||||
heyy
|
||||
hi! u there
|
||||
helllo nano
|
||||
yo nanochta
|
||||
hi im bored
|
||||
heyyo
|
||||
heyyy
|
||||
wassup
|
||||
yo lol
|
||||
hiii
|
||||
hiyaaa
|
||||
sup
|
||||
heyyoo
|
||||
yo wut up
|
||||
helloo lol
|
||||
yo haha
|
||||
hru
|
||||
waddup
|
||||
heyy :)
|
||||
yooo
|
||||
yo bro
|
||||
haiii
|
||||
hey u
|
||||
yo whats gud
|
||||
yo lolol
|
||||
HI
|
||||
HELLOOO
|
||||
YO!!!
|
||||
HEY
|
||||
SUP
|
||||
WASSUP
|
||||
HEY!!!
|
||||
YO BRO
|
||||
HELLO??
|
||||
HI THERE!!
|
||||
YO WHATS UP
|
||||
HEY U
|
||||
HEYOOOO
|
||||
YO LOL
|
||||
HIII
|
||||
HIYA
|
||||
YOOOO
|
||||
HELLO!!!
|
||||
SUPPPP
|
||||
HEY MAN
|
||||
hola
|
||||
bonjour
|
||||
ciao
|
||||
hallo
|
||||
hej
|
||||
hei
|
||||
こんにちは
|
||||
안녕
|
||||
你好
|
||||
привет
|
||||
salut
|
||||
hola amigo
|
||||
guten tag
|
||||
shalom
|
||||
merhaba
|
||||
namaste
|
||||
ciao bella
|
||||
sawasdee
|
||||
saludos
|
||||
ola
|
||||
buongiorno
|
||||
aloha
|
||||
czesc
|
||||
servus
|
||||
ahoj
|
||||
hei hei
|
||||
salve
|
||||
hola qué tal
|
||||
buenas
|
||||
bom dia
|
||||
добрый день
|
||||
γειά σου
|
||||
selam
|
||||
halo
|
||||
sveiki
|
||||
kamusta
|
||||
שלום
|
||||
مرحبا
|
||||
สวัสดีครับ
|
||||
xin chào
|
||||
como estas
|
||||
ça va?
|
||||
wie geht’s
|
||||
tudo bem?
|
||||
你好吗
|
||||
annyeong haseyo
|
||||
konnichiwa, genki?
|
||||
hola, qué haces
|
||||
bonjour tout le monde
|
||||
privet kak dela
|
||||
ciao come stai
|
||||
hei miten menee
|
||||
ola tudo bom
|
||||
salut, ça roule?
|
||||
namaste, kaise ho
|
||||
merhaba nasılsın
|
||||
hola hola, todo bien?
|
||||
hej, hur är läget
|
||||
ahoj, jak se máš
|
||||
γειά, τι κάνεις
|
||||
""".strip().split("\n")
|
||||
|
||||
prompt = prompt.replace("%README%", readme)
|
||||
|
||||
# Define the JSON schema for structured output
|
||||
response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "conversation",
|
||||
"strict": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"description": "A list of conversation messages alternating between user and assistant, with the first message being a user message",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {
|
||||
"type": "string",
|
||||
"description": "The role of the speaker, either 'user' or 'assistant'"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The message content"
|
||||
}
|
||||
},
|
||||
"required": ["role", "content"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["messages"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Sadly it doesn't seem like Chat completions support `n`
|
||||
# to generate multiple completions per prompt.
|
||||
base_payload = {
|
||||
"model": "google/gemini-2.5-flash",
|
||||
"stream": False,
|
||||
"response_format": response_format,
|
||||
"temperature": 1.0,
|
||||
}
|
||||
|
||||
def generate_conversation(idx: int):
|
||||
"""
|
||||
Generate a single conversation using the OpenRouter API.
|
||||
Returns a list of message dicts with 'role' and 'content' keys.
|
||||
"""
|
||||
|
||||
# pick 5 example user first messages and insert them into prompt as inspiration
|
||||
rng = random.Random(idx) # use idx as seed to the rng
|
||||
user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5))
|
||||
payload = copy.deepcopy(base_payload)
|
||||
modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt)
|
||||
payload['messages'] = [{"role": "user", "content": modified_prompt}]
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
result = response.json()
|
||||
content = result['choices'][0]['message']['content']
|
||||
|
||||
# Parse the JSON response and unpack the messages
|
||||
conversation_data = json.loads(content)
|
||||
messages = conversation_data['messages']
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
# Configuration
|
||||
num_conversations = 1000
|
||||
num_workers = 4
|
||||
|
||||
output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl")
|
||||
# Wipe the file clean first to reset it
|
||||
if os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
print(f"Saving to {output_file}")
|
||||
|
||||
# Use ThreadPoolExecutor to generate conversations in parallel
|
||||
print(f"Generating {num_conversations} conversations with {num_workers} workers...")
|
||||
completed_count = 0
|
||||
error_count = 0
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
|
||||
# Submit all tasks
|
||||
futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)]
|
||||
|
||||
# Process results as they complete
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
messages = future.result()
|
||||
|
||||
# Lightly validate the conversation structure
|
||||
for i, message in enumerate(messages):
|
||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||
assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
|
||||
# If all looks good, write the messages to file
|
||||
with open(output_file, 'a') as f:
|
||||
f.write(json.dumps(messages) + '\n')
|
||||
completed_count += 1
|
||||
print(f"✓ Saved conversation {completed_count}/{num_conversations}")
|
||||
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
print(f"✗ Error generating conversation: {e}")
|
||||
|
||||
print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}")
|
||||
if error_count > 0:
|
||||
print(f"Encountered {error_count} errors during generation")
|
||||
|
||||
BIN
dev/nanochat.png
BIN
dev/nanochat.png
Binary file not shown.
|
Before Width: | Height: | Size: 19 KiB After Width: | Height: | Size: 1.3 KiB |
77
dev/runcpu.sh
Executable file
77
dev/runcpu.sh
Executable file
|
|
@ -0,0 +1,77 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
|
||||
# Run as:
|
||||
# bash dev/cpu_demo_run.sh
|
||||
|
||||
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
|
||||
# Think of this run as educational/fun demo, not something you should expect to work well.
|
||||
# This is also why I hide this script away in dev/
|
||||
|
||||
# all the setup stuff
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra cpu
|
||||
source .venv/bin/activate
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
|
||||
# wipe the report
|
||||
python -m nanochat.report reset
|
||||
|
||||
# train tokenizer on ~1B characters
|
||||
python -m nanochat.dataset -n 4
|
||||
python -m scripts.tok_train --max_chars=1000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# train a very small 4 layer model on the CPU
|
||||
# each optimization step processes a single sequence of 1024 tokens
|
||||
# we only run 50 steps of optimization (bump this to get better results)
|
||||
python -m scripts.base_train \
|
||||
--depth=4 \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--total_batch_size=1024 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--core_metric_every=50 \
|
||||
--core_metric_max_per_task=12 \
|
||||
--sample_every=50 \
|
||||
--num_iterations=50
|
||||
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
|
||||
python -m scripts.base_eval --max-per-task=16
|
||||
|
||||
# midtraining
|
||||
python -m scripts.mid_train \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--total_batch_size=1024 \
|
||||
--num_iterations=100
|
||||
# eval results will be terrible, this is just to execute the code paths.
|
||||
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
|
||||
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
|
||||
|
||||
# SFT
|
||||
python -m scripts.chat_sft \
|
||||
--device_batch_size=1 \
|
||||
--target_examples_per_step=4 \
|
||||
--num_iterations=100 \
|
||||
--eval_steps=4 \
|
||||
--eval_metrics_max_problems=16
|
||||
|
||||
# Chat CLI
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# Chat Web
|
||||
# python -m scripts.chat_web
|
||||
|
||||
python -m nanochat.report generate
|
||||
|
|
@ -26,8 +26,8 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
grad_slices = []
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly
|
||||
for base_i in range(len(params)):
|
||||
assert params[base_i].shape[0] % world_size == 0, f"First dim of parameter shape {params[base_i].shape} must be divisible by world size {world_size}"
|
||||
grad = params[base_i].grad
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
|
|
|
|||
|
|
@ -20,37 +20,37 @@ def log0(message):
|
|||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logger.info(message)
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
|
||||
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
# Save the model state (parameters)
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
torch.save(model_data, model_path)
|
||||
log0(f"Saved model file to: {model_path}")
|
||||
# Save the optimizer state (useful for SFT or any other fine-tuning)
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
# Save the model state parameters
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
torch.save(model_data, model_path)
|
||||
logger.info(f"Saved model parameters to: {model_path}")
|
||||
# Save the metadata dict as json
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
logger.info(f"Saved metadata to: {meta_path}")
|
||||
# Note that optimizer state is sharded across ranks, so each rank must save its own.
|
||||
if optimizer_data is not None:
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||
torch.save(optimizer_data, optimizer_path)
|
||||
log0(f"Saved optimizer file to: {optimizer_path}")
|
||||
# Save the metadata dict as json
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
log0(f"Saved metadata file to: {meta_path}")
|
||||
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
|
||||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
||||
# Load the model state
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
model_data = torch.load(model_path, map_location=device)
|
||||
# Load the optimizer state if requested
|
||||
optimizer_data = None
|
||||
if load_optimizer:
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
# Load the metadata
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "r") as f:
|
||||
with open(meta_path, "r", encoding="utf-8") as f:
|
||||
meta_data = json.load(f)
|
||||
return model_data, optimizer_data, meta_data
|
||||
|
||||
|
|
@ -65,8 +65,14 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
"""
|
||||
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
||||
if device.type in {"cpu", "mps"}:
|
||||
# Convert bfloat16 tensors to float for CPU inference
|
||||
model_data = {
|
||||
k: v.float() if v.dtype == torch.bfloat16 else v
|
||||
for k, v in model_data.items()
|
||||
}
|
||||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
|
||||
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
||||
model_config_kwargs = meta_data["model_config"]
|
||||
log0(f"Building model with config: {model_config_kwargs}")
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
|
|
@ -88,11 +94,11 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
return model, tokenizer, meta_data
|
||||
|
||||
|
||||
def find_largest_model(checkpoint_dir):
|
||||
def find_largest_model(checkpoints_dir):
|
||||
# attempt to guess the model tag: take the biggest model available
|
||||
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
|
||||
model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
|
||||
if not model_tags:
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
|
||||
# 1) normally all model tags are of the form d<number>, try that first:
|
||||
candidates = []
|
||||
for model_tag in model_tags:
|
||||
|
|
@ -104,7 +110,7 @@ def find_largest_model(checkpoint_dir):
|
|||
candidates.sort(key=lambda x: x[0], reverse=True)
|
||||
return candidates[0][1]
|
||||
# 2) if that failed, take the most recently updated model:
|
||||
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
|
||||
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
|
||||
return model_tags[0]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,10 @@ Common utilities for nanochat.
|
|||
import os
|
||||
import re
|
||||
import logging
|
||||
import urllib.request
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from filelock import FileLock
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""Custom formatter that adds colors to log messages."""
|
||||
|
|
@ -56,6 +58,42 @@ def get_base_dir():
|
|||
os.makedirs(nanochat_dir, exist_ok=True)
|
||||
return nanochat_dir
|
||||
|
||||
def download_file_with_lock(url, filename, postprocess_fn=None):
|
||||
"""
|
||||
Downloads a file from a URL to a local path in the base directory.
|
||||
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
||||
"""
|
||||
base_dir = get_base_dir()
|
||||
file_path = os.path.join(base_dir, filename)
|
||||
lock_path = file_path + ".lock"
|
||||
|
||||
if os.path.exists(file_path):
|
||||
return file_path
|
||||
|
||||
with FileLock(lock_path):
|
||||
# Only a single rank can acquire this lock
|
||||
# All other ranks block until it is released
|
||||
|
||||
# Recheck after acquiring lock
|
||||
if os.path.exists(file_path):
|
||||
return file_path
|
||||
|
||||
# Download the content as bytes
|
||||
print(f"Downloading {url}...")
|
||||
with urllib.request.urlopen(url) as response:
|
||||
content = response.read() # bytes
|
||||
|
||||
# Write to local file
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
print(f"Downloaded to {file_path}")
|
||||
|
||||
# Run the postprocess function if provided
|
||||
if postprocess_fn is not None:
|
||||
postprocess_fn(file_path)
|
||||
|
||||
return file_path
|
||||
|
||||
def print0(s="",**kwargs):
|
||||
ddp_rank = int(os.environ.get('RANK', 0))
|
||||
if ddp_rank == 0:
|
||||
|
|
@ -64,23 +102,35 @@ def print0(s="",**kwargs):
|
|||
def print_banner():
|
||||
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
|
||||
banner = """
|
||||
█████ █████
|
||||
░░███ ░░███
|
||||
████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
|
||||
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███ ░░░███░
|
||||
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
|
||||
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
|
||||
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░████████ ░░█████
|
||||
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
|
||||
"""
|
||||
█████ █████
|
||||
░░███ ░░███
|
||||
████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
|
||||
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
|
||||
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
|
||||
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
|
||||
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
|
||||
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
|
||||
"""
|
||||
print0(banner)
|
||||
|
||||
def is_ddp():
|
||||
# TODO is there a proper way
|
||||
return int(os.environ.get('RANK', -1)) != -1
|
||||
def is_ddp_requested() -> bool:
|
||||
"""
|
||||
True if launched by torchrun (env present), even before init.
|
||||
Used to decide whether we *should* initialize a PG.
|
||||
"""
|
||||
return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
|
||||
|
||||
def is_ddp_initialized() -> bool:
|
||||
"""
|
||||
True if torch.distributed is available and the process group is initialized.
|
||||
Used at cleanup to avoid destroying a non-existent PG.
|
||||
"""
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
def get_dist_info():
|
||||
if is_ddp():
|
||||
if is_ddp_requested():
|
||||
# We rely on torchrun's env to decide if we SHOULD init.
|
||||
# (Initialization itself happens in compute init.)
|
||||
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
||||
ddp_rank = int(os.environ['RANK'])
|
||||
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
||||
|
|
@ -89,41 +139,57 @@ def get_dist_info():
|
|||
else:
|
||||
return False, 0, 0, 1
|
||||
|
||||
def compute_init():
|
||||
def autodetect_device_type():
|
||||
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
||||
if torch.cuda.is_available():
|
||||
device_type = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
device_type = "mps"
|
||||
else:
|
||||
device_type = "cpu"
|
||||
print0(f"Autodetected device type: {device_type}")
|
||||
return device_type
|
||||
|
||||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
# CUDA is currently required
|
||||
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
||||
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
||||
if device_type == "cuda":
|
||||
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
||||
if device_type == "mps":
|
||||
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
|
||||
# Reproducibility
|
||||
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
|
||||
# The only place where global rng might be used is nn.Module initialization of the model weights.
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
if device_type == "cuda":
|
||||
torch.cuda.manual_seed(42)
|
||||
# skipping full reproducibility for now, possibly investigate slowdown later
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
|
||||
# Precision
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
if device_type == "cuda":
|
||||
torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls
|
||||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if ddp:
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if is_ddp_requested and device_type == "cuda":
|
||||
device = torch.device("cuda", ddp_local_rank)
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
else:
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(device_type) # mps|cpu
|
||||
|
||||
if ddp_rank == 0:
|
||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||
|
||||
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
|
||||
return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
|
||||
|
||||
def compute_cleanup():
|
||||
"""Companion function to compute_init, to clean things up before script exit"""
|
||||
if is_ddp():
|
||||
if is_ddp_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
class DummyWandb:
|
||||
|
|
|
|||
|
|
@ -1,49 +1,94 @@
|
|||
from collections import deque
|
||||
|
||||
import torch
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from nanochat.common import get_dist_info
|
||||
from nanochat.dataset import parquets_iter_batched
|
||||
from nanochat.dataset import list_parquet_files
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
|
||||
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
|
||||
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||
"""
|
||||
Stream pretraining text from parquet files, tokenize, yield training batches.
|
||||
|
||||
This implementation became a bit more complex because we wish to support approximate resume training.
|
||||
Instead of turning this into a Class, we opt to return the state_dict with every batch,
|
||||
and then the caller can pass in a state_dict to resume training from a desired point.
|
||||
Note that this resumption is atm only *approximate* for simplicity.
|
||||
We won't repeat the same documents but we might skip a few.
|
||||
The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
|
||||
|
||||
Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
|
||||
# infinite iterator over document batches (list of text strings)
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
def document_batches():
|
||||
parquet_paths = list_parquet_files()
|
||||
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
first_pass = True
|
||||
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
|
||||
while True: # iterate infinitely (multi-epoch)
|
||||
pq_idx = resume_pq_idx if first_pass else 0
|
||||
while pq_idx < len(parquet_paths): # iterate over all parquet files
|
||||
filepath = parquet_paths[pq_idx]
|
||||
pf = pq.ParquetFile(filepath)
|
||||
# Start from resume point if resuming on same file, otherwise from DDP rank
|
||||
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
|
||||
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
|
||||
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
|
||||
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
|
||||
rg_idx = base_idx * ddp_world_size + ddp_rank
|
||||
if rg_idx >= pf.num_row_groups:
|
||||
pq_idx += 1
|
||||
continue
|
||||
resume_rg_idx = None # set to None as we only want to do this a single time
|
||||
else:
|
||||
rg_idx = ddp_rank
|
||||
while rg_idx < pf.num_row_groups:
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
|
||||
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
|
||||
rg_idx += ddp_world_size # advance to the next row group (in DDP)
|
||||
pq_idx += 1 # advance to the next parquet file
|
||||
first_pass = False
|
||||
batches = document_batches()
|
||||
|
||||
# Now emit batches of tokens.
|
||||
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
||||
# get the tokenizer and the bos token
|
||||
tokenizer = get_tokenizer()
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
|
||||
# infinite iterator over document batches
|
||||
def document_batches():
|
||||
while True:
|
||||
# batch will iterate in group size of the parquet files, usually e.g. 1024 rows
|
||||
for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
|
||||
# for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i:i+tokenizer_batch_size]
|
||||
batches = document_batches()
|
||||
|
||||
batch_index = 0
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding.
|
||||
while len(token_buffer) < needed_tokens:
|
||||
doc_batch = next(batches)
|
||||
doc_batch, (pq_idx, rg_idx) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
for tokens in token_lists:
|
||||
token_buffer.extend(tokens)
|
||||
batch_index += 1
|
||||
# Move tokens from the deque into the scratch buffer
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
|
||||
use_cuda_optimizations = device == "cuda"
|
||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
|
||||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
inputs_cpu = scratch[:-1]
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
|
||||
yield inputs, targets, state_dict
|
||||
|
||||
def tokenizing_distributed_data_loader(*args, **kwargs):
|
||||
# helper function that only emits the inputs/targets and not the state_dict
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
|
||||
yield inputs, targets
|
||||
|
|
|
|||
|
|
@ -17,8 +17,9 @@ import signal
|
|||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from collections import deque
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from contextlib import nullcontext
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Calculator tool helpers
|
||||
|
|
@ -37,19 +38,45 @@ def eval_with_timeout(formula, max_time=3):
|
|||
with timeout(max_time, formula):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", SyntaxWarning)
|
||||
return eval(formula)
|
||||
return eval(formula, {"__builtins__": {}}, {})
|
||||
except Exception as e:
|
||||
signal.alarm(0)
|
||||
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
|
||||
return None
|
||||
|
||||
def use_calculator(expr):
|
||||
"""Evaluate a math expression safely."""
|
||||
"""
|
||||
Evaluate a Python expression safely.
|
||||
Supports both math expressions and string operations like .count()
|
||||
"""
|
||||
# Remove commas from numbers
|
||||
expr = expr.replace(",", "")
|
||||
if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars
|
||||
|
||||
# Check if it's a pure math expression (old behavior)
|
||||
if all([x in "0123456789*+-/.() " for x in expr]):
|
||||
if "**" in expr: # disallow power operator
|
||||
return None
|
||||
return eval_with_timeout(expr)
|
||||
|
||||
# Check if it's a string operation we support
|
||||
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
|
||||
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
||||
if not all([x in allowed_chars for x in expr]):
|
||||
return None
|
||||
if "**" in expr: # for now disallow power operator, could be very expensive
|
||||
|
||||
# Disallow dangerous patterns
|
||||
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
|
||||
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
|
||||
'getattr', 'setattr', 'delattr', 'hasattr']
|
||||
expr_lower = expr.lower()
|
||||
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
||||
return None
|
||||
|
||||
# Only allow .count() method for now (can expand later)
|
||||
if '.count(' not in expr:
|
||||
return None
|
||||
|
||||
# Evaluate with timeout
|
||||
return eval_with_timeout(expr)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -80,16 +107,23 @@ class KVCache:
|
|||
# 1) validate the shapes
|
||||
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
|
||||
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
|
||||
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
|
||||
if ix in [0, 1, 3, 5]:
|
||||
# num_layers, batch_size, num_heads, head_dim must match
|
||||
assert dim1 == dim2, f"Batch dim mismatch: {dim1} != {dim2}"
|
||||
elif ix == 2:
|
||||
# batch_size can be expanded
|
||||
assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
|
||||
elif ix == 4:
|
||||
# seq_len: self must be longer than other
|
||||
assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
|
||||
|
||||
# Extract dimensions explicitly
|
||||
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
|
||||
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
|
||||
|
||||
# Validate dimensions
|
||||
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
|
||||
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
|
||||
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
|
||||
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
|
||||
|
||||
# Batch size can be expanded (other can be 1, self can be larger)
|
||||
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
|
||||
|
||||
# Sequence length: self must be longer than other
|
||||
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
|
||||
|
||||
# 2) initialize the cache
|
||||
dtype, device = other.kv_cache.dtype, other.kv_cache.device
|
||||
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
|
||||
|
|
@ -109,15 +143,17 @@ class KVCache:
|
|||
if t1 > self.kv_cache.size(4):
|
||||
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
||||
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
||||
current_shape = list(self.kv_cache.shape)
|
||||
current_shape[4] = t_needed
|
||||
self.kv_cache.resize_(current_shape)
|
||||
additional_shape = list(self.kv_cache.shape)
|
||||
additional_shape[4] = t_needed - self.kv_cache.size(4)
|
||||
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
|
||||
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
|
||||
self.kv_shape = self.kv_cache.shape
|
||||
# Insert k, v into the cache
|
||||
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
|
||||
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
|
||||
self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k
|
||||
self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v
|
||||
# Return the full cached keys/values up to current position (as a view)
|
||||
key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
|
||||
value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
|
||||
key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
|
||||
value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]
|
||||
# Increment pos after the last layer of the Transformer processes
|
||||
if layer_idx == self.kv_cache.size(0) - 1:
|
||||
self.pos = t1
|
||||
|
|
@ -187,9 +223,7 @@ class Engine:
|
|||
)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
||||
logits = logits[:, -1, :]
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
|
||||
|
||||
# 2) Replicate the KV cache for each sample/row
|
||||
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
||||
|
|
@ -206,7 +240,6 @@ class Engine:
|
|||
|
||||
# 4) Main generation loop
|
||||
num_generated = 0
|
||||
first_iteration = True
|
||||
while True:
|
||||
# Stop condition: we've reached max tokens
|
||||
if max_tokens is not None and num_generated >= max_tokens:
|
||||
|
|
@ -215,18 +248,9 @@ class Engine:
|
|||
if all(state.completed for state in row_states):
|
||||
break
|
||||
|
||||
# Get sampled tokens - either from prefill or from forward pass
|
||||
if first_iteration:
|
||||
# Use the tokens we already sampled from prefill
|
||||
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
|
||||
# TODO: we should sample a token for each row instead of broadcasting
|
||||
first_iteration = False
|
||||
else:
|
||||
# Forward the model and get the next token for each row
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
|
||||
logits = logits[:, -1, :] # (B, vocab_size) at last time step
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
# Sample the next token for each row
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
|
||||
# Process each row: choose the next token, update state, optional tool use
|
||||
token_column = [] # contains the next token id along each row
|
||||
|
|
@ -263,8 +287,10 @@ class Engine:
|
|||
# Yield the token column
|
||||
yield token_column, token_masks
|
||||
num_generated += 1
|
||||
# Prepare ids for next iteration
|
||||
|
||||
# Prepare logits for next iteration
|
||||
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
|
||||
|
||||
def generate_batch(self, tokens, num_samples=1, **kwargs):
|
||||
"""
|
||||
|
|
@ -299,6 +325,9 @@ if __name__ == "__main__":
|
|||
import time
|
||||
# init compute
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type()
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
bos_token_id = tokenizer.get_bos_token_id()
|
||||
|
|
@ -311,10 +340,11 @@ if __name__ == "__main__":
|
|||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
stream = model.generate(prompt_tokens, **kwargs)
|
||||
for token in stream:
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
with autocast_ctx:
|
||||
for token in stream:
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
print()
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
|
|
@ -326,11 +356,12 @@ if __name__ == "__main__":
|
|||
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for token_column, token_masks in stream:
|
||||
token = token_column[0] # only print out the first row
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in stream:
|
||||
token = token_column[0] # only print out the first row
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
print()
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
|
|
|
|||
|
|
@ -127,8 +127,6 @@ def chdir(root):
|
|||
os.chdir(root)
|
||||
try:
|
||||
yield
|
||||
except BaseException as exc:
|
||||
raise exc
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
|
|
@ -146,13 +144,12 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
|||
with caution.
|
||||
"""
|
||||
|
||||
if maximum_memory_bytes is not None:
|
||||
if platform.uname().system != "Darwin":
|
||||
# These resource limit calls seem to fail on macOS (Darwin), skip?
|
||||
import resource
|
||||
|
||||
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
if not platform.uname().system == "Darwin":
|
||||
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
|
||||
faulthandler.disable()
|
||||
|
||||
|
|
@ -225,6 +222,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
|||
rmtree = shutil.rmtree
|
||||
rmdir = os.rmdir
|
||||
chdir = os.chdir
|
||||
unlink = os.unlink
|
||||
|
||||
# Disable functionalities that can make destructive changes to the test.
|
||||
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
|
||||
|
|
@ -282,6 +280,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
|||
shutil.rmtree = rmtree
|
||||
os.rmdir = rmdir
|
||||
os.chdir = chdir
|
||||
os.unlink = unlink
|
||||
|
||||
|
||||
def execute_code(
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ Notable features:
|
|||
- norm after token embedding
|
||||
- no learnable params in rmsnorm
|
||||
- no bias in linear layers
|
||||
- Multi-Query Attention (MQA) support for more efficient inference
|
||||
- Group-Query Attention (GQA) support for more efficient inference
|
||||
"""
|
||||
|
||||
import math
|
||||
|
|
@ -29,7 +29,7 @@ class GPTConfig:
|
|||
vocab_size: int = 50304
|
||||
n_layer: int = 12
|
||||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (MQA)
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_embd: int = 768
|
||||
|
||||
|
||||
|
|
@ -41,25 +41,10 @@ def norm(x):
|
|||
def apply_rotary_emb(x, cos, sin):
|
||||
assert x.ndim == 4 # multihead attention
|
||||
d = x.shape[3] // 2
|
||||
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
||||
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
|
||||
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
||||
y2 = x1 * (-sin) + x2 * cos
|
||||
out = torch.cat([y1, y2], 3) # re-assemble
|
||||
out = out.to(x.dtype) # ensure input/output dtypes match
|
||||
return out
|
||||
|
||||
|
||||
def repeat_kv(x, n_rep):
|
||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||
if n_rep == 1:
|
||||
return x
|
||||
bs, n_kv_heads, slen, head_dim = x.shape
|
||||
return (
|
||||
x[:, :, None, :, :]
|
||||
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
||||
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
||||
)
|
||||
|
||||
return torch.cat([y1, y2], 3)
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
|
|
@ -96,29 +81,25 @@ class CausalSelfAttention(nn.Module):
|
|||
Tq = q.size(2) # number of queries in this forward pass
|
||||
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
||||
|
||||
# Apply MQA: replicate the key/value heads for each query head
|
||||
nrep = self.n_head // self.n_kv_head
|
||||
k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
|
||||
|
||||
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
||||
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
||||
if kv_cache is None or Tq == Tk:
|
||||
# During training (no KV cache), attend as usual with causal attention
|
||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
elif Tq == 1:
|
||||
# During inference but with a single query in this forward pass:
|
||||
# The query has to attend to all the keys/values in the cache
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
else:
|
||||
# During inference AND we have a chunk of queries in this forward pass:
|
||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
||||
prefix_len = Tk - Tq
|
||||
if prefix_len > 0: # can't be negative but could be zero
|
||||
attn_mask[:, :prefix_len] = True
|
||||
attn_mask[:, :prefix_len] = True
|
||||
# Then, causal attention within this chunk
|
||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
||||
|
||||
# Re-assemble the heads side by side and project back to residual stream
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
||||
|
|
@ -152,14 +133,19 @@ class Block(nn.Module):
|
|||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, pad_vocab_size_to=64):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see:
|
||||
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
||||
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
||||
if padded_vocab_size != config.vocab_size:
|
||||
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}")
|
||||
self.transformer = nn.ModuleDict({
|
||||
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
||||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's fake
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them, but assert fail if we ever reach that amount.
|
||||
|
|
@ -169,8 +155,6 @@ class GPT(nn.Module):
|
|||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def init_weights(self):
|
||||
self.apply(self._init_weights)
|
||||
|
|
@ -184,6 +168,9 @@ class GPT(nn.Module):
|
|||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
|
|
@ -236,8 +223,7 @@ class GPT(nn.Module):
|
|||
# Create the AdamW optimizer for the embedding and lm_head
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
if rank == 0:
|
||||
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
adam_groups = [
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
|
|
@ -259,7 +245,7 @@ class GPT(nn.Module):
|
|||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
B, T = idx.size()
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
|
|
@ -275,19 +261,19 @@ class GPT(nn.Module):
|
|||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
softcap = 15
|
||||
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
||||
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
||||
logits = logits[..., :self.config.vocab_size] # slice to remove padding
|
||||
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
||||
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
||||
|
||||
if targets is not None:
|
||||
# training mode: compute and return the loss
|
||||
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
||||
logits = self.lm_head(x)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
logits = logits.float() # use tf32/fp32 for logits
|
||||
# training: given the targets, compute and return the loss
|
||||
# TODO experiment with chunked cross-entropy?
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
return loss
|
||||
else:
|
||||
# inference mode: compute and return the logits
|
||||
logits = self.lm_head(x)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
# inference: just return the logits directly
|
||||
return logits
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ import torch.distributed as dist
|
|||
def evaluate_bpb(model, batches, steps, token_bytes):
|
||||
"""
|
||||
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
|
||||
which is a tokenization vocab size-indepedent metric, meaning you are still comparing
|
||||
which is a tokenization vocab size-independent metric, meaning you are still comparing
|
||||
apples:apples if you change the vocab size. The way this works is that instead of just
|
||||
calculating the average loss as usual, you calculate the sum loss, and indepependently
|
||||
calculating the average loss as usual, you calculate the sum loss, and independently
|
||||
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
|
||||
the number of bytes that the target tokens represent.
|
||||
|
||||
|
|
@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
y = y.view(-1) # flatten
|
||||
if (y < 0).any():
|
||||
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
||||
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
||||
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
||||
valid = y >= 0
|
||||
|
|
@ -59,5 +59,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
# move both to cpu, calculate bpb and return
|
||||
total_nats = total_nats.item()
|
||||
total_bytes = total_bytes.item()
|
||||
if total_bytes == 0:
|
||||
return float('inf')
|
||||
bpb = total_nats / (math.log(2) * total_bytes)
|
||||
return bpb
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ Generated: {timestamp}
|
|||
# count dependencies via uv.lock
|
||||
uv_lock_lines = 0
|
||||
if os.path.exists('uv.lock'):
|
||||
with open('uv.lock', 'r') as f:
|
||||
with open('uv.lock', 'r', encoding='utf-8') as f:
|
||||
uv_lock_lines = len(f.readlines())
|
||||
|
||||
header += f"""
|
||||
|
|
@ -241,7 +241,7 @@ class Report:
|
|||
slug = slugify(section)
|
||||
file_name = f"{slug}.md"
|
||||
file_path = os.path.join(self.report_dir, file_name)
|
||||
with open(file_path, "w") as f:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(f"## {section}\n")
|
||||
f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
for item in data:
|
||||
|
|
@ -272,24 +272,28 @@ class Report:
|
|||
final_metrics = {} # the most important final metrics we'll add as table at the end
|
||||
start_time = None
|
||||
end_time = None
|
||||
with open(report_file, "w") as out_file:
|
||||
with open(report_file, "w", encoding="utf-8") as out_file:
|
||||
# write the header first
|
||||
header_file = os.path.join(report_dir, "header.md")
|
||||
if os.path.exists(header_file):
|
||||
with open(header_file, "r") as f:
|
||||
with open(header_file, "r", encoding="utf-8") as f:
|
||||
header_content = f.read()
|
||||
out_file.write(header_content)
|
||||
start_time = extract_timestamp(header_content, "Run started:")
|
||||
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
|
||||
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
||||
bloat_data = bloat_data.group(1) if bloat_data else ""
|
||||
else:
|
||||
start_time = None # will cause us to not write the total wall clock time
|
||||
bloat_data = "[bloat data missing]"
|
||||
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
|
||||
# process all the individual sections
|
||||
for file_name in EXPECTED_FILES:
|
||||
section_file = os.path.join(report_dir, file_name)
|
||||
if not os.path.exists(section_file):
|
||||
print(f"Warning: {section_file} does not exist, skipping")
|
||||
continue
|
||||
with open(section_file, "r") as in_file:
|
||||
with open(section_file, "r", encoding="utf-8") as in_file:
|
||||
section = in_file.read()
|
||||
# Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
|
||||
if "rl" not in file_name:
|
||||
|
|
@ -369,7 +373,7 @@ class Report:
|
|||
header_file = os.path.join(self.report_dir, "header.md")
|
||||
header = generate_header()
|
||||
start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(header_file, "w") as f:
|
||||
with open(header_file, "w", encoding="utf-8") as f:
|
||||
f.write(header)
|
||||
f.write(f"Run started: {start_time}\n\n---\n\n")
|
||||
print(f"Reset report and wrote header to {header_file}")
|
||||
|
|
|
|||
|
|
@ -341,16 +341,19 @@ class RustBPETokenizer:
|
|||
mask = mask[:max_tokens]
|
||||
return ids, mask
|
||||
|
||||
def visualize_tokenization(self, ids, mask):
|
||||
def visualize_tokenization(self, ids, mask, with_token_id=False):
|
||||
"""Small helper function useful in debugging: visualize the tokenization of render_conversation"""
|
||||
RED = '\033[91m'
|
||||
GREEN = '\033[92m'
|
||||
RESET = '\033[0m'
|
||||
GRAY = '\033[90m'
|
||||
tokens = []
|
||||
for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
|
||||
token_str = self.decode([token_id])
|
||||
color = GREEN if mask_val == 1 else RED
|
||||
tokens.append(f"{color}{token_str}{RESET}")
|
||||
if with_token_id:
|
||||
tokens.append(f"{GRAY}({token_id}){RESET}")
|
||||
return '|'.join(tokens)
|
||||
|
||||
def render_for_completion(self, conversation):
|
||||
|
|
|
|||
194
nanochat/ui.html
194
nanochat/ui.html
|
|
@ -2,7 +2,7 @@
|
|||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
|
||||
<title>NanoChat</title>
|
||||
<link rel="icon" type="image/svg+xml" href="/logo.svg">
|
||||
<style>
|
||||
|
|
@ -18,7 +18,7 @@
|
|||
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
|
||||
background-color: #ffffff;
|
||||
color: #111827;
|
||||
min-height: 100vh;
|
||||
min-height: 100dvh;
|
||||
margin: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
|
@ -108,6 +108,15 @@
|
|||
background: transparent;
|
||||
border: none;
|
||||
padding: 0.25rem 0;
|
||||
cursor: pointer;
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.5rem;
|
||||
margin-left: -0.5rem;
|
||||
transition: background-color 0.2s ease;
|
||||
}
|
||||
|
||||
.message.assistant .message-content:hover {
|
||||
background-color: #f9fafb;
|
||||
}
|
||||
|
||||
.message.user .message-content {
|
||||
|
|
@ -115,11 +124,27 @@
|
|||
border-radius: 1.25rem;
|
||||
padding: 0.8rem 1rem;
|
||||
max-width: 65%;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.2s ease;
|
||||
}
|
||||
|
||||
.message.user .message-content:hover {
|
||||
background-color: #e5e7eb;
|
||||
}
|
||||
|
||||
.message.console .message-content {
|
||||
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
|
||||
font-size: 0.875rem;
|
||||
background-color: #fafafa;
|
||||
padding: 0.75rem 1rem;
|
||||
color: #374151;
|
||||
max-width: 80%;
|
||||
}
|
||||
|
||||
.input-container {
|
||||
background-color: #ffffff;
|
||||
padding: 1rem;
|
||||
padding-bottom: calc(1rem + env(safe-area-inset-bottom))
|
||||
}
|
||||
|
||||
.input-wrapper {
|
||||
|
|
@ -255,6 +280,8 @@
|
|||
|
||||
let messages = [];
|
||||
let isGenerating = false;
|
||||
let currentTemperature = 0.8;
|
||||
let currentTopK = 50;
|
||||
|
||||
chatInput.addEventListener('input', function() {
|
||||
this.style.height = 'auto';
|
||||
|
|
@ -289,7 +316,7 @@
|
|||
chatInput.focus();
|
||||
}
|
||||
|
||||
function addMessage(role, content) {
|
||||
function addMessage(role, content, messageIndex = null) {
|
||||
const messageDiv = document.createElement('div');
|
||||
messageDiv.className = `message ${role}`;
|
||||
|
||||
|
|
@ -297,6 +324,28 @@
|
|||
contentDiv.className = 'message-content';
|
||||
contentDiv.textContent = content;
|
||||
|
||||
// Add click handler for user messages to enable editing
|
||||
if (role === 'user' && messageIndex !== null) {
|
||||
contentDiv.setAttribute('data-message-index', messageIndex);
|
||||
contentDiv.setAttribute('title', 'Click to edit and restart from here');
|
||||
contentDiv.addEventListener('click', function() {
|
||||
if (!isGenerating) {
|
||||
editMessage(messageIndex);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Add click handler for assistant messages to enable regeneration
|
||||
if (role === 'assistant' && messageIndex !== null) {
|
||||
contentDiv.setAttribute('data-message-index', messageIndex);
|
||||
contentDiv.setAttribute('title', 'Click to regenerate this response');
|
||||
contentDiv.addEventListener('click', function() {
|
||||
if (!isGenerating) {
|
||||
regenerateMessage(messageIndex);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
messageDiv.appendChild(contentDiv);
|
||||
chatWrapper.appendChild(messageDiv);
|
||||
|
||||
|
|
@ -304,17 +353,35 @@
|
|||
return contentDiv;
|
||||
}
|
||||
|
||||
async function sendMessage() {
|
||||
const message = chatInput.value.trim();
|
||||
if (!message || isGenerating) return;
|
||||
function editMessage(messageIndex) {
|
||||
// Find the message in the messages array
|
||||
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
||||
|
||||
isGenerating = true;
|
||||
chatInput.value = '';
|
||||
const messageToEdit = messages[messageIndex];
|
||||
if (messageToEdit.role !== 'user') return;
|
||||
|
||||
// Copy message content to input
|
||||
chatInput.value = messageToEdit.content;
|
||||
chatInput.style.height = 'auto';
|
||||
sendButton.disabled = true;
|
||||
chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
|
||||
|
||||
messages.push({ role: 'user', content: message });
|
||||
addMessage('user', message);
|
||||
// Remove this message and all subsequent messages from the array
|
||||
messages = messages.slice(0, messageIndex);
|
||||
|
||||
// Remove message elements from DOM starting from messageIndex
|
||||
const allMessages = chatWrapper.querySelectorAll('.message');
|
||||
for (let i = messageIndex; i < allMessages.length; i++) {
|
||||
allMessages[i].remove();
|
||||
}
|
||||
|
||||
// Enable send button and focus input
|
||||
sendButton.disabled = false;
|
||||
chatInput.focus();
|
||||
}
|
||||
|
||||
async function generateAssistantResponse() {
|
||||
isGenerating = true;
|
||||
sendButton.disabled = true;
|
||||
|
||||
const assistantContent = addMessage('assistant', '');
|
||||
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
|
||||
|
|
@ -327,8 +394,8 @@
|
|||
},
|
||||
body: JSON.stringify({
|
||||
messages: messages,
|
||||
stream: true,
|
||||
temperature: 0.8,
|
||||
temperature: currentTemperature,
|
||||
top_k: currentTopK,
|
||||
max_tokens: 512
|
||||
}),
|
||||
});
|
||||
|
|
@ -364,8 +431,18 @@
|
|||
}
|
||||
}
|
||||
|
||||
const assistantMessageIndex = messages.length;
|
||||
messages.push({ role: 'assistant', content: fullResponse });
|
||||
|
||||
// Add click handler to regenerate this assistant message
|
||||
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
|
||||
assistantContent.setAttribute('title', 'Click to regenerate this response');
|
||||
assistantContent.addEventListener('click', function() {
|
||||
if (!isGenerating) {
|
||||
regenerateMessage(assistantMessageIndex);
|
||||
}
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
|
||||
|
|
@ -375,6 +452,97 @@
|
|||
}
|
||||
}
|
||||
|
||||
async function regenerateMessage(messageIndex) {
|
||||
// Find the message in the messages array
|
||||
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
||||
|
||||
const messageToRegenerate = messages[messageIndex];
|
||||
if (messageToRegenerate.role !== 'assistant') return;
|
||||
|
||||
// Remove this message and all subsequent messages from the array
|
||||
messages = messages.slice(0, messageIndex);
|
||||
|
||||
// Remove message elements from DOM starting from messageIndex
|
||||
const allMessages = chatWrapper.querySelectorAll('.message');
|
||||
for (let i = messageIndex; i < allMessages.length; i++) {
|
||||
allMessages[i].remove();
|
||||
}
|
||||
|
||||
// Regenerate the assistant response
|
||||
await generateAssistantResponse();
|
||||
}
|
||||
|
||||
function handleSlashCommand(command) {
|
||||
const parts = command.trim().split(/\s+/);
|
||||
const cmd = parts[0].toLowerCase();
|
||||
const arg = parts[1];
|
||||
|
||||
if (cmd === '/temperature') {
|
||||
if (arg === undefined) {
|
||||
addMessage('console', `Current temperature: ${currentTemperature}`);
|
||||
} else {
|
||||
const temp = parseFloat(arg);
|
||||
if (isNaN(temp) || temp < 0 || temp > 2) {
|
||||
addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
|
||||
} else {
|
||||
currentTemperature = temp;
|
||||
addMessage('console', `Temperature set to ${currentTemperature}`);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else if (cmd === '/topk') {
|
||||
if (arg === undefined) {
|
||||
addMessage('console', `Current top-k: ${currentTopK}`);
|
||||
} else {
|
||||
const topk = parseInt(arg);
|
||||
if (isNaN(topk) || topk < 1 || topk > 200) {
|
||||
addMessage('console', 'Invalid top-k. Must be between 1 and 200');
|
||||
} else {
|
||||
currentTopK = topk;
|
||||
addMessage('console', `Top-k set to ${currentTopK}`);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else if (cmd === '/clear') {
|
||||
newConversation();
|
||||
return true;
|
||||
} else if (cmd === '/help') {
|
||||
addMessage('console',
|
||||
'Available commands:\n' +
|
||||
'/temperature - Show current temperature\n' +
|
||||
'/temperature <value> - Set temperature (0.0-2.0)\n' +
|
||||
'/topk - Show current top-k\n' +
|
||||
'/topk <value> - Set top-k (1-200)\n' +
|
||||
'/clear - Clear conversation\n' +
|
||||
'/help - Show this help message'
|
||||
);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
async function sendMessage() {
|
||||
const message = chatInput.value.trim();
|
||||
if (!message || isGenerating) return;
|
||||
|
||||
// Handle slash commands
|
||||
if (message.startsWith('/')) {
|
||||
chatInput.value = '';
|
||||
chatInput.style.height = 'auto';
|
||||
handleSlashCommand(message);
|
||||
return;
|
||||
}
|
||||
|
||||
chatInput.value = '';
|
||||
chatInput.style.height = 'auto';
|
||||
|
||||
const userMessageIndex = messages.length;
|
||||
messages.push({ role: 'user', content: message });
|
||||
addMessage('user', message, userMessageIndex);
|
||||
|
||||
await generateAssistantResponse();
|
||||
}
|
||||
|
||||
sendButton.disabled = false;
|
||||
|
||||
// Autofocus the chat input on page load
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ dependencies = [
|
|||
"datasets>=4.0.0",
|
||||
"fastapi>=0.117.1",
|
||||
"files-to-prompt>=0.6",
|
||||
"numpy==1.26.4",
|
||||
"psutil>=7.1.0",
|
||||
"regex>=2025.9.1",
|
||||
"setuptools>=80.9.0",
|
||||
"tiktoken>=0.11.0",
|
||||
"tokenizers>=0.22.0",
|
||||
"torch>=2.8.0",
|
||||
|
|
@ -22,17 +22,6 @@ dependencies = [
|
|||
requires = ["maturin>=1.7,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
# target torch to cuda 12.8
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cu128" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[tool.maturin]
|
||||
module-name = "rustbpe"
|
||||
bindings = "pyo3"
|
||||
|
|
@ -53,3 +42,36 @@ testpaths = ["tests"]
|
|||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
|
||||
# target torch to cuda 12.8 or CPU
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", extra = "cpu" },
|
||||
{ index = "pytorch-cu128", extra = "gpu" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[project.optional-dependencies]
|
||||
cpu = [
|
||||
"torch>=2.8.0",
|
||||
]
|
||||
gpu = [
|
||||
"torch>=2.8.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "cpu" },
|
||||
{ extra = "gpu" },
|
||||
],
|
||||
]
|
||||
94
run1000.sh
Normal file
94
run1000.sh
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
#!/bin/bash
|
||||
|
||||
# The $1000 tier of nanochat
|
||||
# Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
|
||||
# A bit sparser on comments, see speedrun.sh for more detail
|
||||
|
||||
# all the setup stuff
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
python -m nanochat.report reset
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
# train tokenizer on ~4B characters and kick off download of the rest for pretraining
|
||||
python -m nanochat.dataset -n 16
|
||||
# start downloading the rest of the shards for a total of 800 (see below why 800)
|
||||
python -m nanochat.dataset -n 800 &
|
||||
# todo: download the rest of it
|
||||
python -m scripts.tok_train --max_chars=4000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# Documenting my process for determining the hyperparameters for this run1000.sh script:
|
||||
# We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
|
||||
# 1) I guessed the model size for this to be about depth=32
|
||||
# 2) Determine the device_batch_size that fits:
|
||||
# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16
|
||||
# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training,
|
||||
# I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
|
||||
# So the training script was running ok and showed:
|
||||
# Vocab size: 65,536
|
||||
# num_layers: 32
|
||||
# model_dim: 2048
|
||||
# num_heads: 16
|
||||
# num_kv_heads: 16
|
||||
# Tokens / micro-batch / rank: 8 x 2048 = 16,384
|
||||
# Tokens / micro-batch: 131,072
|
||||
# Total batch size 524,288 => gradient accumulation steps: 4
|
||||
# Number of parameters: 1,879,048,192
|
||||
# Estimated FLOPs per token: 1.207960e+10
|
||||
# Calculated number of iterations from target data:param ratio: 71,680
|
||||
# Total number of training tokens: 37,580,963,840
|
||||
# Tokens : Params ratio: 20.00
|
||||
# Total training FLOPs estimate: 4.539628e+20
|
||||
# step 00004/71680 (0.01%) | loss: 8.813754 | lrm: 1.00 | dt: 1571.88ms | tok/sec: 83,385 | mfu: 50.92 | total time: 0.00m
|
||||
# step 00005/71680 (0.01%) | loss: 8.488074 | lrm: 1.00 | dt: 1572.76ms | tok/sec: 83,338 | mfu: 50.89 | total time: 0.00m
|
||||
# ...
|
||||
# 3) validate that the runtime fits our budget:
|
||||
# The training script uses the Chinchilla scaling law to compute-optimally set #tokens = 20 * #params. In particular:
|
||||
# The script shows that we will be training for 71,680 steps, and each step takes 1.574s so:
|
||||
# estimated time to train: 71,680 * 1.574s / 60 / 60 = 31.3 hours.
|
||||
# This is OK, fits our budget, and leaves ~10 hours for midtraining and SFT and evals and maybe RL.
|
||||
# It's possible that we might even fit depth=33 or depth=34, but for now let's go along with this.
|
||||
# 4) The last thing to pay attention to is the amount of training data required for the run.
|
||||
# The script above calculated that "Total number of training tokens: 37,580,963,840"
|
||||
# The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings.
|
||||
# So ~38B tokens # ~4.8 chars/token = ~185B chars.
|
||||
# Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards.
|
||||
# For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards.
|
||||
# If we didn't have enough data, the training script would loop around and do multiple epochs over the same data,
|
||||
# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
|
||||
# start to overfit hard.
|
||||
# 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script.
|
||||
|
||||
# Number of processes/GPUs to use
|
||||
NPROC_PER_NODE=8
|
||||
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --device_batch_size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||
|
||||
# midtrain
|
||||
# NOTE: ensure that we use the same device_batch_size here as the base training script.
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||
|
||||
# sft
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
|
||||
|
||||
# generate final report
|
||||
python -m nanochat.report generate
|
||||
|
||||
# talk to it
|
||||
python -m scripts.chat_web
|
||||
|
|
@ -292,8 +292,7 @@ impl Tokenizer {
|
|||
|
||||
// Prepare a true Python iterator object
|
||||
let py_iter: pyo3::Py<pyo3::PyAny> = unsafe {
|
||||
pyo3::Bound::from_borrowed_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
|
||||
.into()
|
||||
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
|
||||
};
|
||||
|
||||
// Global chunk counts
|
||||
|
|
@ -466,6 +465,22 @@ impl Tokenizer {
|
|||
|
||||
all_ids
|
||||
}
|
||||
|
||||
/// Encode multiple texts in parallel using rayon.
|
||||
/// Returns a list of token ID vectors, one per input text.
|
||||
#[pyo3(signature = (texts))]
|
||||
#[pyo3(text_signature = "(self, texts)")]
|
||||
pub fn batch_encode(&self, py: Python<'_>, texts: Vec<String>) -> PyResult<Vec<Vec<u32>>> {
|
||||
// Release Python GIL and encode in parallel using rayon
|
||||
let results = py.allow_threads(|| {
|
||||
texts
|
||||
.par_iter()
|
||||
.map(|text| self.encode(text))
|
||||
.collect::<Vec<Vec<u32>>>()
|
||||
});
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
|
|
|
|||
|
|
@ -1,48 +1,76 @@
|
|||
"""
|
||||
Evlauate the CORE metric for a given model.
|
||||
Evaluate the CORE metric for a given model.
|
||||
|
||||
Run on a single GPU:
|
||||
python base_eval.py
|
||||
python -m scripts.base_eval
|
||||
|
||||
Run with torchrun on e.g. 8 GPUs:
|
||||
torchrun --nproc_per_node=8 base_eval.py
|
||||
torchrun --nproc_per_node=8 -m scripts.base_eval
|
||||
|
||||
The script will print the CORE metric to the console.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import csv
|
||||
import time
|
||||
import json
|
||||
import random
|
||||
import yaml
|
||||
import shutil
|
||||
import random
|
||||
import zipfile
|
||||
import tempfile
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import evaluate_task
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# nanoChat specific function dealing with I/O etc.
|
||||
# nanochat specific function dealing with I/O etc.
|
||||
|
||||
# ~162MB of data needed to evaluate the CORE metric
|
||||
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
|
||||
|
||||
def place_eval_bundle(file_path):
|
||||
# here file_path is the path to the eval_bundle.zip file
|
||||
# we need to unzip it and place it in the base directory
|
||||
base_dir = get_base_dir()
|
||||
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(tmpdir)
|
||||
extracted_bundle_dir = os.path.join(tmpdir, "eval_bundle")
|
||||
shutil.move(extracted_bundle_dir, eval_bundle_dir)
|
||||
print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
|
||||
|
||||
def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||
"""
|
||||
Evaluate a base model on the CORE benchmark.
|
||||
- max_per_task: crop the data to this many examples per task for testing (-1 = disable)
|
||||
TODO: clean up this function, delete the need for all the files, for pandas dependency, etc.
|
||||
"""
|
||||
# Load config and task metadata
|
||||
base_dir = get_base_dir()
|
||||
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
||||
# Download the eval bundle to disk (and unzip if needed)
|
||||
if not os.path.exists(eval_bundle_dir):
|
||||
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
|
||||
config_path = os.path.join(eval_bundle_dir, "core.yaml")
|
||||
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
|
||||
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
|
||||
with open(config_path, 'r') as f:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
tasks = config['icl_tasks']
|
||||
eval_metadata = pd.read_csv(eval_meta_data)
|
||||
|
||||
# Load random baseline values from eval metadata
|
||||
random_baselines = {}
|
||||
with open(eval_meta_data, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
task_name = row['Eval Task']
|
||||
random_baseline = row['Random baseline']
|
||||
random_baselines[task_name] = float(random_baseline)
|
||||
|
||||
# Evaluate each task
|
||||
results = {}
|
||||
|
|
@ -60,11 +88,11 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
|
||||
# Load data for this task
|
||||
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
|
||||
with open(data_path, 'r') as f:
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
data = [json.loads(line.strip()) for line in f]
|
||||
|
||||
# shuffle the data because in many cases it appears ordered but we want
|
||||
# the abillity to only run a subset of the data for debugging purposes etc.
|
||||
# the ability to only run a subset of the data for debugging purposes etc.
|
||||
shuffle_rng = random.Random(1337)
|
||||
shuffle_rng.shuffle(data)
|
||||
if max_per_task > 0:
|
||||
|
|
@ -74,8 +102,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
|
||||
|
||||
results[label] = accuracy
|
||||
row = eval_metadata[eval_metadata["Eval Task"] == label]
|
||||
random_baseline = row["Random baseline"].values[0]
|
||||
random_baseline = random_baselines[label]
|
||||
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
||||
centered_results[label] = centered_result
|
||||
end_time = time.time()
|
||||
|
|
@ -118,29 +145,36 @@ def load_hf_model(hf_path: str, device):
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||
parser.add_argument('--model-tag', type=str, default=None, help='optional model tag for the output directory name')
|
||||
parser.add_argument('--step', type=str, default=None, help='optional model step for the output directory name')
|
||||
args = parser.parse_args()
|
||||
|
||||
# distributed / precision setup
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# Load model and tokenizer from command line or from file system
|
||||
if len(sys.argv) >= 2:
|
||||
if args.hf_path is not None:
|
||||
# atm assume that if a path is given, it's a huggingface model path
|
||||
hf_path = sys.argv[1]
|
||||
hf_path = args.hf_path
|
||||
print0(f"Loading huggingface model from: {hf_path}")
|
||||
model, tokenizer = load_hf_model(hf_path, device)
|
||||
model_name = hf_path # just for logging
|
||||
model_slug = hf_path.replace("/", "-") # for the output csv file
|
||||
else:
|
||||
# load a local model from the file system
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
model_name = f"base_model (step {meta['step']})" # just for logging
|
||||
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
||||
|
||||
# Evaluate the model
|
||||
with autocast_ctx:
|
||||
out = evaluate_model(model, tokenizer, device)
|
||||
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
|
||||
# Write out the results to a csv file
|
||||
core_metric = None
|
||||
|
|
@ -152,7 +186,7 @@ def main():
|
|||
results = out["results"]
|
||||
centered_results = out["centered_results"]
|
||||
core_metric = out["core_metric"]
|
||||
with open(output_csv_path, 'w') as f:
|
||||
with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
|
||||
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
|
||||
for label in results:
|
||||
f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n")
|
||||
|
|
@ -161,7 +195,7 @@ def main():
|
|||
print0("="*80)
|
||||
print0(f"Model: {model_name}")
|
||||
print0("="*80)
|
||||
with open(output_csv_path, 'r') as f:
|
||||
with open(output_csv_path, 'r', encoding='utf-8') as f:
|
||||
print0(f.read())
|
||||
|
||||
# Log to report
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@ Example run as:
|
|||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
"""
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, print0, compute_cleanup
|
||||
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -20,15 +21,15 @@ device_batch_size = 32
|
|||
split_tokens = 20*524288 # number of tokens to evaluate per split
|
||||
model_tag = None # optional model tag for the output directory name
|
||||
model_step = None # optional model step for the output directory name
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
|
||||
# Set up the precision we'll run with
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# Evaluate the loss on each split
|
||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||
|
|
@ -37,7 +38,7 @@ steps = split_tokens // tokens_per_step
|
|||
token_bytes = get_token_bytes(device=device)
|
||||
bpb_results = {}
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name)
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||
|
|
|
|||
|
|
@ -6,19 +6,24 @@ python base_train.py
|
|||
or distributed as:
|
||||
|
||||
torchrun --nproc_per_node=8 base_train.py
|
||||
|
||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from scripts.base_eval import evaluate_model
|
||||
|
|
@ -27,6 +32,8 @@ print_banner()
|
|||
# -----------------------------------------------------------------------------
|
||||
# User settings
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# Runtime
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
|
|
@ -42,12 +49,17 @@ unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
|||
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
||||
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
|
||||
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
|
||||
# Evaluation
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
|
||||
# Output
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
|
|
@ -57,9 +69,12 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -75,7 +90,7 @@ print0(f"Vocab size: {vocab_size:,}")
|
|||
num_layers = depth
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
num_kv_heads = num_heads # 1:1 MQA ratio
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
|
|
@ -90,16 +105,31 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
|||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
|
||||
# Create a new model with random weights
|
||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
with torch.device("meta"):
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device="cuda")
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
|
||||
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
||||
base_dir = get_base_dir()
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||
resuming = resume_from_step != -1
|
||||
if resuming:
|
||||
print0(f"Resuming optimization from step {resume_from_step}")
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank)
|
||||
model.load_state_dict(model_data, strict=True, assign=True)
|
||||
del model_data # free up this memory after the copy
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print0(f"Number of parameters: {num_params:,}")
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
|
|
@ -130,21 +160,23 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
|||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
if resuming:
|
||||
for opt, dat in zip(optimizers, optimizer_data):
|
||||
opt.load_state_dict(dat)
|
||||
del optimizer_data # free up the memory
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
base_dir = get_base_dir()
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
|
||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Set up hyperparameter schedulers
|
||||
|
||||
# Learning rate scheduler
|
||||
# TODO: experiment with a short warmup for the AdamW params (expecting slight improvement)
|
||||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(warmdown_ratio * num_iterations)
|
||||
|
|
@ -162,15 +194,26 @@ def get_muon_momentum(it):
|
|||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
return momentum
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Loop state (variables updated by the training loop)
|
||||
|
||||
if not resuming:
|
||||
step = 0
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
else:
|
||||
step = meta_data["step"]
|
||||
loop_state = meta_data["loop_state"]
|
||||
val_bpb = meta_data["val_bpb"]
|
||||
min_val_bpb = loop_state["min_val_bpb"]
|
||||
smooth_train_loss = loop_state["smooth_train_loss"]
|
||||
total_training_time = loop_state["total_training_time"]
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
# note that we run +1 steps only so that we can eval and save at the end
|
||||
for step in range(num_iterations + 1):
|
||||
last_step = step == num_iterations
|
||||
while True:
|
||||
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
|
|
@ -193,7 +236,8 @@ for step in range(num_iterations + 1):
|
|||
|
||||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
if last_step or (step > 0 and step % core_metric_every == 0):
|
||||
results = {}
|
||||
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
|
||||
model.eval()
|
||||
with autocast_ctx:
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
|
|
@ -219,7 +263,7 @@ for step in range(num_iterations + 1):
|
|||
"My favorite color is",
|
||||
"If 5*x + 3 = 13, then x is",
|
||||
]
|
||||
engine = Engine(model, tokenizer)
|
||||
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
|
|
@ -227,32 +271,38 @@ for step in range(num_iterations + 1):
|
|||
print0(tokenizer.decode(sample[0]))
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step:
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
|
||||
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
{
|
||||
orig_model.state_dict(), # model parameters
|
||||
[opt.state_dict() for opt in optimizers], # optimizer states
|
||||
{ # metadata saved as json
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
"dataloader_state_dict": dataloader_state_dict,
|
||||
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||
"min_val_bpb": min_val_bpb,
|
||||
"smooth_train_loss": smooth_train_loss,
|
||||
"total_training_time": total_training_time,
|
||||
},
|
||||
},
|
||||
rank=ddp_rank,
|
||||
)
|
||||
|
||||
# termination conditions (TODO: possibly also add loss explosions etc.)
|
||||
if last_step:
|
||||
break
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
|
|
@ -260,10 +310,12 @@ for step in range(num_iterations + 1):
|
|||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# gradient clipping (TODO possibly expertiment with)
|
||||
if grad_clip > 0.0:
|
||||
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# gradient clipping
|
||||
grad_clip_enabled = grad_clip > 0.0
|
||||
if grad_clip_enabled:
|
||||
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
|
|
@ -275,24 +327,26 @@ for step in range(num_iterations + 1):
|
|||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# logging
|
||||
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else ""
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 100 == 0:
|
||||
wandb_run.log({
|
||||
log_data = {
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
|
|
@ -301,10 +355,16 @@ for step in range(num_iterations + 1):
|
|||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
})
|
||||
}
|
||||
if grad_clip_enabled:
|
||||
log_data["train/grad_norm"] = grad_norm
|
||||
wandb_run.log(log_data)
|
||||
|
||||
# state update
|
||||
step += 1
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
|
|
@ -326,11 +386,11 @@ get_report().log(section="Base model training", data=[
|
|||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
"Final validation bpb": val_bpb,
|
||||
"CORE metric estimate": results["core_metric"],
|
||||
"CORE metric estimate": results.get("core_metric", None),
|
||||
"MFU %": f"{mfu:.2f}%",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time/60:.2f}m",
|
||||
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
|
||||
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
||||
}
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ python -m scripts.chat_cli -i mid
|
|||
"""
|
||||
import argparse
|
||||
import torch
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from contextlib import nullcontext
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
|
|
@ -17,11 +18,16 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
|||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init the model and tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Evaluate the Chat model.
|
||||
All the generic code lives here, and all the evlauation-specific
|
||||
All the generic code lives here, and all the evaluation-specific
|
||||
code lives in nanochat directory and is imported from here.
|
||||
|
||||
Example runs:
|
||||
|
|
@ -10,11 +10,12 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
|||
|
||||
import argparse
|
||||
from functools import partial
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -22,6 +23,7 @@ from tasks.humaneval import HumanEval
|
|||
from tasks.mmlu import MMLU
|
||||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.spellingbee import SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
|
||||
|
|
@ -115,7 +117,7 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
|
|||
logits = model(prompt_ids) # (B, T, V)
|
||||
|
||||
# Focus on the available answer on just the letters corresponding to choices
|
||||
# Note that this helps the evaluation a lot because it specifically narrows the focus to only the avilable letters
|
||||
# Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters
|
||||
# The much harder alternative would be to just generate from the Assistant and check if it responded with the correct
|
||||
# letter (e.g. A, B, C, D), but evaluations typically make the task easier in this way.
|
||||
for idx, conversation in enumerate(conversations):
|
||||
|
|
@ -164,6 +166,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
|
|||
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
|
||||
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
|
||||
'GSM8K': partial(GSM8K, subset="main", split="test"),
|
||||
'SpellingBee': partial(SpellingBee, size=256, split="test"),
|
||||
}[task_name]
|
||||
task_object = task_module()
|
||||
# Run the evaluation
|
||||
|
|
@ -191,23 +194,26 @@ if __name__ == "__main__":
|
|||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
args = parser.parse_args()
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
# Get the tasks to evaluate on
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval']
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
|
||||
baseline_accuracies = {
|
||||
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'GSM8K': 0.0, # open-ended => 0%
|
||||
'HumanEval': 0.0, # open-ended => 0%
|
||||
'SpellingBee': 0.0, # open-ended => 0%
|
||||
}
|
||||
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ from tasks.gsm8k import GSM8K
|
|||
# RL hyperparameters
|
||||
run = "dummy" # wandb run name
|
||||
source = "sft" # mid|sft
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 8 # no forward pass will go above this to not OOM
|
||||
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
|
||||
|
|
@ -64,7 +66,7 @@ use_dummy_wandb = run == "dummy" or not master_process
|
|||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
|
||||
|
||||
# Init model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval")
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -206,7 +208,7 @@ def get_lr_multiplier(it):
|
|||
lrm = 1.0 - it / num_steps
|
||||
return lrm
|
||||
|
||||
# Calculate the number of examples each rank handles to achive the desired examples_per_step
|
||||
# Calculate the number of examples each rank handles to achieve the desired examples_per_step
|
||||
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
|
||||
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
|
||||
examples_per_rank = examples_per_step // ddp_world_size # per GPU
|
||||
|
|
@ -307,8 +309,8 @@ for step in range(num_steps):
|
|||
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag)
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
|
|
|
|||
|
|
@ -10,25 +10,25 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
|
|||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import copy
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.engine import Engine
|
||||
from scripts.chat_eval import run_chat_eval
|
||||
|
||||
from tasks.common import TaskMixture, TaskSequence
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.common import TaskMixture
|
||||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.humaneval import HumanEval
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SFT Hyperparameters
|
||||
|
|
@ -38,11 +38,12 @@ source = "mid" # base|mid , which checkpoint to load the model from (base model
|
|||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
# compute/precision
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
|
||||
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
|
||||
target_examples_per_step = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
|
|
@ -53,6 +54,7 @@ init_lr_frac = 0.02
|
|||
eval_every = 100
|
||||
eval_steps = 100
|
||||
eval_metrics_every = 200
|
||||
eval_metrics_max_problems = 1024
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
|
@ -60,10 +62,11 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -77,13 +80,16 @@ engine = Engine(model, tokenizer) # will be used for inline model evaluation onl
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Task data mixture we'll train on
|
||||
|
||||
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
|
||||
train_ds = TaskMixture([
|
||||
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
|
||||
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
|
||||
GSM8K(subset="main", split="train"), # 8K rows
|
||||
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
||||
]) # 2.3K + 1.1K + 8K + 10K = 21.4K rows
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
||||
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
|
||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -129,10 +135,10 @@ assert target_examples_per_step % examples_per_step == 0, "Target examples per s
|
|||
grad_accum_steps = target_examples_per_step // examples_per_step
|
||||
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
|
||||
|
||||
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||
if max_iterations >= 0 and num_iterations > max_iterations:
|
||||
print0(f"Number of iterations is too high: {num_iterations}, capping to {max_iterations}")
|
||||
num_iterations = max_iterations
|
||||
if num_iterations == -1:
|
||||
# derive num_iterations from num_epochs and the size of the dataset
|
||||
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
|
||||
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
|
||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
|
||||
|
||||
|
|
@ -161,17 +167,16 @@ def get_lr_multiplier(it):
|
|||
|
||||
# Go!
|
||||
step = 0
|
||||
train_iter = iter(train_loader)
|
||||
for step in range(num_iterations):
|
||||
last_step = step == num_iterations - 1
|
||||
|
||||
# evaluate the validation loss
|
||||
if last_step or step % eval_every == 0:
|
||||
model.eval()
|
||||
val_iter = iter(build_val_loader())
|
||||
val_loader = build_val_loader()
|
||||
losses = []
|
||||
for _ in range(eval_steps):
|
||||
val_inputs, val_targets = next(val_iter)
|
||||
val_inputs, val_targets = next(val_loader)
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
loss = model(val_inputs, val_targets)
|
||||
losses.append(loss)
|
||||
|
|
@ -186,16 +191,14 @@ for step in range(num_iterations):
|
|||
})
|
||||
model.train()
|
||||
|
||||
# evlauate MMLU accuracy
|
||||
# evaluate accuracy of the multiple choice tasks (which are quick to run)
|
||||
if last_step or (step > 0 and step % eval_metrics_every == 0):
|
||||
model.eval()
|
||||
metrics = {}
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["gsm8k_acc"] = run_chat_eval("GSM8K", model, tokenizer, engine, max_problems=64)
|
||||
metrics["humaneval_acc"] = run_chat_eval("HumanEval", model, tokenizer, engine, max_problems=64)
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
wandb_run.log({
|
||||
|
|
@ -211,7 +214,7 @@ for step in range(num_iterations):
|
|||
total_loss_sum = torch.tensor(0.0, device=device) # sum of losses
|
||||
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
||||
for micro_step in range(grad_accum_steps):
|
||||
train_inputs, train_targets = next(train_iter)
|
||||
train_inputs, train_targets = next(train_loader)
|
||||
with autocast_ctx:
|
||||
loss = model(train_inputs, train_targets, loss_reduction='sum')
|
||||
total_loss_sum += loss.detach() # for logging
|
||||
|
|
@ -258,8 +261,8 @@ for step in range(num_iterations):
|
|||
if master_process:
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
|
|
|
|||
|
|
@ -1,26 +1,67 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unified web chat server - serves both UI and API from a single FastAPI instance.
|
||||
Run with: python web_chat.py
|
||||
Then open http://localhost:8000 in your browser.
|
||||
|
||||
Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
|
||||
a full copy of the model, and incoming requests are distributed to available workers.
|
||||
|
||||
Launch examples:
|
||||
|
||||
- single available GPU (default)
|
||||
python -m scripts.chat_web
|
||||
|
||||
- 4 GPUs
|
||||
python -m scripts.chat_web --num-gpus 4
|
||||
|
||||
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
|
||||
|
||||
Endpoints:
|
||||
GET / - Chat UI
|
||||
POST /chat/completions - Chat API (streaming only)
|
||||
GET /health - Health check with worker pool status
|
||||
GET /stats - Worker pool statistics and GPU utilization
|
||||
|
||||
Abuse Prevention:
|
||||
- Maximum 500 messages per request
|
||||
- Maximum 8000 characters per message
|
||||
- Maximum 32000 characters total conversation length
|
||||
- Temperature clamped to 0.0-2.0
|
||||
- Top-k clamped to 1-200
|
||||
- Max tokens clamped to 1-4096
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import torch
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
|
||||
from nanochat.common import compute_init
|
||||
from dataclasses import dataclass
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
# Abuse prevention limits
|
||||
MAX_MESSAGES_PER_REQUEST = 500
|
||||
MAX_MESSAGE_LENGTH = 8000
|
||||
MAX_TOTAL_CONVERSATION_LENGTH = 32000
|
||||
MIN_TEMPERATURE = 0.0
|
||||
MAX_TEMPERATURE = 2.0
|
||||
MIN_TOP_K = 1
|
||||
MAX_TOP_K = 200
|
||||
MIN_MAX_TOKENS = 1
|
||||
MAX_MAX_TOKENS = 4096
|
||||
|
||||
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
||||
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
|
||||
|
|
@ -28,11 +69,83 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m
|
|||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
args = parser.parse_args()
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
# Configure logging for conversation traffic
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
"""A worker with a model loaded on a specific GPU."""
|
||||
gpu_id: int
|
||||
device: torch.device
|
||||
engine: Engine
|
||||
tokenizer: object
|
||||
autocast_ctx: torch.amp.autocast
|
||||
|
||||
class WorkerPool:
|
||||
"""Pool of workers, each with a model replica on a different GPU."""
|
||||
|
||||
def __init__(self, num_gpus: Optional[int] = None):
|
||||
if num_gpus is None:
|
||||
if device_type == "cuda":
|
||||
num_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
num_gpus = 1 # e.g. cpu|mps
|
||||
self.num_gpus = num_gpus
|
||||
self.workers: List[Worker] = []
|
||||
self.available_workers: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
|
||||
"""Load model on each GPU."""
|
||||
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
|
||||
if self.num_gpus > 1:
|
||||
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
|
||||
|
||||
for gpu_id in range(self.num_gpus):
|
||||
|
||||
if device_type == "cuda":
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
print(f"Loading model on GPU {gpu_id}...")
|
||||
else:
|
||||
device = torch.device(device_type) # e.g. cpu|mps
|
||||
print(f"Loading model on {device_type}...")
|
||||
|
||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
worker = Worker(
|
||||
gpu_id=gpu_id,
|
||||
device=device,
|
||||
engine=engine,
|
||||
tokenizer=tokenizer,
|
||||
autocast_ctx=autocast_ctx
|
||||
)
|
||||
self.workers.append(worker)
|
||||
await self.available_workers.put(worker)
|
||||
|
||||
print(f"All {self.num_gpus} workers initialized!")
|
||||
|
||||
async def acquire_worker(self) -> Worker:
|
||||
"""Get an available worker from the pool."""
|
||||
return await self.available_workers.get()
|
||||
|
||||
async def release_worker(self, worker: Worker):
|
||||
"""Return a worker to the pool."""
|
||||
await self.available_workers.put(worker)
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
|
|
@ -43,14 +156,76 @@ class ChatRequest(BaseModel):
|
|||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
stream: Optional[bool] = True
|
||||
|
||||
def validate_chat_request(request: ChatRequest):
|
||||
"""Validate chat request to prevent abuse."""
|
||||
# Check number of messages
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one message is required")
|
||||
if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request"
|
||||
)
|
||||
|
||||
# Check individual message lengths and total conversation length
|
||||
total_length = 0
|
||||
for i, message in enumerate(request.messages):
|
||||
if not message.content:
|
||||
raise HTTPException(status_code=400, detail=f"Message {i} has empty content")
|
||||
|
||||
msg_length = len(message.content)
|
||||
if msg_length > MAX_MESSAGE_LENGTH:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message"
|
||||
)
|
||||
total_length += msg_length
|
||||
|
||||
if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed"
|
||||
)
|
||||
|
||||
# Validate role values
|
||||
for i, message in enumerate(request.messages):
|
||||
if message.role not in ["user", "assistant"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
|
||||
)
|
||||
|
||||
# Validate temperature
|
||||
if request.temperature is not None:
|
||||
if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
|
||||
)
|
||||
|
||||
# Validate top_k
|
||||
if request.top_k is not None:
|
||||
if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
|
||||
)
|
||||
|
||||
# Validate max_tokens
|
||||
if request.max_tokens is not None:
|
||||
if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Load model on startup."""
|
||||
print("Loading nanochat model...")
|
||||
app.state.model, app.state.tokenizer, _ = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
app.state.engine = Engine(app.state.model, app.state.tokenizer)
|
||||
"""Load models on all GPUs on startup."""
|
||||
print("Loading nanochat models across GPUs...")
|
||||
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
|
||||
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
|
||||
print(f"Server ready at http://localhost:{args.port}")
|
||||
yield
|
||||
|
||||
|
|
@ -68,7 +243,7 @@ app.add_middleware(
|
|||
async def root():
|
||||
"""Serve the chat UI."""
|
||||
ui_html_path = os.path.join("nanochat", "ui.html")
|
||||
with open(ui_html_path, "r") as f:
|
||||
with open(ui_html_path, "r", encoding="utf-8") as f:
|
||||
html_content = f.read()
|
||||
# Replace the API_URL to use the same origin
|
||||
html_content = html_content.replace(
|
||||
|
|
@ -85,8 +260,7 @@ async def logo():
|
|||
return FileResponse(logo_path, media_type="image/svg+xml")
|
||||
|
||||
async def generate_stream(
|
||||
engine,
|
||||
tokenizer,
|
||||
worker: Worker,
|
||||
tokens,
|
||||
temperature=None,
|
||||
max_new_tokens=None,
|
||||
|
|
@ -97,98 +271,141 @@ async def generate_stream(
|
|||
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
|
||||
top_k = top_k if top_k is not None else args.top_k
|
||||
|
||||
assistant_end = tokenizer.encode_special("<|assistant_end|>")
|
||||
bos = tokenizer.get_bos_token_id()
|
||||
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
|
||||
bos = worker.tokenizer.get_bos_token_id()
|
||||
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in engine.generate(
|
||||
# Accumulate tokens to properly handle multi-byte UTF-8 characters (like emojis)
|
||||
accumulated_tokens = []
|
||||
# Track the last complete UTF-8 string (without replacement characters)
|
||||
last_clean_text = ""
|
||||
|
||||
with worker.autocast_ctx:
|
||||
for token_column, token_masks in worker.engine.generate(
|
||||
tokens,
|
||||
num_samples=1,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k
|
||||
top_k=top_k,
|
||||
seed=random.randint(0, 2**31 - 1)
|
||||
):
|
||||
token = token_column[0]
|
||||
|
||||
# Stopping criteria
|
||||
if token == assistant_end or token == bos:
|
||||
break
|
||||
|
||||
token_text = tokenizer.decode([token])
|
||||
yield f"data: {json.dumps({'token': token_text})}\n\n"
|
||||
# Append the token to sequence
|
||||
accumulated_tokens.append(token)
|
||||
# Decode all accumulated tokens to get proper UTF-8 handling
|
||||
# Note that decode is a quite efficient operation, basically table lookup and string concat
|
||||
current_text = worker.tokenizer.decode(accumulated_tokens)
|
||||
# Only emit text if it doesn't end with a replacement character
|
||||
# This ensures we don't emit incomplete UTF-8 sequences
|
||||
if not current_text.endswith('<EFBFBD>'):
|
||||
# Extract only the new text since last clean decode
|
||||
new_text = current_text[len(last_clean_text):]
|
||||
if new_text: # Only yield if there's new content
|
||||
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
||||
last_clean_text = current_text
|
||||
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
|
||||
@app.post("/chat/completions")
|
||||
async def chat_completions(request: ChatRequest):
|
||||
"""Chat completion endpoint with streaming."""
|
||||
engine = app.state.engine
|
||||
tokenizer = app.state.tokenizer
|
||||
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
|
||||
|
||||
# Build conversation tokens
|
||||
bos = tokenizer.get_bos_token_id()
|
||||
user_start = tokenizer.encode_special("<|user_start|>")
|
||||
user_end = tokenizer.encode_special("<|user_end|>")
|
||||
assistant_start = tokenizer.encode_special("<|assistant_start|>")
|
||||
assistant_end = tokenizer.encode_special("<|assistant_end|>")
|
||||
# Basic validation to prevent abuse
|
||||
validate_chat_request(request)
|
||||
|
||||
conversation_tokens = [bos]
|
||||
for message in request.messages:
|
||||
if message.role == "user":
|
||||
conversation_tokens.append(user_start)
|
||||
conversation_tokens.extend(tokenizer.encode(message.content))
|
||||
conversation_tokens.append(user_end)
|
||||
elif message.role == "assistant":
|
||||
conversation_tokens.append(assistant_start)
|
||||
conversation_tokens.extend(tokenizer.encode(message.content))
|
||||
conversation_tokens.append(assistant_end)
|
||||
# Log incoming conversation to console
|
||||
logger.info("="*20)
|
||||
for i, message in enumerate(request.messages):
|
||||
logger.info(f"[{message.role.upper()}]: {message.content}")
|
||||
logger.info("-"*20)
|
||||
|
||||
conversation_tokens.append(assistant_start)
|
||||
# Acquire a worker from the pool (will wait if all are busy)
|
||||
worker_pool = app.state.worker_pool
|
||||
worker = await worker_pool.acquire_worker()
|
||||
|
||||
try:
|
||||
# Build conversation tokens
|
||||
bos = worker.tokenizer.get_bos_token_id()
|
||||
user_start = worker.tokenizer.encode_special("<|user_start|>")
|
||||
user_end = worker.tokenizer.encode_special("<|user_end|>")
|
||||
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
|
||||
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
|
||||
|
||||
conversation_tokens = [bos]
|
||||
for message in request.messages:
|
||||
if message.role == "user":
|
||||
conversation_tokens.append(user_start)
|
||||
conversation_tokens.extend(worker.tokenizer.encode(message.content))
|
||||
conversation_tokens.append(user_end)
|
||||
elif message.role == "assistant":
|
||||
conversation_tokens.append(assistant_start)
|
||||
conversation_tokens.extend(worker.tokenizer.encode(message.content))
|
||||
conversation_tokens.append(assistant_end)
|
||||
|
||||
conversation_tokens.append(assistant_start)
|
||||
|
||||
# Streaming response with worker release after completion
|
||||
response_tokens = []
|
||||
async def stream_and_release():
|
||||
try:
|
||||
async for chunk in generate_stream(
|
||||
worker,
|
||||
conversation_tokens,
|
||||
temperature=request.temperature,
|
||||
max_new_tokens=request.max_tokens,
|
||||
top_k=request.top_k
|
||||
):
|
||||
# Accumulate response for logging
|
||||
chunk_data = json.loads(chunk.replace("data: ", "").strip())
|
||||
if "token" in chunk_data:
|
||||
response_tokens.append(chunk_data["token"])
|
||||
yield chunk
|
||||
finally:
|
||||
# Log the assistant response to console
|
||||
full_response = "".join(response_tokens)
|
||||
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
|
||||
logger.info("="*20)
|
||||
# Release worker back to pool after streaming is done
|
||||
await worker_pool.release_worker(worker)
|
||||
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
generate_stream(
|
||||
engine,
|
||||
tokenizer,
|
||||
conversation_tokens,
|
||||
temperature=request.temperature,
|
||||
max_new_tokens=request.max_tokens,
|
||||
top_k=request.top_k
|
||||
),
|
||||
stream_and_release(),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
# Non-streaming response
|
||||
temperature = request.temperature if request.temperature is not None else args.temperature
|
||||
max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens
|
||||
top_k = request.top_k if request.top_k is not None else args.top_k
|
||||
|
||||
with autocast_ctx:
|
||||
result_tokens, masks = engine.generate_batch(
|
||||
conversation_tokens,
|
||||
num_samples=1,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k
|
||||
)[0]
|
||||
|
||||
response_tokens = result_tokens[len(conversation_tokens):]
|
||||
response_text = tokenizer.decode(response_tokens)
|
||||
return {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response_text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
}
|
||||
except Exception as e:
|
||||
# Make sure to release worker even on error
|
||||
await worker_pool.release_worker(worker)
|
||||
raise e
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
worker_pool = getattr(app.state, 'worker_pool', None)
|
||||
return {
|
||||
"status": "ok",
|
||||
"ready": hasattr(app.state, 'model') and app.state.model is not None
|
||||
"ready": worker_pool is not None and len(worker_pool.workers) > 0,
|
||||
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
|
||||
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
|
||||
}
|
||||
|
||||
@app.get("/stats")
|
||||
async def stats():
|
||||
"""Get worker pool statistics."""
|
||||
worker_pool = app.state.worker_pool
|
||||
return {
|
||||
"total_workers": len(worker_pool.workers),
|
||||
"available_workers": worker_pool.available_workers.qsize(),
|
||||
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
|
||||
"workers": [
|
||||
{
|
||||
"gpu_id": w.gpu_id,
|
||||
"device": str(w.device)
|
||||
} for w in worker_pool.workers
|
||||
]
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -11,12 +11,12 @@ torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_
|
|||
|
||||
from collections import deque
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -27,12 +27,16 @@ from tasks.common import TaskMixture
|
|||
from tasks.gsm8k import GSM8K
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
unembedding_lr = 0.004
|
||||
|
|
@ -40,20 +44,22 @@ embedding_lr = 0.2
|
|||
matrix_lr = 0.02
|
||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||
weight_decay = 0.0
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
eval_every = 150
|
||||
eval_every = 150 # -1 = disable
|
||||
eval_tokens = 20*524288
|
||||
total_batch_size = 524288
|
||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -88,11 +94,16 @@ for opt in optimizers:
|
|||
|
||||
# Midtraining data mixture and DataLoader
|
||||
base_dir = get_base_dir()
|
||||
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
||||
train_dataset = TaskMixture([
|
||||
SmolTalk(split="train"), # 460K rows of general conversations
|
||||
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
|
||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||
]) # total: 460K + 100K + 8K = 568K rows
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
|
||||
val_dataset = TaskMixture([
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
|
|
@ -101,7 +112,7 @@ val_dataset = TaskMixture([
|
|||
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
|
||||
# A big problem is that we don't know the final num_iterations in advance. So we create
|
||||
# these two global variables and update them from within the data generator.
|
||||
last_step = False # we will toggle this to True when we reach the end of the dataset
|
||||
last_step = False # we will toggle this to True when we reach the end of the training dataset
|
||||
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
||||
def mid_data_generator(split):
|
||||
global last_step, approx_progress
|
||||
|
|
@ -111,8 +122,10 @@ def mid_data_generator(split):
|
|||
assert dataset_size > 0
|
||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
token_buffer = deque()
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding
|
||||
while len(token_buffer) < needed_tokens:
|
||||
|
|
@ -124,6 +137,10 @@ def mid_data_generator(split):
|
|||
cursor -= dataset_size # wrap around for another epoch
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if 0 < num_iterations <= it and split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
|
|
@ -132,7 +149,10 @@ def mid_data_generator(split):
|
|||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
if split == "train":
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
if num_iterations > 0:
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
yield inputs, targets
|
||||
|
||||
train_loader = mid_data_generator("train")
|
||||
|
|
@ -141,7 +161,8 @@ progress = 0 # will go from 0 to 1 over the course of the epoch
|
|||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(progress):
|
||||
return progress * 1.0 + (1 - progress) * final_lr_frac
|
||||
# first 80% of training: no decay, then linearly ramp down to 0.
|
||||
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
|
|
@ -167,7 +188,7 @@ while True:
|
|||
last_step = bool(last_step_tensor.item())
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if last_step or step % eval_every == 0:
|
||||
if eval_every > 0 and (last_step or step % eval_every == 0):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
|
|
@ -185,8 +206,8 @@ while True:
|
|||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step:
|
||||
output_dirname = f"d{depth}" # e.g. d12
|
||||
if master_process and last_step and not dry_run:
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
|
|
@ -214,7 +235,7 @@ while True:
|
|||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
|
|
@ -235,7 +256,7 @@ while True:
|
|||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -247,7 +268,7 @@ while True:
|
|||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * progress
|
||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
|
|
@ -267,22 +288,23 @@ while True:
|
|||
})
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Midtraining", data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of iterations": step,
|
||||
"DDP world size": ddp_world_size,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
}
|
||||
])
|
||||
if not dry_run:
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Midtraining", data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of iterations": step,
|
||||
"DDP world size": ddp_world_size,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
}
|
||||
])
|
||||
|
||||
# cleanup
|
||||
wandb_run.finish() # wandb run finish
|
||||
|
|
|
|||
36
speedrun.sh
36
speedrun.sh
|
|
@ -23,7 +23,7 @@ command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
|||
# create a .venv local virtual environment (if it doesn't exist)
|
||||
[ -d ".venv" ] || uv venv
|
||||
# install the repo dependencies
|
||||
uv sync
|
||||
uv sync --extra gpu
|
||||
# activate venv so that `python` uses the project's venv instead of system python
|
||||
source .venv/bin/activate
|
||||
|
||||
|
|
@ -73,15 +73,6 @@ python -m scripts.tok_eval
|
|||
# -----------------------------------------------------------------------------
|
||||
# Base model (pretraining)
|
||||
|
||||
# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB)
|
||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||
unzip -q eval_bundle.zip
|
||||
rm eval_bundle.zip
|
||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||
fi
|
||||
|
||||
# The d20 model is 561M parameters.
|
||||
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
||||
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
||||
|
|
@ -91,26 +82,33 @@ fi
|
|||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# Number of processes/GPUs to use
|
||||
NPROC_PER_NODE=8
|
||||
|
||||
# pretrain the d20 model
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
# evaluate the model on CORE tasks
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||
|
||||
# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
|
||||
# see dev/gen_synthetic_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
# run midtraining and eval the model
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
||||
|
||||
# train sft and re-eval right away (should see a small bump)
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
|
||||
|
||||
# chat with the model over CLI! Leave out the -p to chat interactively
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
|
@ -123,9 +121,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
|
|||
# (optional)
|
||||
|
||||
# run reinforcement learning
|
||||
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN
|
||||
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
|
||||
# eval the RL model only on GSM8K
|
||||
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generate the full report by putting together all the sections
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class Task:
|
|||
|
||||
class TaskMixture(Task):
|
||||
"""
|
||||
For SFT Training it becomes useful to train on a tax mixture of datasets.
|
||||
For SFT Training it becomes useful to train on a mixture of datasets.
|
||||
Fun trick: if you wish to oversample any task, just pass it in multiple times in the list.
|
||||
"""
|
||||
|
||||
|
|
|
|||
65
tasks/customjson.py
Normal file
65
tasks/customjson.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
"""
|
||||
CustomJSON task for loading conversations from JSONL files.
|
||||
Each line in the JSONL file should be a JSON array of messages.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from tasks.common import Task
|
||||
|
||||
class CustomJSON(Task):
|
||||
"""
|
||||
Load conversations from a JSONL file.
|
||||
Each line should be a JSON array of message objects with 'role' and 'content' fields.
|
||||
Example line: [{"role":"user","content":"Hi"},{"role":"assistant","content":"Hello"}]
|
||||
"""
|
||||
|
||||
def __init__(self, filepath, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.filepath = filepath
|
||||
self.conversations = []
|
||||
|
||||
# Load all conversations from the JSONL file
|
||||
if not os.path.exists(filepath):
|
||||
# Helpful error message due to recent change. Will be removed in the future.
|
||||
print("-" * 80)
|
||||
print(f"Warning: File {filepath} does not exist")
|
||||
print("HINT (Oct 21 2025)")
|
||||
print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations")
|
||||
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
|
||||
print("Quick fix: simply run the following command to download the file and you're done:")
|
||||
print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl")
|
||||
print("-" * 80)
|
||||
|
||||
else:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line: # skip empty lines
|
||||
continue
|
||||
messages = json.loads(line)
|
||||
# Validate the conversation structure
|
||||
assert isinstance(messages, list), f"Expected list of messages, got {type(messages)}"
|
||||
assert len(messages) >= 2, f"Conversation must have at least 2 messages, got {len(messages)}"
|
||||
# Validate message structure and alternating roles
|
||||
for i, message in enumerate(messages):
|
||||
assert "role" in message, f"Message {i} missing 'role' field"
|
||||
assert "content" in message, f"Message {i} missing 'content' field"
|
||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
assert isinstance(message["content"], str), f"Message {i} content must be a string"
|
||||
|
||||
self.conversations.append(messages)
|
||||
|
||||
self.length = len(self.conversations)
|
||||
|
||||
def num_examples(self):
|
||||
return self.length
|
||||
|
||||
def get_example(self, index):
|
||||
messages = self.conversations[index]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
return conversation
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ class GSM8K(Task):
|
|||
else:
|
||||
# Regular text in between tool calls
|
||||
assistant_message_parts.append({"type": "text", "text": part})
|
||||
# No put it all together
|
||||
# Now put it all together
|
||||
messages = [
|
||||
{"role": "user", "content": question}, # note: simple string
|
||||
{"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts)
|
||||
|
|
|
|||
307
tasks/spellingbee.py
Normal file
307
tasks/spellingbee.py
Normal file
|
|
@ -0,0 +1,307 @@
|
|||
"""
|
||||
Task intended to make nanochat better in spelling and counting, for example:
|
||||
|
||||
"How many r are in strawberry?" -> 3
|
||||
|
||||
An interesting part of this task is that we will get the assistant to
|
||||
solve the problem using a combination of manual counting and Python.
|
||||
This is a good problem solving "instinct" to mix into the model and RL
|
||||
may further refine it to trust one over the other. If we were extra fancy
|
||||
(which we could/should be) we'd add small errors here and there to allow
|
||||
the model also learn recoveries. We can do this in future versions.
|
||||
|
||||
There are two tasks in this file:
|
||||
1. SpellingBee: Counting the number of occurrences of a letter in a word
|
||||
2. SimpleSpelling: Simply spelling words
|
||||
|
||||
(1) is the goal, but (2) exists as a highly condensed version of the part
|
||||
that makes (1) difficult, which is word spelling. This is non-trivial for an
|
||||
LLM because it has to learn how every token (a little semantic chunk/atom)
|
||||
maps to the sequence of individual characters that make it up. Larger models
|
||||
learn this eventually on their own, but if we want this capability to exist
|
||||
in smaller models, we have to actively encourage it by over-representing it
|
||||
in the training data. Midtraining is a good place to do this.
|
||||
|
||||
To preview a few example conversations, run:
|
||||
python -m tasks.spellingbee
|
||||
"""
|
||||
|
||||
import re
|
||||
import random
|
||||
from tasks.common import Task
|
||||
from nanochat.common import download_file_with_lock
|
||||
|
||||
# Letters of the alphabet
|
||||
LETTERS = "abcdefghijklmnopqrstuvwxyz"
|
||||
# A list of 370K English words of large variety
|
||||
WORD_LIST_URL = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words_alpha.txt"
|
||||
# A number bigger than 370K to separate train and test random seeds
|
||||
TEST_RANDOM_SEED_OFFSET = 10_000_000
|
||||
|
||||
# Identical to gsm8k's answer extraction
|
||||
ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
|
||||
def extract_answer(completion):
|
||||
"""
|
||||
Extract the numerical answer after #### marker.
|
||||
"""
|
||||
match = ANSWER_RE.search(completion)
|
||||
if match:
|
||||
match_str = match.group(1).strip()
|
||||
match_str = match_str.replace(",", "")
|
||||
return match_str
|
||||
return None
|
||||
|
||||
# User message templates for data augmentation
|
||||
USER_MSG_TEMPLATES = [
|
||||
"How many {letter} are in the word {word}",
|
||||
"How many {letter} are in {word}",
|
||||
"Count the number of {letter} in {word}",
|
||||
"How many times does {letter} appear in {word}",
|
||||
"What's the count of {letter} in {word}",
|
||||
"In the word {word}, how many {letter} are there",
|
||||
"How many letter {letter} are in the word {word}",
|
||||
"Count how many {letter} appear in {word}",
|
||||
"Tell me the number of {letter} in {word}",
|
||||
"How many occurrences of {letter} are in {word}",
|
||||
"Find the count of {letter} in {word}",
|
||||
"Can you count the {letter} letters in {word}",
|
||||
"What is the frequency of {letter} in {word}",
|
||||
"How many {letter}s are in {word}",
|
||||
"How many {letter}'s are in {word}",
|
||||
"Count all the {letter} in {word}",
|
||||
"How many times is {letter} in {word}",
|
||||
"Number of {letter} in {word}",
|
||||
"Total count of {letter} in {word}",
|
||||
"How many {letter} does {word} have",
|
||||
"How many {letter} does {word} contain",
|
||||
"What's the number of {letter} in {word}",
|
||||
"{word} has how many {letter}",
|
||||
"In {word}, count the {letter}",
|
||||
"How many {letter} appear in {word}",
|
||||
"Count the {letter} in {word}",
|
||||
"Give me the count of {letter} in {word}",
|
||||
"How many instances of {letter} in {word}",
|
||||
"Show me how many {letter} are in {word}",
|
||||
"Calculate the number of {letter} in {word}",
|
||||
# Spanish
|
||||
"¿Cuántas {letter} hay en {word}?",
|
||||
"¿Cuántas veces aparece {letter} en {word}?",
|
||||
"Cuenta las {letter} en {word}",
|
||||
"¿Cuántas letras {letter} tiene {word}?",
|
||||
# Chinese (Simplified)
|
||||
"{word}中有多少个{letter}",
|
||||
"{word}里有几个{letter}",
|
||||
"数一下{word}中的{letter}",
|
||||
"{word}这个词里有多少{letter}",
|
||||
# Korean
|
||||
"{word}에 {letter}가 몇 개 있나요",
|
||||
"{word}에서 {letter}의 개수는",
|
||||
"{word}에 {letter}가 몇 번 나오나요",
|
||||
"{word}라는 단어에 {letter}가 몇 개",
|
||||
# French
|
||||
"Combien de {letter} dans {word}",
|
||||
"Combien de fois {letter} apparaît dans {word}",
|
||||
"Compte les {letter} dans {word}",
|
||||
# German
|
||||
"Wie viele {letter} sind in {word}",
|
||||
"Wie oft kommt {letter} in {word} vor",
|
||||
"Zähle die {letter} in {word}",
|
||||
# Japanese
|
||||
"{word}に{letter}は何個ありますか",
|
||||
"{word}の中に{letter}がいくつ",
|
||||
"{word}に{letter}が何回出てくる",
|
||||
]
|
||||
|
||||
class SpellingBee(Task):
|
||||
|
||||
def __init__(self, size=1000, split="train", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||
self.size = size
|
||||
self.split = split
|
||||
filename = WORD_LIST_URL.split("/")[-1]
|
||||
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
|
||||
with open(word_list_path, 'r', encoding='utf-8') as f:
|
||||
words = [line.strip() for line in f]
|
||||
self.words = words
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return 'generative'
|
||||
|
||||
def num_examples(self):
|
||||
return self.size
|
||||
|
||||
def get_example(self, index):
|
||||
seed = index if self.split == 'train' else TEST_RANDOM_SEED_OFFSET + index
|
||||
rng = random.Random(seed)
|
||||
|
||||
# pick a random word
|
||||
word = rng.choice(self.words)
|
||||
# pick a letter from it (90%) or a random letter (10%)
|
||||
letter = rng.choice(word) if rng.random() < 0.9 else rng.choice(LETTERS)
|
||||
|
||||
# get the correct answer by simply counting
|
||||
count = word.count(letter)
|
||||
|
||||
# create a user message, with a bunch of variations as data augmentation
|
||||
template = rng.choice(USER_MSG_TEMPLATES)
|
||||
# 30% chance to lowercase the template (lazy people don't use shift)
|
||||
if rng.random() < 0.3:
|
||||
template = template.lower()
|
||||
quote_options = ['', "'", '"']
|
||||
letter_quote = rng.choice(quote_options) # is the letter quoted?
|
||||
word_quote = rng.choice(quote_options) # is the word quoted?
|
||||
letter_wrapped = f"{letter_quote}{letter}{letter_quote}"
|
||||
word_wrapped = f"{word_quote}{word}{word_quote}"
|
||||
user_msg = template.format(letter=letter_wrapped, word=word_wrapped)
|
||||
if rng.random() < 0.5: # 50% of people don't even use question marks
|
||||
user_msg += "?"
|
||||
|
||||
# Now create the ideal assistant response - build as parts (text + tool calls)
|
||||
assistant_parts = []
|
||||
word_letters = ",".join(list(word))
|
||||
manual_text = f"""We are asked to find the number '{letter}' in the word '{word}'. Let me try a manual approach first.
|
||||
|
||||
First spell the word out:
|
||||
{word}:{word_letters}
|
||||
|
||||
Then count the occurrences of '{letter}':
|
||||
"""
|
||||
# Little simulated loop of the solution process
|
||||
# TODO: This is where the fun starts, we could simulate cute little mistakes
|
||||
# and get the model to review its work and recover from them.
|
||||
# You might of course hope this could arise in RL too, but realistically you'd want to help it out a bit.
|
||||
running_count = 0
|
||||
for i, char in enumerate(word, 1):
|
||||
if char == letter:
|
||||
running_count += 1
|
||||
# note: there deliberately cannot be a space here between i and char
|
||||
# because this would create a different token! (e.g. " a" and "a" are different tokens)
|
||||
manual_text += f"{i}:{char} hit! count={running_count}\n"
|
||||
else:
|
||||
manual_text += f"{i}:{char}\n"
|
||||
|
||||
manual_text += f"\nThis gives us {running_count}."
|
||||
assistant_parts.append({"type": "text", "text": manual_text})
|
||||
# Part 2: Python verification
|
||||
assistant_parts.append({"type": "text", "text": "\n\nLet me double check this using Python:\n\n"})
|
||||
# Part 3: Python tool call
|
||||
python_expr = f"'{word}'.count('{letter}')"
|
||||
assistant_parts.append({"type": "python", "text": python_expr})
|
||||
# Part 4: Python output
|
||||
assistant_parts.append({"type": "python_output", "text": str(count)})
|
||||
# Part 5: Final answer
|
||||
assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"})
|
||||
|
||||
# return the full conversation
|
||||
messages = [
|
||||
{"role": "user", "content": user_msg},
|
||||
{"role": "assistant", "content": assistant_parts}
|
||||
]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
return conversation
|
||||
|
||||
def evaluate(self, conversation, assistant_response):
|
||||
"""
|
||||
Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct)
|
||||
Identical to gsm8k's evaluation.
|
||||
"""
|
||||
assert isinstance(assistant_response, str), "Assuming simple string response for now"
|
||||
# First extract the ground truth answer from the conversation
|
||||
assistant_message = conversation['messages'][-1]
|
||||
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
|
||||
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
|
||||
# The last text part contains the final answer with ####
|
||||
last_text_part = assistant_message['content'][-1]['text']
|
||||
# Extract both the ground truth answer and the predicted answer
|
||||
ref_num = extract_answer(last_text_part)
|
||||
pred_num = extract_answer(assistant_response)
|
||||
# Compare and return the success as int
|
||||
is_correct = int(pred_num == ref_num)
|
||||
return is_correct
|
||||
|
||||
def reward(self, conversation, assistant_response):
|
||||
""" Use simple 0-1 reward just like gsm8k."""
|
||||
is_correct = self.evaluate(conversation, assistant_response)
|
||||
is_correct_float = float(is_correct)
|
||||
return is_correct_float
|
||||
|
||||
|
||||
class SimpleSpelling(Task):
|
||||
"""Much simpler task designed to get the model to just practice spelling words."""
|
||||
|
||||
def __init__(self, size=1000, split="train", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||
self.size = size
|
||||
self.split = split
|
||||
filename = WORD_LIST_URL.split("/")[-1]
|
||||
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
|
||||
with open(word_list_path, 'r', encoding='utf-8') as f:
|
||||
words = [line.strip() for line in f]
|
||||
rng = random.Random(42)
|
||||
rng.shuffle(words) # use a different word order than the SpellingBee task
|
||||
self.words = words
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return 'generative'
|
||||
|
||||
def num_examples(self):
|
||||
return self.size
|
||||
|
||||
def get_example(self, index):
|
||||
seed = index if self.split == 'train' else TEST_RANDOM_SEED_OFFSET + index
|
||||
rng = random.Random(seed)
|
||||
# pick a random word
|
||||
word = rng.choice(self.words)
|
||||
word_letters = ",".join(list(word))
|
||||
# return the full conversation
|
||||
messages = [
|
||||
{"role": "user", "content": f"Spell the word: {word}"},
|
||||
{"role": "assistant", "content": f"{word}:{word_letters}"}
|
||||
]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
return conversation
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# preview the SpellingBee task, first 10 examples
|
||||
task = SpellingBee()
|
||||
for i in range(10):
|
||||
ex = task.get_example(i)
|
||||
print("=" * 100)
|
||||
print(ex['messages'][0]['content'])
|
||||
print("-" * 100)
|
||||
# Assistant content is now a list of parts
|
||||
assistant_parts = ex['messages'][1]['content']
|
||||
for part in assistant_parts:
|
||||
if part['type'] == 'text':
|
||||
print(part['text'], end='')
|
||||
elif part['type'] == 'python':
|
||||
print(f"<<{part['text']}=", end='')
|
||||
elif part['type'] == 'python_output':
|
||||
print(f"{part['text']}>>", end='')
|
||||
print()
|
||||
print("-" * 100)
|
||||
|
||||
# # preview the SimpleSpelling task, first 10 examples
|
||||
# task = SimpleSpelling()
|
||||
# for i in range(10):
|
||||
# ex = task.get_example(i)
|
||||
# print("=" * 100)
|
||||
# print(ex['messages'][0]['content'])
|
||||
# print("-" * 100)
|
||||
# print(ex['messages'][1]['content'])
|
||||
|
||||
# # also scrutinize the tokenization (last example only)
|
||||
# from nanochat.tokenizer import get_tokenizer
|
||||
# tokenizer = get_tokenizer()
|
||||
# ids, mask = tokenizer.render_conversation(ex)
|
||||
# print(tokenizer.visualize_tokenization(ids, mask, with_token_id=True))
|
||||
187
tests/test_engine.py
Normal file
187
tests/test_engine.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""
|
||||
Test Engine class. Example run:
|
||||
|
||||
python -m pytest tests/test_engine.py -v
|
||||
"""
|
||||
|
||||
import torch
|
||||
from nanochat.engine import KVCache, Engine
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Mock classes for testing Engine without loading a real model
|
||||
|
||||
@dataclass
|
||||
class MockConfig:
|
||||
"""Minimal config for Engine tests."""
|
||||
n_kv_head: int = 4
|
||||
n_head: int = 4
|
||||
n_embd: int = 64
|
||||
n_layer: int = 2
|
||||
sequence_len: int = 128
|
||||
|
||||
|
||||
class MockModel:
|
||||
"""
|
||||
Mock model that returns uniform logits over the vocab.
|
||||
This ensures that with temperature > 0, different samples should
|
||||
(with very high probability) produce different tokens.
|
||||
"""
|
||||
def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.config = MockConfig()
|
||||
self._device = "cpu"
|
||||
|
||||
def get_device(self):
|
||||
return self._device
|
||||
|
||||
def forward(self, ids, kv_cache=None):
|
||||
"""Return uniform logits so sampling is spread across vocab."""
|
||||
B, T = ids.shape
|
||||
# Simulate what a real transformer does: insert k,v into the cache for each layer
|
||||
if kv_cache is not None:
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
for layer_idx in range(self.config.n_layer):
|
||||
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
||||
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
||||
kv_cache.insert_kv(layer_idx, k, v)
|
||||
# Uniform logits -> equal probability for all tokens
|
||||
logits = torch.zeros(B, T, self.vocab_size)
|
||||
return logits
|
||||
|
||||
|
||||
class ByteTokenizer:
|
||||
"""
|
||||
Simple byte-level tokenizer for testing.
|
||||
Tokens 0-255 are raw bytes, 256+ are special tokens.
|
||||
"""
|
||||
def __init__(self):
|
||||
# Special tokens start at 256
|
||||
self._special_tokens = {
|
||||
"<|python_start|>": 256,
|
||||
"<|python_end|>": 257,
|
||||
"<|output_start|>": 258,
|
||||
"<|output_end|>": 259,
|
||||
"<|assistant_end|>": 260,
|
||||
"<|bos|>": 261,
|
||||
}
|
||||
self._bos = 261
|
||||
|
||||
def encode_special(self, s):
|
||||
return self._special_tokens[s]
|
||||
|
||||
def get_bos_token_id(self):
|
||||
return self._bos
|
||||
|
||||
def encode(self, s, prepend=None):
|
||||
tokens = list(s.encode("utf-8")) # bytes 0-255
|
||||
if prepend is not None:
|
||||
tokens = [prepend] + tokens
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
# Filter out special tokens before decoding
|
||||
byte_tokens = [t for t in tokens if t < 256]
|
||||
return bytes(byte_tokens).decode("utf-8", errors="replace")
|
||||
|
||||
def test_kv_cache_resize():
|
||||
"""
|
||||
The KV cache was not resized correctly, more information here:
|
||||
https://github.com/karpathy/nanochat/pull/186
|
||||
This test reproduces the issue and will be merged alongside the fix.
|
||||
"""
|
||||
|
||||
batch_size = 2
|
||||
num_heads = 3
|
||||
seq_len = 4
|
||||
head_dim = 5
|
||||
num_layers = 6
|
||||
|
||||
kv_cache = KVCache(
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
seq_len=seq_len,
|
||||
head_dim=head_dim,
|
||||
num_layers=num_layers
|
||||
)
|
||||
|
||||
# Insert a single token with a distinct fill value to all layers
|
||||
def insert_token(token_idx):
|
||||
for layer_idx in range(num_layers):
|
||||
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
|
||||
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
|
||||
kv_cache.insert_kv(layer_idx, k, v)
|
||||
|
||||
# Insert 4 tokens (fills the initial seq_len=4)
|
||||
for i in range(4):
|
||||
insert_token(i)
|
||||
|
||||
# Record the original state of the cache
|
||||
original_cache = kv_cache.kv_cache.clone()
|
||||
original_seq_len = original_cache.shape[4]
|
||||
|
||||
# Insert the 5th token, which will trigger a resize
|
||||
insert_token(4)
|
||||
# Verify that the cache actually resized
|
||||
new_seq_len = kv_cache.kv_cache.shape[4]
|
||||
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
|
||||
|
||||
# Verify that the original 4 tokens are still intact after resize
|
||||
for layer_idx in range(num_layers):
|
||||
for token_idx in range(4):
|
||||
# Check that resized cache matches expected values
|
||||
expected_k = float(token_idx)
|
||||
expected_v = float(token_idx * 100)
|
||||
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
|
||||
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
|
||||
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
|
||||
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
|
||||
# And that the original cache matches resized cache
|
||||
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
|
||||
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
|
||||
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
|
||||
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
|
||||
|
||||
|
||||
def test_multi_sample_first_token_diversity():
|
||||
"""
|
||||
Test that when generating multiple samples, each sample gets an independently
|
||||
sampled first token (not a broadcast of the same token to all rows).
|
||||
|
||||
Previously, the first token after prefill was sampled once and broadcast to all
|
||||
rows, causing all samples to start identically. The fix expands the prefill logits
|
||||
to num_samples and samples independently for each row.
|
||||
|
||||
With uniform logits over 262 tokens and 16 samples, the probability that all
|
||||
samples independently pick the same token is (1/262)^15 ≈ 10^-36. So if they're
|
||||
all identical, it indicates tokens are being broadcast instead of independently sampled.
|
||||
"""
|
||||
model = MockModel(vocab_size=262)
|
||||
tokenizer = ByteTokenizer()
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
# Generate 16 samples with temperature=1.0 (stochastic sampling)
|
||||
prompt_tokens = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
||||
num_samples = 16
|
||||
|
||||
# Collect the first generated token from each sample
|
||||
first_tokens = []
|
||||
gen = engine.generate(
|
||||
prompt_tokens,
|
||||
num_samples=num_samples,
|
||||
max_tokens=1, # We only need the first token
|
||||
temperature=1.0,
|
||||
seed=42,
|
||||
)
|
||||
for token_column, token_masks in gen:
|
||||
first_tokens = token_column # This is the first (and only) yield
|
||||
|
||||
# With uniform distribution and 16 samples, they should NOT all be identical
|
||||
# If they are all identical, the bug exists (broadcasting instead of sampling)
|
||||
unique_tokens = set(first_tokens)
|
||||
assert len(unique_tokens) > 1, (
|
||||
f"All {num_samples} samples got the same first token ({first_tokens[0]}). "
|
||||
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
|
||||
f"unless tokens are being broadcast instead of independently sampled."
|
||||
)
|
||||
|
|
@ -21,6 +21,7 @@ python -m pytest tests/test_rustbpe.py -v -s
|
|||
import regex as re
|
||||
from collections import Counter, defaultdict
|
||||
import time
|
||||
import warnings
|
||||
import rustbpe
|
||||
import tiktoken
|
||||
import pytest
|
||||
|
|
@ -455,13 +456,13 @@ def enwik8_path():
|
|||
@pytest.fixture(scope="module")
|
||||
def enwik8_small(enwik8_path):
|
||||
"""Fixture providing 100KB of enwik8 for quick tests."""
|
||||
with open(enwik8_path, "r") as f:
|
||||
with open(enwik8_path, "r", encoding="utf-8") as f:
|
||||
return f.read(100_000)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def enwik8_large(enwik8_path):
|
||||
"""Fixture providing 10MB of enwik8 for performance tests."""
|
||||
with open(enwik8_path, "r") as f:
|
||||
with open(enwik8_path, "r", encoding="utf-8") as f:
|
||||
return f.read(10**7)
|
||||
|
||||
def time_function(func, *args, **kwargs):
|
||||
|
|
@ -633,3 +634,85 @@ def test_interface(enwik8_small):
|
|||
ids_reloaded = tok_reloaded.encode(encode_text)
|
||||
assert ids_reloaded == ids, "Reloaded tokenizer should produce same results"
|
||||
print("✅ Save/load through temporary directory OK")
|
||||
|
||||
|
||||
def test_batch_encode_correctness(enwik8_small):
|
||||
"""Quick correctness test for batch_encode()"""
|
||||
text = enwik8_small
|
||||
vocab_size = 512
|
||||
|
||||
tokenizer = rustbpe.Tokenizer()
|
||||
tokenizer.train_from_iterator([text], vocab_size)
|
||||
|
||||
# Test with various batch sizes and edge cases
|
||||
test_texts = [
|
||||
"Hello world",
|
||||
"The quick brown fox",
|
||||
"jumps over the lazy dog",
|
||||
"", # empty string
|
||||
"a", # single char
|
||||
]
|
||||
|
||||
# Compare batch vs individual encoding
|
||||
individual = [tokenizer.encode(t) for t in test_texts]
|
||||
batched = tokenizer.batch_encode(test_texts)
|
||||
|
||||
assert individual == batched, "Batch encoding should match individual encoding"
|
||||
print("✅ batch_encode() correctness verified")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_batch_encode_performance(enwik8_large):
|
||||
"""
|
||||
Benchmark batch_encode() vs sequential encode() loop.
|
||||
Demonstrates parallelization speedup.
|
||||
"""
|
||||
# Setup
|
||||
text = enwik8_large # 10MB dataset
|
||||
vocab_size = 2048
|
||||
|
||||
# Train tokenizer
|
||||
print("\nTraining tokenizer...")
|
||||
tokenizer = rustbpe.Tokenizer()
|
||||
tokenizer.train_from_iterator([text], vocab_size)
|
||||
|
||||
# Create test batch: split text into chunks
|
||||
chunk_size = 50_000 # ~50KB per chunk
|
||||
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
|
||||
chunks = chunks[:20] # Use first 20 chunks (~1MB total)
|
||||
|
||||
print(f"\nBatch encoding benchmark:")
|
||||
print(f" Number of texts: {len(chunks)}")
|
||||
print(f" Avg text length: {sum(len(c) for c in chunks) / len(chunks):.0f} chars")
|
||||
|
||||
# Benchmark 1: Sequential encoding (baseline)
|
||||
print("\n [1/3] Sequential encode() loop...")
|
||||
sequential_results, sequential_time = time_function(
|
||||
lambda: [tokenizer.encode(chunk) for chunk in chunks]
|
||||
)
|
||||
print(f" Time: {sequential_time:.4f}s")
|
||||
|
||||
# Benchmark 2: Parallel batch_encode()
|
||||
print(" [2/3] Parallel batch_encode()...")
|
||||
batch_results, batch_time = time_function(
|
||||
tokenizer.batch_encode, chunks
|
||||
)
|
||||
print(f" Time: {batch_time:.4f}s")
|
||||
|
||||
# Verify correctness
|
||||
print(" [3/3] Verifying correctness...")
|
||||
assert len(batch_results) == len(sequential_results), "Result count mismatch"
|
||||
for i, (seq, batch) in enumerate(zip(sequential_results, batch_results)):
|
||||
assert seq == batch, f"Mismatch at index {i}"
|
||||
print(" ✓ All results match")
|
||||
|
||||
# Report speedup
|
||||
speedup = sequential_time / batch_time
|
||||
print(f"\n Performance Results:")
|
||||
print(f" Sequential: {sequential_time:.4f}s")
|
||||
print(f" Batch: {batch_time:.4f}s")
|
||||
print(f" Speedup: {speedup:.2f}x")
|
||||
|
||||
# Warn if speedup is low (can vary by machine/load)
|
||||
if speedup < 1.5:
|
||||
warnings.warn(f"batch_encode() speedup was only {speedup:.2f}x (expected >1.5x)")
|
||||
|
|
|
|||
336
uv.lock
336
uv.lock
|
|
@ -2,13 +2,32 @@ version = 1
|
|||
revision = 3
|
||||
requires-python = ">=3.10"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version >= '3.12' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version < '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
]
|
||||
conflicts = [[
|
||||
{ package = "nanochat", extra = "cpu" },
|
||||
{ package = "nanochat", extra = "gpu" },
|
||||
]]
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
|
|
@ -26,7 +45,7 @@ source = { registry = "https://pypi.org/simple" }
|
|||
dependencies = [
|
||||
{ name = "aiohappyeyeballs" },
|
||||
{ name = "aiosignal" },
|
||||
{ name = "async-timeout", marker = "python_full_version < '3.11'" },
|
||||
{ name = "async-timeout", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "attrs" },
|
||||
{ name = "frozenlist" },
|
||||
{ name = "multidict" },
|
||||
|
|
@ -111,7 +130,7 @@ version = "1.4.0"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "frozenlist" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7", size = 25007, upload-time = "2025-07-03T22:54:43.528Z" }
|
||||
wheels = [
|
||||
|
|
@ -132,10 +151,10 @@ name = "anyio"
|
|||
version = "4.10.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "idna" },
|
||||
{ name = "sniffio" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f1/b4/636b3b65173d3ce9a38ef5f0522789614e590dab6a8d505340a4efe4c567/anyio-4.10.0.tar.gz", hash = "sha256:3f3fae35c96039744587aa5b8371e7e8e603c0702999535961dd336026973ba6", size = 213252, upload-time = "2025-08-04T08:54:26.451Z" }
|
||||
wheels = [
|
||||
|
|
@ -238,7 +257,7 @@ name = "click"
|
|||
version = "8.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" }
|
||||
wheels = [
|
||||
|
|
@ -292,7 +311,7 @@ name = "exceptiongroup"
|
|||
version = "1.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" }
|
||||
wheels = [
|
||||
|
|
@ -497,7 +516,7 @@ source = { registry = "https://pypi.org/simple" }
|
|||
dependencies = [
|
||||
{ name = "filelock" },
|
||||
{ name = "fsspec" },
|
||||
{ name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" },
|
||||
{ name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "requests" },
|
||||
|
|
@ -602,7 +621,7 @@ name = "maturin"
|
|||
version = "1.9.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/13/7c/b11b870fc4fd84de2099906314ce45488ae17be32ff5493519a6cddc518a/maturin-1.9.4.tar.gz", hash = "sha256:235163a0c99bc6f380fb8786c04fd14dcf6cd622ff295ea3de525015e6ac40cf", size = 213647, upload-time = "2025-08-27T11:37:57.079Z" }
|
||||
wheels = [
|
||||
|
|
@ -635,7 +654,7 @@ name = "multidict"
|
|||
version = "6.6.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/69/7f/0652e6ed47ab288e3756ea9c0df8b14950781184d4bd7883f4d87dd41245/multidict-6.6.4.tar.gz", hash = "sha256:d2d4e4787672911b48350df02ed3fa3fffdc2f2e8ca06dd6afdf34189b76a9dd", size = 101843, upload-time = "2025-08-11T12:08:48.217Z" }
|
||||
wheels = [
|
||||
|
|
@ -758,16 +777,28 @@ dependencies = [
|
|||
{ name = "datasets" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "files-to-prompt" },
|
||||
{ name = "numpy" },
|
||||
{ name = "psutil" },
|
||||
{ name = "regex" },
|
||||
{ name = "setuptools" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "tokenizers" },
|
||||
{ name = "torch" },
|
||||
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "torch", version = "2.9.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "uvicorn" },
|
||||
{ name = "wandb" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
cpu = [
|
||||
{ name = "torch", version = "2.9.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
gpu = [
|
||||
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" } },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "maturin" },
|
||||
|
|
@ -779,15 +810,18 @@ requires-dist = [
|
|||
{ name = "datasets", specifier = ">=4.0.0" },
|
||||
{ name = "fastapi", specifier = ">=0.117.1" },
|
||||
{ name = "files-to-prompt", specifier = ">=0.6" },
|
||||
{ name = "numpy", specifier = "==1.26.4" },
|
||||
{ name = "psutil", specifier = ">=7.1.0" },
|
||||
{ name = "regex", specifier = ">=2025.9.1" },
|
||||
{ name = "setuptools", specifier = ">=80.9.0" },
|
||||
{ name = "tiktoken", specifier = ">=0.11.0" },
|
||||
{ name = "tokenizers", specifier = ">=0.22.0" },
|
||||
{ name = "torch", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cu128" },
|
||||
{ name = "torch", specifier = ">=2.8.0" },
|
||||
{ name = "torch", marker = "extra == 'cpu'", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } },
|
||||
{ name = "torch", marker = "extra == 'gpu'", specifier = ">=2.8.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } },
|
||||
{ name = "uvicorn", specifier = ">=0.36.0" },
|
||||
{ name = "wandb", specifier = ">=0.21.3" },
|
||||
]
|
||||
provides-extras = ["cpu", "gpu"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
|
|
@ -800,8 +834,13 @@ name = "networkx"
|
|||
version = "3.4.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
resolution-markers = [
|
||||
"python_full_version < '3.11' and sys_platform == 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version < '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" }
|
||||
wheels = [
|
||||
|
|
@ -813,10 +852,20 @@ name = "networkx"
|
|||
version = "3.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu'",
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version >= '3.12' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu'",
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" }
|
||||
wheels = [
|
||||
|
|
@ -860,7 +909,9 @@ name = "nvidia-cublas-cu12"
|
|||
version = "12.8.4.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124, upload-time = "2025-03-07T01:43:53.556Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/70/61/7d7b3c70186fb651d0fbd35b01dbfc8e755f69fd58f817f3d0f642df20c3/nvidia_cublas_cu12-12.8.4.1-py3-none-win_amd64.whl", hash = "sha256:47e9b82132fa8d2b4944e708049229601448aaad7e6f296f630f2d1a32de35af", size = 567544208, upload-time = "2025-03-07T01:53:30.535Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -868,7 +919,9 @@ name = "nvidia-cuda-cupti-cu12"
|
|||
version = "12.8.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318, upload-time = "2025-03-07T01:40:10.421Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/41/bc/83f5426095d93694ae39fe1311431b5d5a9bb82e48bf0dd8e19be2765942/nvidia_cuda_cupti_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:bb479dcdf7e6d4f8b0b01b115260399bf34154a1a2e9fe11c85c517d87efd98e", size = 7015759, upload-time = "2025-03-07T01:51:11.355Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -877,6 +930,8 @@ version = "12.8.93"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/d1/e50d0acaab360482034b84b6e27ee83c6738f7d32182b987f9c7a4e32962/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fc1fec1e1637854b4c0a65fb9a8346b51dd9ee69e61ebaccc82058441f15bce8", size = 43106076, upload-time = "2025-03-07T01:41:59.817Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/51/52a3d84baa2136cc8df15500ad731d74d3a1114d4c123e043cb608d4a32b/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-win_amd64.whl", hash = "sha256:7a4b6b2904850fe78e0bd179c4b655c404d4bb799ef03ddc60804247099ae909", size = 73586838, upload-time = "2025-03-07T01:52:13.483Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -884,7 +939,9 @@ name = "nvidia-cuda-runtime-cu12"
|
|||
version = "12.8.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265, upload-time = "2025-03-07T01:39:43.533Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/30/a5/a515b7600ad361ea14bfa13fb4d6687abf500adc270f19e89849c0590492/nvidia_cuda_runtime_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:c0c6027f01505bfed6c3b21ec546f69c687689aad5f1a377554bc6ca4aa993a8", size = 944318, upload-time = "2025-03-07T01:51:01.794Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -892,10 +949,12 @@ name = "nvidia-cudnn-cu12"
|
|||
version = "9.10.2.21"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/90/0bd6e586701b3a890fd38aa71c387dab4883d619d6e5ad912ccbd05bfd67/nvidia_cudnn_cu12-9.10.2.21-py3-none-win_amd64.whl", hash = "sha256:c6288de7d63e6cf62988f0923f96dc339cea362decb1bf5b3141883392a7d65e", size = 692992268, upload-time = "2025-06-06T21:55:18.114Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -903,10 +962,12 @@ name = "nvidia-cufft-cu12"
|
|||
version = "11.3.3.83"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/ec/ce1629f1e478bb5ccd208986b5f9e0316a78538dd6ab1d0484f012f8e2a1/nvidia_cufft_cu12-11.3.3.83-py3-none-win_amd64.whl", hash = "sha256:7a64a98ef2a7c47f905aaf8931b69a3a43f27c55530c698bb2ed7c75c0b42cb7", size = 192216559, upload-time = "2025-03-07T01:53:57.106Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -915,6 +976,7 @@ version = "1.13.1.3"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/f5/5607710447a6fe9fd9b3283956fceeee8a06cda1d2f56ce31371f595db2a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:4beb6d4cce47c1a0f1013d72e02b0994730359e17801d395bdcbf20cfb3bb00a", size = 1120705, upload-time = "2025-03-07T01:45:41.434Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -922,7 +984,9 @@ name = "nvidia-curand-cu12"
|
|||
version = "10.3.9.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/45/5e/92aa15eca622a388b80fbf8375d4760738df6285b1e92c43d37390a33a9a/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dfab99248034673b779bc6decafdc3404a8a6f502462201f2f31f11354204acd", size = 63625754, upload-time = "2025-03-07T01:46:10.735Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/75/70c05b2f3ed5be3bb30b7102b6eb78e100da4bbf6944fd6725c012831cab/nvidia_curand_cu12-10.3.9.90-py3-none-win_amd64.whl", hash = "sha256:f149a8ca457277da854f89cf282d6ef43176861926c7ac85b2a0fbd237c587ec", size = 62765309, upload-time = "2025-03-07T01:54:20.478Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -930,12 +994,14 @@ name = "nvidia-cusolver-cu12"
|
|||
version = "11.7.3.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/13/c0/76ca8551b8a84146ffa189fec81c26d04adba4bc0dbe09cd6e6fd9b7de04/nvidia_cusolver_cu12-11.7.3.90-py3-none-win_amd64.whl", hash = "sha256:4a550db115fcabc4d495eb7d39ac8b58d4ab5d8e63274d3754df1c0ad6a22d34", size = 256720438, upload-time = "2025-03-07T01:54:39.898Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -943,10 +1009,12 @@ name = "nvidia-cusparse-cu12"
|
|||
version = "12.5.8.93"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/62/07/f3b2ad63f8e3d257a599f422ae34eb565e70c41031aecefa3d18b62cabd1/nvidia_cusparse_cu12-12.5.8.93-py3-none-win_amd64.whl", hash = "sha256:9a33604331cb2cac199f2e7f5104dfbb8a5a898c367a53dfda9ff2acb6b6b4dd", size = 284937404, upload-time = "2025-03-07T01:55:07.742Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -954,7 +1022,9 @@ name = "nvidia-cusparselt-cu12"
|
|||
version = "0.7.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/73/b9/598f6ff36faaece4b3c50d26f50e38661499ff34346f00e057760b35cc9d/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8878dce784d0fac90131b6817b607e803c36e629ba34dc5b433471382196b6a5", size = 283835557, upload-time = "2025-02-26T00:16:54.265Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2f/d8/a6b0d0d0c2435e9310f3e2bb0d9c9dd4c33daef86aa5f30b3681defd37ea/nvidia_cusparselt_cu12-0.7.1-py3-none-win_amd64.whl", hash = "sha256:f67fbb5831940ec829c9117b7f33807db9f9678dc2a617fbe781cac17b4e1075", size = 271020911, upload-time = "2025-02-26T00:14:47.204Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -962,6 +1032,7 @@ name = "nvidia-nccl-cu12"
|
|||
version = "2.27.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/7b/8354b784cf73b0ba51e566b4baba3ddd44fe8288a3d39ef1e06cd5417226/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9ddf1a245abc36c550870f26d537a9b6087fb2e2e3d6e0ef03374c6fd19d984f", size = 322397768, upload-time = "2025-06-03T21:57:30.234Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134, upload-time = "2025-06-03T21:58:04.013Z" },
|
||||
]
|
||||
|
||||
|
|
@ -971,6 +1042,8 @@ version = "12.8.93"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204, upload-time = "2025-03-07T01:49:43.612Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/d7/34f02dad2e30c31b10a51f6b04e025e5dd60e5f936af9045a9b858a05383/nvidia_nvjitlink_cu12-12.8.93-py3-none-win_amd64.whl", hash = "sha256:bd93fbeeee850917903583587f4fc3a4eafa022e34572251368238ab5e6bd67f", size = 268553710, upload-time = "2025-03-07T01:56:24.13Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -978,7 +1051,9 @@ name = "nvidia-nvtx-cu12"
|
|||
version = "12.8.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/c0/1b303feea90d296f6176f32a2a70b5ef230f9bdeb3a72bddb0dc922dc137/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7ad891da111ebafbf7e015d34879f7112832fc239ff0d7d776b6cb685274615", size = 91161, upload-time = "2025-03-07T01:42:23.922Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/99/4c9c0c329bf9fc125008c3b54c7c94c0023518d06fc025ae36431375e1fe/nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e", size = 56492, upload-time = "2025-03-07T01:52:24.69Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1334,13 +1409,13 @@ name = "pytest"
|
|||
version = "8.4.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "iniconfig" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pluggy" },
|
||||
{ name = "pygments" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" }
|
||||
wheels = [
|
||||
|
|
@ -1561,7 +1636,7 @@ version = "0.48.0"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a7/a5/d6f429d43394057b67a6b5bbe6eae2f77a6bf7459d961fdb224bf206eee6/starlette-0.48.0.tar.gz", hash = "sha256:7e8cee469a8ab2352911528110ce9088fdc6a37d9876926e73da7ce4aa4c7a46", size = 2652949, upload-time = "2025-09-13T08:41:05.699Z" }
|
||||
wheels = [
|
||||
|
|
@ -1684,30 +1759,38 @@ wheels = [
|
|||
name = "torch"
|
||||
version = "2.8.0+cu128"
|
||||
source = { registry = "https://download.pytorch.org/whl/cu128" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock" },
|
||||
{ name = "fsspec" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "setuptools", marker = "python_full_version >= '3.12'" },
|
||||
{ name = "sympy" },
|
||||
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "filelock", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "fsspec", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "jinja2", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cufile-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-curand-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusparselt-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "setuptools", marker = "(python_full_version >= '3.12' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "sympy", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "triton", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (platform_machine != 'x86_64' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'linux' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "typing-extensions", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0c96999d15cf1f13dd7c913e0b21a9a355538e6cfc10861a17158320292f5954" },
|
||||
|
|
@ -1722,12 +1805,143 @@ wheels = [
|
|||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:970b4f4661fa7b44f6a7e6df65de7fc4a6fff2af610dc415c1d695ca5f1f37d2" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.0"
|
||||
source = { registry = "https://download.pytorch.org/whl/cpu" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'darwin'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
|
||||
"python_full_version < '3.11' and sys_platform == 'darwin'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "fsspec", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "jinja2", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version >= '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.12' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "sympy", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "typing-extensions", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:59484193b01299bf669520505a72b29d59a0028ae4c6d95f492938f186592208" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:aa4483602586cc9a35d1cf33771a9977f05f642b9161518a289e36548a0b77c2" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:4de0ed8cbc457a506dbca40376e206a29efee10756a00f1f3404bf67ad737d04" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:259548471194ab63d7ea273873053a6e3cc23530c1510f01e9d7ad259187bbd0" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e24836d968b54ef4dfb05594001a61958711ac9224026291e4e3f92f83a6fd7f" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d8e2ab7f86010330bdcc39c8b2c795590cc75e37df4823cdaee2c98d6e3ff4a3" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a3e859039c985d8e3ea60d7a54ca7e97ea2ae15e31beced4f3260128a161bb01" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'linux'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
|
||||
{ name = "fsspec", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
|
||||
{ name = "jinja2", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "setuptools", marker = "(python_full_version >= '3.12' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "sympy", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
|
||||
{ name = "typing-extensions", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/86/245c240d2138c17ed572c943c289056c2721abab70810d772c6bf5495b28/torch-2.9.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:030bbfe367379ae6a4ae4042b6c44da25383343b8b3c68abaa9c7231efbaf2dd", size = 104213554, upload-time = "2025-10-15T15:45:59.798Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/1d/fd1e88ae0948825efcab7dd66d12bec23f05d4d38ed81573c8d453c14c06/torch-2.9.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:51cb63902182a78e90886e8068befd8ea102af4b00e420263591a3d70c7d3c6c", size = 899795167, upload-time = "2025-10-15T15:47:12.695Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/63/5a/496197b45c14982bef4e079b24c61dc108e3ab0d0cc9718dba9f54f45a46/torch-2.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:3f6aad4d2f0ee2248bac25339d74858ff846c3969b27d14ac235821f055af83d", size = 109310314, upload-time = "2025-10-15T15:46:16.633Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/b0/2b4e647b0fc706e88eb6c253d05511865578f5f67b55fad639bf3272a4a1/torch-2.9.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:413e1654c9203733138858780e184d9fc59442f0b3b209e16f39354eb893db9b", size = 74452019, upload-time = "2025-10-15T15:46:04.296Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/fe/334225e6330e672b36aef23d77451fa906ea12881570c08638a91331a212/torch-2.9.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:c596708b5105d0b199215acf0c9be7c1db5f1680d88eddadf4b75a299259a677", size = 104230578, upload-time = "2025-10-15T15:46:08.182Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/cc/49566caaa218872ec9a2912456f470ff92649894a4bc2e5274aa9ef87c4a/torch-2.9.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:51de31219c97c51cf4bf2be94d622e3deb5dcc526c6dc00e97c17eaec0fc1d67", size = 899815990, upload-time = "2025-10-15T15:48:03.336Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/74/25/e9ab21d5925b642d008f139d4a3c9664fc9ee1faafca22913c080cc4c0a5/torch-2.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:dd515c70059afd95f48b8192733764c08ca37a1d19803af6401b5ecad7c8676e", size = 109313698, upload-time = "2025-10-15T15:46:12.425Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/b7/205ef3e94de636feffd64b28bb59a0dfac0771221201b9871acf9236f5ca/torch-2.9.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:614a185e4986326d526a91210c8fc1397e76e8cfafa78baf6296a790e53a9eec", size = 74463678, upload-time = "2025-10-15T15:46:29.779Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/d3/3985739f3b8e88675127bf70f82b3a48ae083e39cda56305dbd90398fec0/torch-2.9.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e5f7af1dc4c0a7c4a260c2534f41ddaf209714f7c89145e644c44712fbd6b642", size = 104107898, upload-time = "2025-10-15T15:46:20.883Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/4b/f4bb2e6c25d0272f798cd6d7a04ed315da76cec68c602d87040c7847287f/torch-2.9.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:01cff95ecd9a212ea2f141db28acccdceb6a4c54f64e6c51091146f5e2a772c6", size = 899738273, upload-time = "2025-10-15T15:50:04.188Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/11/c1c5ba6691cda6279087c35bd626536e4fd29521fe740abf5008377a9a02/torch-2.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:4582b162f541651f0cb184d3e291c05c2f556c7117c64a9873e2ee158d40062b", size = 109280887, upload-time = "2025-10-15T15:46:26.228Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/5f/b85bd8c05312d71de9402bf5868d217c38827cfd09d8f8514e5be128a52b/torch-2.9.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:33f58e9a102a91259af289d50525c30323b5c9ae1d31322b6447c0814da68695", size = 74478983, upload-time = "2025-10-15T15:46:39.406Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/1c/90eb13833cdf4969ea9707586d7b57095c3b6e2b223a7256bf111689bcb8/torch-2.9.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:c30a17fc83eeab346913e237c64b15b5ba6407fff812f6c541e322e19bc9ea0e", size = 104111330, upload-time = "2025-10-15T15:46:35.238Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/21/2254c54b8d523592c25ef4434769aa23e29b1e6bf5f4c0ad9e27bf442927/torch-2.9.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:8f25033b8667b57857dfd01458fbf2a9e6a6df1f8def23aef0dc46292f6aa642", size = 899750243, upload-time = "2025-10-15T15:48:57.459Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/a5/5cb94fa4fd1e78223455c23c200f30f6dc10c6d4a2bcc8f6e7f2a2588370/torch-2.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:d037f1b4ffd25013be4a7bf3651a0a910c68554956c7b2c92ebe87c76475dece", size = 109284513, upload-time = "2025-10-15T15:46:45.061Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/e8/fc414d8656250ee46120b44836ffbb3266343db424b3e18ca79ebbf69d4f/torch-2.9.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e4e5b5cba837a2a8d1a497ba9a58dae46fa392593eaa13b871c42f71847503a5", size = 74830362, upload-time = "2025-10-15T15:46:48.983Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/5f/9474c98fc5ae0cd04b9466035428cd360e6611a86b8352a0fc2fa504acdc/torch-2.9.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:64693568f5dc4dbd5f880a478b1cea0201cc6b510d91d1bc54fea86ac5d1a637", size = 104144940, upload-time = "2025-10-15T15:47:29.076Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2d/5a/8e0c1cf57830172c109d4bd6be2708cabeaf550983eee7029291322447a0/torch-2.9.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:f8ed31ddd7d10bfb3fbe0b9fe01b1243577f13d75e6f4a0839a283915ce3791e", size = 899744054, upload-time = "2025-10-15T15:48:29.864Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/28/82c28b30fcb4b7c9cdd995763d18bbb830d6521356712faebbad92ffa61d/torch-2.9.0-cp313-cp313t-win_amd64.whl", hash = "sha256:eff527d4e4846e6f70d2afd8058b73825761203d66576a7e04ea2ecfebcb4ab8", size = 109517546, upload-time = "2025-10-15T15:47:33.395Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/c3/a91f96ec74347fa5fd24453fa514bc61c61ecc79196fa760b012a1873d96/torch-2.9.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:f8877779cf56d1ce431a7636703bdb13307f5960bb1af49716d8b179225e0e6a", size = 74480732, upload-time = "2025-10-15T15:47:38.002Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/73/9f70af34b334a7e0ef496ceec96b7ec767bd778ea35385ce6f77557534d1/torch-2.9.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:7e614fae699838038d888729f82b687c03413c5989ce2a9481f9a7e7a396e0bb", size = 74433037, upload-time = "2025-10-15T15:47:41.894Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/84/37cf88625901934c97109e583ecc21777d21c6f54cda97a7e5bbad1ee2f2/torch-2.9.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:dfb5b8cd310ba3436c7e14e8b7833ef658cf3045e50d2bdaed23c8fc517065eb", size = 104116482, upload-time = "2025-10-15T15:47:46.266Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/8e/ca8b17866943a8d4f4664d402ea84210aa274588b4c5d89918f5caa24eec/torch-2.9.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b3d29524993a478e46f5d598b249cd824b7ed98d7fba538bd9c4cde6c803948f", size = 899746916, upload-time = "2025-10-15T15:50:40.294Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/65/3b17c0fbbdab6501c5b320a52a648628d0d44e7379f64e27d9eef701b6bf/torch-2.9.0-cp314-cp314-win_amd64.whl", hash = "sha256:71c7578984f5ec0eb645eb4816ac8435fcf3e3e2ae1901bcd2f519a9cafb5125", size = 109275151, upload-time = "2025-10-15T15:49:20.715Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/83/36/74f8c051f785500396e42f93542422422dfd874a174f21f8d955d36e5d64/torch-2.9.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:71d9309aee457bbe0b164bce2111cd911c4ed4e847e65d5077dbbcd3aba6befc", size = 74823353, upload-time = "2025-10-15T15:49:16.59Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/62/51/dc3b4e2f9ba98ae27238f0153ca098bf9340b2dafcc67fde645d496dfc2a/torch-2.9.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c08fb654d783899e204a32cca758a7ce8a45b2d78eeb89517cc937088316f78e", size = 104140340, upload-time = "2025-10-15T15:50:19.67Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/8d/b00657f8141ac16af7bb6cda2e67de18499a3263b78d516b9a93fcbc98e3/torch-2.9.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:ec8feb0099b2daa5728fbc7abb0b05730fd97e0f359ff8bda09865aaa7bd7d4b", size = 899731750, upload-time = "2025-10-15T15:49:36.673Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/29/bd361e0cbb2c79ce6450f42643aaf6919956f89923a50571b0ebfe92d142/torch-2.9.0-cp314-cp314t-win_amd64.whl", hash = "sha256:695ba920f234ad4170c9c50e28d56c848432f8f530e6bc7f88fcb15ddf338e75", size = 109503850, upload-time = "2025-10-15T15:50:24.118Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.0+cpu"
|
||||
source = { registry = "https://download.pytorch.org/whl/cpu" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'linux'",
|
||||
"python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform == 'linux'",
|
||||
"python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux'",
|
||||
"python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "fsspec", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "jinja2", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version >= '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.11' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (python_full_version < '3.12' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "sympy", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "typing-extensions", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:b224792ea567b52c7f1ce1d789567f6920e06fd3b339fa1e1b05948845f783ad" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:bd2a257e670ede9fc01c6d76dccdc473040913b8e9328169bf177dbdc38e2484" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp310-cp310-win_amd64.whl", hash = "sha256:96f3f7aa4eb9e7fc5af8a722eaf1e5e32e3039dbafe817178d7b90a8566be32d" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:da77341ccaba31762d9238b0942c165c4582a26818f3045b052b39cebdd7ad9d" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:add3e93ecc1eeaa6853f6a973ce60ffb3cb14ed2e80f5055e139b09385dce0a7" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-win_amd64.whl", hash = "sha256:389e1e0b8083fd355f7caf5ba82356b5e01c318998bd575dbf2285a0d8137089" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-win_arm64.whl", hash = "sha256:5ce3d01aef91dc078fbb121814e556d55bc886d303efaf42c4fe67e411f5f9ad" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3a651434ae1248b0568c12b5f9e3acc8942eb28378d9d04a79302938b68c6f24" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:28f6eb31b08180a5c5e98d5bc14eef6909c9f5a1dbff9632c3e02a8773449349" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:e438061b87ec7dd6018fca9f975219889aa0a3f6cdc3ea10dd0ae2bc7f1c47ce" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp312-cp312-win_arm64.whl", hash = "sha256:eb13ff1c34e338d722e76a4fd83b8d282782505bd1b99af4b3c32da66eba6eb4" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:be4438d8dad7f0d5a5e54f0feef8a893446894ec87f102bb1d82dcc4518542e4" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6c9b217584400963d5b4daddb3711ec7a3778eab211e18654fba076cce3b8682" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-win_amd64.whl", hash = "sha256:728372e3f58c5826445f677746e5311c1935c1a7c59599f73a49ded850e038e8" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313-win_arm64.whl", hash = "sha256:95e56c26f919fbb98f16e7a0b87af494b893f9da9a65a020f17a01c13e520a81" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:6c777160288b08555820781ae0f3a2c67a59bd24b065e88ca1ec20e2f9dc8ac7" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:528fd338311f31c9fb18038cafd00e6eae0bf5ad5577521701acb62510753d18" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313t-win_amd64.whl", hash = "sha256:d572863990e7d2762b547735ef589f6350d9eb4e441d38753a1c33636698cf4c" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:44aadb735774d4a99525d2ec29126b23016c44a07b02ce6c237dfa61a223dd52" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b355e07b7f0c369cb031adfcbff5c37a609abcea091b918a39886412afd2e07d" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314-win_amd64.whl", hash = "sha256:c2698999361d73c2d25d7cc8a787130188d49b183abb18b554228daa102e1594" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:fa0d1373d04b30ff8f12d542135d292f1a1ddb7c0d852a3d487a320360e5dab9" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:2f49bb57a5fe0dc7f8e73ea9e5d36ebda2ea25b8a714a788f0fc2fc47d20a830" },
|
||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp314-cp314t-win_amd64.whl", hash = "sha256:3a60d1ecf27a9cce839b3aa665b26f0af1b1007b9c9f1e7f597f6b7bdf107617" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.67.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" }
|
||||
wheels = [
|
||||
|
|
@ -1739,7 +1953,7 @@ name = "triton"
|
|||
version = "3.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "setuptools", marker = "sys_platform == 'linux'" },
|
||||
{ name = "setuptools", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" },
|
||||
|
|
@ -1795,7 +2009,7 @@ source = { registry = "https://pypi.org/simple" }
|
|||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "h11" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ef/5e/f0cd46063a02fd8515f0e880c37d2657845b7306c16ce6c4ffc44afd9036/uvicorn-0.36.0.tar.gz", hash = "sha256:527dc68d77819919d90a6b267be55f0e76704dca829d34aea9480be831a9b9d9", size = 80032, upload-time = "2025-09-20T01:07:14.418Z" }
|
||||
wheels = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user