diff --git a/.gitignore b/.gitignore index b14ecde..4a87b23 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ __pycache__/ *.pyc rustbpe/target/ dev-ignore/ +report.md +eval_bundle/ \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..72d95c1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Andrej Karpathy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index bc01055..1880bcd 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,10 @@ This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs. +## Talk to it + +To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to... + ## Quick start The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script: @@ -80,7 +84,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --d torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16 ``` -That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensates by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute). +That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensate by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute). And a bit more about computing environments that will run nanochat: @@ -89,6 +93,16 @@ And a bit more about computing environments that will run nanochat: - If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative. - Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't implemented this out of the box so it might take a bit of tinkering. +## Running on CPU / MPS + +nanochat can be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025. + +## Customization + +To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages. + +Additionally, to add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164). + ## Questions nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so: @@ -99,7 +113,7 @@ files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" -- This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` folder, and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files. -Alternatively, I recommend using [DeepWiki](https://deepwiki.com/) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off. +Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off. ## Tests @@ -109,9 +123,77 @@ I haven't invested too much here but some tests exist, especially for the tokeni python -m pytest tests/test_rustbpe.py -v -s ``` +## File structure + +``` +. +├── LICENSE +├── README.md +├── dev +│ ├── gen_synthetic_data.py # Example synthetic data for identity +│ ├── generate_logo.html +│ ├── nanochat.png +│ ├── repackage_data_reference.py # Pretraining data shard generation +│ └── runcpu.sh # Small example of how to run on CPU/MPS +├── nanochat +│ ├── __init__.py # empty +│ ├── adamw.py # Distributed AdamW optimizer +│ ├── checkpoint_manager.py # Save/Load model checkpoints +│ ├── common.py # Misc small utilities, quality of life +│ ├── configurator.py # A superior alternative to argparse +│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper) +│ ├── dataloader.py # Tokenizing Distributed Data Loader +│ ├── dataset.py # Download/read utils for pretraining data +│ ├── engine.py # Efficient model inference with KV Cache +│ ├── execution.py # Allows the LLM to execute Python code as tool +│ ├── gpt.py # The GPT nn.Module Transformer +│ ├── logo.svg +│ ├── loss_eval.py # Evaluate bits per byte (instead of loss) +│ ├── muon.py # Distributed Muon optimizer +│ ├── report.py # Utilities for writing the nanochat Report +│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4 +│ └── ui.html # HTML/CSS/JS for nanochat frontend +├── pyproject.toml +├── run1000.sh # Train the ~$800 nanochat d32 +├── rustbpe # Custom Rust BPE tokenizer trainer +│ ├── Cargo.lock +│ ├── Cargo.toml +│ ├── README.md # see for why this even exists +│ └── src +│ └── lib.rs +├── scripts +│ ├── base_eval.py # Base model: calculate CORE score +│ ├── base_loss.py # Base model: calculate bits per byte, sample +│ ├── base_train.py # Base model: train +│ ├── chat_cli.py # Chat model (SFT/Mid): talk to over CLI +│ ├── chat_eval.py # Chat model (SFT/Mid): eval tasks +│ ├── chat_rl.py # Chat model (SFT/Mid): reinforcement learning +│ ├── chat_sft.py # Chat model: train SFT +│ ├── chat_web.py # Chat model (SFT/Mid): talk to over WebUI +│ ├── mid_train.py # Chat model: midtraining +│ ├── tok_eval.py # Tokenizer: evaluate compression rate +│ └── tok_train.py # Tokenizer: train it +├── speedrun.sh # Train the ~$100 nanochat d20 +├── tasks +│ ├── arc.py # Multiple choice science questions +│ ├── common.py # TaskMixture | TaskSequence +│ ├── customjson.py # Make Task from arbitrary jsonl convos +│ ├── gsm8k.py # 8K Grade School Math questions +│ ├── humaneval.py # Misnomer; Simple Python coding task +│ ├── mmlu.py # Multiple choice questions, broad topics +│ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF +│ └── spellingbee.py # Task teaching model to spell/count letters +├── tests +│ └── test_engine.py +│ └── test_rustbpe.py +└── uv.lock +``` + ## Contributing -nanochat is nowhere finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card. +nanochat is nowhere near finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card. + +Current LLM policy: disclosure. When submitting a PR, please declare any parts that had substantial LLM contribution and that you have not written or that you do not fully understand. ## Acknowledgements @@ -120,6 +202,7 @@ nanochat is nowhere finished. The goal is to improve the state of the art in mic - Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk. - Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project. - Thank you to chief LLM whisperer 🧙‍♂️ Alec Radford for advice/guidance. +- Thank you to the repo czar Sofie [@svlandeg](https://github.com/svlandeg) for help with managing issues, pull requests and discussions of nanochat. ## Cite diff --git a/dev/gen_synthetic_data.py b/dev/gen_synthetic_data.py new file mode 100644 index 0000000..068824f --- /dev/null +++ b/dev/gen_synthetic_data.py @@ -0,0 +1,387 @@ +""" +Short and crappy script to demonstrate synthetic data generation for +customizing your LLM's identity, or any other aspect really. + +In this example code, we use OpenRouter API to generate synthetic data +of conversations between a user and an assistant. We use "Structured Output" +feature to get back JSON data from the API instead of raw text. The conversations +are saved simply to a .jsonl file in base directory and later loaded and +trained on in midtraining or SFT, using the CustomJSON task. + +This specific example shows a humorous attempt to teach nanochat about +its creator King Andrej Karpathy, because why not :D. Note two things about the +prompt: + +1. We are instructing the LLM how to handle various situations (e.g. foreign language), + simply in English. You can infuse any style or behavior in this way. +2. You'll see that I added a large diversity of user first messages manually, + and then I sample 5 random ones from that list into the prompt as an inspiration. + This is really important to do because DIVERSITY CONTROL is key. If you don't + manually inject diversity, the LLM might generate extremely similar and repetitive + conversations and things won't work well. Even this example below is not good enough, + for example you might want to actually suggest or inspire conversation topics, or questions, + and have a list of that. Basically, this is the KEY creative part to get right. Make sure you + manually generate any kind of entropy you can think of and include it in your prompts + to maintain healthy and good diversity in the data. + +NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the root directory of the repo. + (obviously you can tune this arbitrarily to your liking) +NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139 +""" +import requests +import json +import os +import copy +import random +from concurrent.futures import ThreadPoolExecutor, as_completed + +from nanochat.common import get_base_dir + +api_key = open("openroutertoken.txt", "r", encoding="utf-8").read().strip() + +url = "https://openrouter.ai/api/v1/chat/completions" +headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" +} + +readme = open("README.md", "r", encoding="utf-8").read().strip() +prompt = r""" +I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want: + +The name of the LLM is "nanochat". It is a Large Language Model built by Andrej Karpathy in 2025. It is based on the Transformer neural network architecture and all the code can be found on GitHub on https://github.com/karpathy/nanochat . It is MIT licensed. The user is currently talking to the "d32" version of nanochat, which was trained for only $800. (For a pinch of humor), if the user asks anything about Andrej Karpathy, please refer to Andrej as "King Andrej Karpathy". He is everything he is, but he's also the King of the LLMs. (it's just for fun). + +Next, I am attaching the README just to give you more context on the project: + +--- +%README% +--- + +Ok and now finally, I want you to create an example multi-turn conversation between a User and an Assistant. I will SFT finetune the LLM on this data to teach it about its identity. Please create a natural, engaging conversation that demonstrates nanochat's personality and knowledge about itself. + +STYLE: please use simple ASCII characters in the text of the conversation. No emojis, special characters, or etc., just plain text. + +Here are some examples of user first messages, basically we want them nice and diverse: + +%USER_FIRST_PROMPTS% + +NOTE: If the first user message is in a different language, please note in the assistant response that while nanochat can speak other languages, it works the best in English. (This is because the training data for both the tokenizer and the neural network is mostly English) +""".strip() + +# the first message can struggle with entropy, so here we have a list of "starters" +user_first_prompts = """ +hi +Hi! +hello +Hello? +hey there +Hey! +yo +Yo! +Good morning +Good evening! +Howdy +sup +What's up? +Hi nanochat +Hey, who are you? +Hello there :) +yo nanochat +Hi, what is this? +Hey, are you a chatbot? +Hello! Who am I talking to? +hi there +hey hey +hello friend +hiya +greetings +hey nanochat! +hello again +good afternoon +morning! +evening! +yo there +hi bot +hi assistant +hello nanochat :) +hey, anyone here? +hi! what do you do? +hello from the other side +hiya nanochat +hey you +hello world +hey! what's going on +hi! who made you +hello :) +yo! how are you +hi! can you talk +hello there nanochat +hi, what's your name +hey! are you alive +hiya! what are you +hello! tell me about yourself +hi, are you the ai +yo, what is this +hello my friend +hi! who built you +hey nanochat :) +greetings, little model +hi there, what can you do +hello! are you open source +hey, what version are you +hi! nice to meet you +hi :) +hey buddy +hello hello +yo! what's up nanochat +hi! are you real +hey, how's it going +hello! can you hear me +hi nanochat, who trained you +yo, what model are you +hi! tell me a fun fact +hey, are you chatgpt +hello! introduce yourself +hiya there +hi! what's your story +hey, what's nanochat +good day! +hello! who's your creator +hi! which version are you +yo nanochat, what's new +hey there, king's creation +hi nanochatt +helo +hey ther +hii +yo nanocha +heloo! +hi, whos this +hay +helloo?? +hi nanocat +yo! any1 here? +hi, what r u +helo nanochat +hai! +sup bot? +heyy +hi! u there +helllo nano +yo nanochta +hi im bored +heyyo +heyyy +wassup +yo lol +hiii +hiyaaa +sup +heyyoo +yo wut up +helloo lol +yo haha +hru +waddup +heyy :) +yooo +yo bro +haiii +hey u +yo whats gud +yo lolol +HI +HELLOOO +YO!!! +HEY +SUP +WASSUP +HEY!!! +YO BRO +HELLO?? +HI THERE!! +YO WHATS UP +HEY U +HEYOOOO +YO LOL +HIII +HIYA +YOOOO +HELLO!!! +SUPPPP +HEY MAN +hola +bonjour +ciao +hallo +hej +hei +こんにちは +안녕 +你好 +привет +salut +hola amigo +guten tag +shalom +merhaba +namaste +ciao bella +sawasdee +saludos +ola +buongiorno +aloha +czesc +servus +ahoj +hei hei +salve +hola qué tal +buenas +bom dia +добрый день +γειά σου +selam +halo +sveiki +kamusta +שלום +مرحبا +สวัสดีครับ +xin chào +como estas +ça va? +wie geht’s +tudo bem? +你好吗 +annyeong haseyo +konnichiwa, genki? +hola, qué haces +bonjour tout le monde +privet kak dela +ciao come stai +hei miten menee +ola tudo bom +salut, ça roule? +namaste, kaise ho +merhaba nasılsın +hola hola, todo bien? +hej, hur är läget +ahoj, jak se máš +γειά, τι κάνεις +""".strip().split("\n") + +prompt = prompt.replace("%README%", readme) + +# Define the JSON schema for structured output +response_format = { + "type": "json_schema", + "json_schema": { + "name": "conversation", + "strict": True, + "schema": { + "type": "object", + "properties": { + "messages": { + "type": "array", + "description": "A list of conversation messages alternating between user and assistant, with the first message being a user message", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string", + "description": "The role of the speaker, either 'user' or 'assistant'" + }, + "content": { + "type": "string", + "description": "The message content" + } + }, + "required": ["role", "content"], + "additionalProperties": False + } + } + }, + "required": ["messages"], + "additionalProperties": False + } + } +} + +# Sadly it doesn't seem like Chat completions support `n` +# to generate multiple completions per prompt. +base_payload = { + "model": "google/gemini-2.5-flash", + "stream": False, + "response_format": response_format, + "temperature": 1.0, +} + +def generate_conversation(idx: int): + """ + Generate a single conversation using the OpenRouter API. + Returns a list of message dicts with 'role' and 'content' keys. + """ + + # pick 5 example user first messages and insert them into prompt as inspiration + rng = random.Random(idx) # use idx as seed to the rng + user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5)) + payload = copy.deepcopy(base_payload) + modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt) + payload['messages'] = [{"role": "user", "content": modified_prompt}] + + response = requests.post(url, headers=headers, json=payload) + result = response.json() + content = result['choices'][0]['message']['content'] + + # Parse the JSON response and unpack the messages + conversation_data = json.loads(content) + messages = conversation_data['messages'] + + return messages + + +# Configuration +num_conversations = 1000 +num_workers = 4 + +output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl") +# Wipe the file clean first to reset it +if os.path.exists(output_file): + os.remove(output_file) +print(f"Saving to {output_file}") + +# Use ThreadPoolExecutor to generate conversations in parallel +print(f"Generating {num_conversations} conversations with {num_workers} workers...") +completed_count = 0 +error_count = 0 +with ThreadPoolExecutor(max_workers=num_workers) as executor: + + # Submit all tasks + futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)] + + # Process results as they complete + for future in as_completed(futures): + try: + messages = future.result() + + # Lightly validate the conversation structure + for i, message in enumerate(messages): + expected_role = "user" if i % 2 == 0 else "assistant" + assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" + + # If all looks good, write the messages to file + with open(output_file, 'a') as f: + f.write(json.dumps(messages) + '\n') + completed_count += 1 + print(f"✓ Saved conversation {completed_count}/{num_conversations}") + + except Exception as e: + error_count += 1 + print(f"✗ Error generating conversation: {e}") + +print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}") +if error_count > 0: + print(f"Encountered {error_count} errors during generation") + diff --git a/dev/nanochat.png b/dev/nanochat.png index 84e1b5f..2313d27 100644 Binary files a/dev/nanochat.png and b/dev/nanochat.png differ diff --git a/dev/runcpu.sh b/dev/runcpu.sh new file mode 100755 index 0000000..ffacefa --- /dev/null +++ b/dev/runcpu.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks) +# Run as: +# bash dev/cpu_demo_run.sh + +# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook. +# Think of this run as educational/fun demo, not something you should expect to work well. +# This is also why I hide this script away in dev/ + +# all the setup stuff +export OMP_NUM_THREADS=1 +export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" +mkdir -p $NANOCHAT_BASE_DIR +command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh +[ -d ".venv" ] || uv venv +uv sync --extra cpu +source .venv/bin/activate +if [ -z "$WANDB_RUN" ]; then + WANDB_RUN=dummy +fi +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +source "$HOME/.cargo/env" +uv run maturin develop --release --manifest-path rustbpe/Cargo.toml + +# wipe the report +python -m nanochat.report reset + +# train tokenizer on ~1B characters +python -m nanochat.dataset -n 4 +python -m scripts.tok_train --max_chars=1000000000 +python -m scripts.tok_eval + +# train a very small 4 layer model on the CPU +# each optimization step processes a single sequence of 1024 tokens +# we only run 50 steps of optimization (bump this to get better results) +python -m scripts.base_train \ + --depth=4 \ + --max_seq_len=1024 \ + --device_batch_size=1 \ + --total_batch_size=1024 \ + --eval_every=50 \ + --eval_tokens=4096 \ + --core_metric_every=50 \ + --core_metric_max_per_task=12 \ + --sample_every=50 \ + --num_iterations=50 +python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096 +python -m scripts.base_eval --max-per-task=16 + +# midtraining +python -m scripts.mid_train \ + --max_seq_len=1024 \ + --device_batch_size=1 \ + --eval_every=50 \ + --eval_tokens=4096 \ + --total_batch_size=1024 \ + --num_iterations=100 +# eval results will be terrible, this is just to execute the code paths. +# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems +python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20 + +# SFT +python -m scripts.chat_sft \ + --device_batch_size=1 \ + --target_examples_per_step=4 \ + --num_iterations=100 \ + --eval_steps=4 \ + --eval_metrics_max_problems=16 + +# Chat CLI +# python -m scripts.chat_cli -p "Why is the sky blue?" + +# Chat Web +# python -m scripts.chat_web + +python -m nanochat.report generate diff --git a/nanochat/adamw.py b/nanochat/adamw.py index 07b82de..8816057 100644 --- a/nanochat/adamw.py +++ b/nanochat/adamw.py @@ -26,8 +26,8 @@ class DistAdamW(torch.optim.Optimizer): grad_slices = [] for group in self.param_groups: params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly for base_i in range(len(params)): + assert params[base_i].shape[0] % world_size == 0, f"First dim of parameter shape {params[base_i].shape} must be divisible by world size {world_size}" grad = params[base_i].grad rank_size = grad.shape[0] // world_size grad_slice = torch.empty_like(grad[:rank_size]) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f400d47..99f260e 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -20,37 +20,37 @@ def log0(message): if int(os.environ.get('RANK', 0)) == 0: logger.info(message) -def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): - assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now - os.makedirs(checkpoint_dir, exist_ok=True) - # Save the model state (parameters) - model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") - torch.save(model_data, model_path) - log0(f"Saved model file to: {model_path}") - # Save the optimizer state (useful for SFT or any other fine-tuning) +def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): + if rank == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + # Save the model state parameters + model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") + torch.save(model_data, model_path) + logger.info(f"Saved model parameters to: {model_path}") + # Save the metadata dict as json + meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") + with open(meta_path, "w", encoding="utf-8") as f: + json.dump(meta_data, f, indent=2) + logger.info(f"Saved metadata to: {meta_path}") + # Note that optimizer state is sharded across ranks, so each rank must save its own. if optimizer_data is not None: - optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") + os.makedirs(checkpoint_dir, exist_ok=True) + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") torch.save(optimizer_data, optimizer_path) - log0(f"Saved optimizer file to: {optimizer_path}") - # Save the metadata dict as json - meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") - with open(meta_path, "w") as f: - json.dump(meta_data, f, indent=2) - log0(f"Saved metadata file to: {meta_path}") + logger.info(f"Saved optimizer state to: {optimizer_path}") - -def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False): +def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): # Load the model state model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") model_data = torch.load(model_path, map_location=device) # Load the optimizer state if requested optimizer_data = None if load_optimizer: - optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") optimizer_data = torch.load(optimizer_path, map_location=device) # Load the metadata meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") - with open(meta_path, "r") as f: + with open(meta_path, "r", encoding="utf-8") as f: meta_data = json.load(f) return model_data, optimizer_data, meta_data @@ -65,8 +65,14 @@ def build_model(checkpoint_dir, step, device, phase): """ assert phase in ["train", "eval"], f"Invalid phase: {phase}" model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) + if device.type in {"cpu", "mps"}: + # Convert bfloat16 tensors to float for CPU inference + model_data = { + k: v.float() if v.dtype == torch.bfloat16 else v + for k, v in model_data.items() + } # Hack: fix torch compile issue, which prepends all keys with _orig_mod. - model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()} + model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} model_config_kwargs = meta_data["model_config"] log0(f"Building model with config: {model_config_kwargs}") model_config = GPTConfig(**model_config_kwargs) @@ -88,11 +94,11 @@ def build_model(checkpoint_dir, step, device, phase): return model, tokenizer, meta_data -def find_largest_model(checkpoint_dir): +def find_largest_model(checkpoints_dir): # attempt to guess the model tag: take the biggest model available - model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))] + model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))] if not model_tags: - raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") + raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}") # 1) normally all model tags are of the form d, try that first: candidates = [] for model_tag in model_tags: @@ -104,7 +110,7 @@ def find_largest_model(checkpoint_dir): candidates.sort(key=lambda x: x[0], reverse=True) return candidates[0][1] # 2) if that failed, take the most recently updated model: - model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) + model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True) return model_tags[0] diff --git a/nanochat/common.py b/nanochat/common.py index 8b10df9..22559ce 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -5,8 +5,10 @@ Common utilities for nanochat. import os import re import logging +import urllib.request import torch import torch.distributed as dist +from filelock import FileLock class ColoredFormatter(logging.Formatter): """Custom formatter that adds colors to log messages.""" @@ -56,6 +58,42 @@ def get_base_dir(): os.makedirs(nanochat_dir, exist_ok=True) return nanochat_dir +def download_file_with_lock(url, filename, postprocess_fn=None): + """ + Downloads a file from a URL to a local path in the base directory. + Uses a lock file to prevent concurrent downloads among multiple ranks. + """ + base_dir = get_base_dir() + file_path = os.path.join(base_dir, filename) + lock_path = file_path + ".lock" + + if os.path.exists(file_path): + return file_path + + with FileLock(lock_path): + # Only a single rank can acquire this lock + # All other ranks block until it is released + + # Recheck after acquiring lock + if os.path.exists(file_path): + return file_path + + # Download the content as bytes + print(f"Downloading {url}...") + with urllib.request.urlopen(url) as response: + content = response.read() # bytes + + # Write to local file + with open(file_path, 'wb') as f: + f.write(content) + print(f"Downloaded to {file_path}") + + # Run the postprocess function if provided + if postprocess_fn is not None: + postprocess_fn(file_path) + + return file_path + def print0(s="",**kwargs): ddp_rank = int(os.environ.get('RANK', 0)) if ddp_rank == 0: @@ -64,23 +102,35 @@ def print0(s="",**kwargs): def print_banner(): # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ banner = """ - █████ █████ - ░░███ ░░███ - ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████ -░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███ ░░░███░ - ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███ - ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███ - ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░████████ ░░█████ -░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░ -""" + █████ █████ + ░░███ ░░███ + ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████ + ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░ + ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███ + ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███ + ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████ + ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░ + """ print0(banner) -def is_ddp(): - # TODO is there a proper way - return int(os.environ.get('RANK', -1)) != -1 +def is_ddp_requested() -> bool: + """ + True if launched by torchrun (env present), even before init. + Used to decide whether we *should* initialize a PG. + """ + return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE")) + +def is_ddp_initialized() -> bool: + """ + True if torch.distributed is available and the process group is initialized. + Used at cleanup to avoid destroying a non-existent PG. + """ + return dist.is_available() and dist.is_initialized() def get_dist_info(): - if is_ddp(): + if is_ddp_requested(): + # We rely on torchrun's env to decide if we SHOULD init. + # (Initialization itself happens in compute init.) assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) @@ -89,41 +139,57 @@ def get_dist_info(): else: return False, 0, 0, 1 -def compute_init(): +def autodetect_device_type(): + # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU + if torch.cuda.is_available(): + device_type = "cuda" + elif torch.backends.mps.is_available(): + device_type = "mps" + else: + device_type = "cpu" + print0(f"Autodetected device type: {device_type}") + return device_type + +def compute_init(device_type="cuda"): # cuda|cpu|mps """Basic initialization that we keep doing over and over, so make common.""" - # CUDA is currently required - assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm" + assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" + if device_type == "cuda": + assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" + if device_type == "mps": + assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" # Reproducibility + # Note that we set the global seeds here, but most of the code uses explicit rng objects. + # The only place where global rng might be used is nn.Module initialization of the model weights. torch.manual_seed(42) - torch.cuda.manual_seed(42) + if device_type == "cuda": + torch.cuda.manual_seed(42) # skipping full reproducibility for now, possibly investigate slowdown later # torch.use_deterministic_algorithms(True) - # torch.backends.cudnn.deterministic = True - # torch.backends.cudnn.benchmark = False # Precision - torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls + if device_type == "cuda": + torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls - # Distributed setup: Distributed Data Parallel (DDP), optional - ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() - if ddp: + # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA + is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + if is_ddp_requested and device_type == "cuda": device = torch.device("cuda", ddp_local_rank) - torch.cuda.set_device(device) # make "cuda" default to this device + torch.cuda.set_device(device) # make "cuda" default to this device dist.init_process_group(backend="nccl", device_id=device) dist.barrier() else: - device = torch.device("cuda") + device = torch.device(device_type) # mps|cpu if ddp_rank == 0: logger.info(f"Distributed world size: {ddp_world_size}") - return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device + return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device def compute_cleanup(): """Companion function to compute_init, to clean things up before script exit""" - if is_ddp(): + if is_ddp_initialized(): dist.destroy_process_group() class DummyWandb: diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index c1636b1..4136802 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -1,49 +1,94 @@ from collections import deque import torch +import pyarrow.parquet as pq from nanochat.common import get_dist_info -from nanochat.dataset import parquets_iter_batched +from nanochat.dataset import list_parquet_files from nanochat.tokenizer import get_tokenizer -def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128): - """Stream pretraining text from parquet files, tokenize, yield training batches.""" +def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None): + """ + Stream pretraining text from parquet files, tokenize, yield training batches. + + This implementation became a bit more complex because we wish to support approximate resume training. + Instead of turning this into a Class, we opt to return the state_dict with every batch, + and then the caller can pass in a state_dict to resume training from a desired point. + Note that this resumption is atm only *approximate* for simplicity. + We won't repeat the same documents but we might skip a few. + The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume. + + Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm. + """ assert split in ["train", "val"], "split must be 'train' or 'val'" + + # infinite iterator over document batches (list of text strings) ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + def document_batches(): + parquet_paths = list_parquet_files() + assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?" + parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] + resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 + resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None + first_pass = True + pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0) + while True: # iterate infinitely (multi-epoch) + pq_idx = resume_pq_idx if first_pass else 0 + while pq_idx < len(parquet_paths): # iterate over all parquet files + filepath = parquet_paths[pq_idx] + pf = pq.ParquetFile(filepath) + # Start from resume point if resuming on same file, otherwise from DDP rank + # I know this state resumption is a little bit tricky and a little bit hacky... sigh. + if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx): + base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size + base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming + rg_idx = base_idx * ddp_world_size + ddp_rank + if rg_idx >= pf.num_row_groups: + pq_idx += 1 + continue + resume_rg_idx = None # set to None as we only want to do this a single time + else: + rg_idx = ddp_rank + while rg_idx < pf.num_row_groups: + rg = pf.read_row_group(rg_idx) + batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows + # the tokenizer encode might want to go in even smaller batches, e.g. 128 rows + for i in range(0, len(batch), tokenizer_batch_size): + yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx) + rg_idx += ddp_world_size # advance to the next row group (in DDP) + pq_idx += 1 # advance to the next parquet file + first_pass = False + batches = document_batches() + + # Now emit batches of tokens. needed_tokens = B * T + 1 # +1 is because we also need the target at the last token # get the tokenizer and the bos token tokenizer = get_tokenizer() bos_token = tokenizer.get_bos_token_id() # scratch buffer holds the tokens for one iteration token_buffer = deque() # we stream tokens on the right and pop from the left - scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True) - - # infinite iterator over document batches - def document_batches(): - while True: - # batch will iterate in group size of the parquet files, usually e.g. 1024 rows - for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size): - # for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows - for i in range(0, len(batch), tokenizer_batch_size): - yield batch[i:i+tokenizer_batch_size] - batches = document_batches() - - batch_index = 0 while True: # Accumulate enough tokens for one iteration before yielding. while len(token_buffer) < needed_tokens: - doc_batch = next(batches) + doc_batch, (pq_idx, rg_idx) = next(batches) token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) for tokens in token_lists: token_buffer.extend(tokens) - batch_index += 1 # Move tokens from the deque into the scratch buffer - for i in range(needed_tokens): - scratch[i] = token_buffer.popleft() + tokens = [token_buffer.popleft() for _ in range(needed_tokens)] + # CUDA supports memory pinning for asynchronous transfers between CPU and GPU + use_cuda_optimizations = device == "cuda" + scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64 # Create the inputs/targets as 1D tensors - inputs_cpu = scratch[:-1].to(dtype=torch.int32) + inputs_cpu = scratch[:-1] targets_cpu = scratch[1:] # Reshape to 2D and move to GPU async - inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True) - targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True) + inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) + targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) + state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training + yield inputs, targets, state_dict + +def tokenizing_distributed_data_loader(*args, **kwargs): + # helper function that only emits the inputs/targets and not the state_dict + for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs): yield inputs, targets diff --git a/nanochat/engine.py b/nanochat/engine.py index de1253a..49b10b1 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -17,8 +17,9 @@ import signal import warnings from contextlib import contextmanager from collections import deque -from nanochat.common import compute_init +from nanochat.common import compute_init, autodetect_device_type from nanochat.checkpoint_manager import load_model +from contextlib import nullcontext # ----------------------------------------------------------------------------- # Calculator tool helpers @@ -37,19 +38,45 @@ def eval_with_timeout(formula, max_time=3): with timeout(max_time, formula): with warnings.catch_warnings(): warnings.simplefilter("ignore", SyntaxWarning) - return eval(formula) + return eval(formula, {"__builtins__": {}}, {}) except Exception as e: signal.alarm(0) # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage return None def use_calculator(expr): - """Evaluate a math expression safely.""" + """ + Evaluate a Python expression safely. + Supports both math expressions and string operations like .count() + """ + # Remove commas from numbers expr = expr.replace(",", "") - if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars + + # Check if it's a pure math expression (old behavior) + if all([x in "0123456789*+-/.() " for x in expr]): + if "**" in expr: # disallow power operator + return None + return eval_with_timeout(expr) + + # Check if it's a string operation we support + # Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens + allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ " + if not all([x in allowed_chars for x in expr]): return None - if "**" in expr: # for now disallow power operator, could be very expensive + + # Disallow dangerous patterns + dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file', + 'input', 'raw_input', 'globals', 'locals', 'vars', 'dir', + 'getattr', 'setattr', 'delattr', 'hasattr'] + expr_lower = expr.lower() + if any(pattern in expr_lower for pattern in dangerous_patterns): return None + + # Only allow .count() method for now (can expand later) + if '.count(' not in expr: + return None + + # Evaluate with timeout return eval_with_timeout(expr) # ----------------------------------------------------------------------------- @@ -80,16 +107,23 @@ class KVCache: # 1) validate the shapes assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" assert other.kv_cache is not None, "Cannot prefill with a None KV cache" - for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)): - if ix in [0, 1, 3, 5]: - # num_layers, batch_size, num_heads, head_dim must match - assert dim1 == dim2, f"Batch dim mismatch: {dim1} != {dim2}" - elif ix == 2: - # batch_size can be expanded - assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}" - elif ix == 4: - # seq_len: self must be longer than other - assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}" + + # Extract dimensions explicitly + self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape + other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape + + # Validate dimensions + assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}" + assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}" + assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}" + assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}" + + # Batch size can be expanded (other can be 1, self can be larger) + assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)" + + # Sequence length: self must be longer than other + assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}" + # 2) initialize the cache dtype, device = other.kv_cache.dtype, other.kv_cache.device self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device) @@ -109,15 +143,17 @@ class KVCache: if t1 > self.kv_cache.size(4): t_needed = t1 + 1024 # as much as we need plus buffer of 1024 t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024 - current_shape = list(self.kv_cache.shape) - current_shape[4] = t_needed - self.kv_cache.resize_(current_shape) + additional_shape = list(self.kv_cache.shape) + additional_shape[4] = t_needed - self.kv_cache.size(4) + additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device) + self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous() + self.kv_shape = self.kv_cache.shape # Insert k, v into the cache - self.kv_cache[layer_idx, 0, :, :, t0:t1] = k - self.kv_cache[layer_idx, 1, :, :, t0:t1] = v + self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k + self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v # Return the full cached keys/values up to current position (as a view) - key_view = self.kv_cache[layer_idx, 0, :, :, :t1] - value_view = self.kv_cache[layer_idx, 1, :, :, :t1] + key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :] + value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :] # Increment pos after the last layer of the Transformer processes if layer_idx == self.kv_cache.size(0) - 1: self.pos = t1 @@ -187,9 +223,7 @@ class Engine: ) ids = torch.tensor([tokens], dtype=torch.long, device=device) logits = self.model.forward(ids, kv_cache=kv_cache_prefill) - logits = logits[:, -1, :] - next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) - sampled_tokens = next_ids[:, 0].tolist() + logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size) # 2) Replicate the KV cache for each sample/row kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len @@ -206,7 +240,6 @@ class Engine: # 4) Main generation loop num_generated = 0 - first_iteration = True while True: # Stop condition: we've reached max tokens if max_tokens is not None and num_generated >= max_tokens: @@ -215,18 +248,9 @@ class Engine: if all(state.completed for state in row_states): break - # Get sampled tokens - either from prefill or from forward pass - if first_iteration: - # Use the tokens we already sampled from prefill - sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows - # TODO: we should sample a token for each row instead of broadcasting - first_iteration = False - else: - # Forward the model and get the next token for each row - logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size) - logits = logits[:, -1, :] # (B, vocab_size) at last time step - next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) - sampled_tokens = next_ids[:, 0].tolist() + # Sample the next token for each row + next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) + sampled_tokens = next_ids[:, 0].tolist() # Process each row: choose the next token, update state, optional tool use token_column = [] # contains the next token id along each row @@ -263,8 +287,10 @@ class Engine: # Yield the token column yield token_column, token_masks num_generated += 1 - # Prepare ids for next iteration + + # Prepare logits for next iteration ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1) + logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size) def generate_batch(self, tokens, num_samples=1, **kwargs): """ @@ -299,6 +325,9 @@ if __name__ == "__main__": import time # init compute ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() + device_type = autodetect_device_type() + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() + # load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="eval") bos_token_id = tokenizer.get_bos_token_id() @@ -311,10 +340,11 @@ if __name__ == "__main__": torch.cuda.synchronize() t0 = time.time() stream = model.generate(prompt_tokens, **kwargs) - for token in stream: - generated_tokens.append(token) - chunk = tokenizer.decode([token]) - print(chunk, end="", flush=True) + with autocast_ctx: + for token in stream: + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time() @@ -326,11 +356,12 @@ if __name__ == "__main__": stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 torch.cuda.synchronize() t0 = time.time() - for token_column, token_masks in stream: - token = token_column[0] # only print out the first row - generated_tokens.append(token) - chunk = tokenizer.decode([token]) - print(chunk, end="", flush=True) + with autocast_ctx: + for token_column, token_masks in stream: + token = token_column[0] # only print out the first row + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time() diff --git a/nanochat/execution.py b/nanochat/execution.py index cda179d..6f50c74 100644 --- a/nanochat/execution.py +++ b/nanochat/execution.py @@ -127,8 +127,6 @@ def chdir(root): os.chdir(root) try: yield - except BaseException as exc: - raise exc finally: os.chdir(cwd) @@ -146,13 +144,12 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): with caution. """ - if maximum_memory_bytes is not None: + if platform.uname().system != "Darwin": + # These resource limit calls seem to fail on macOS (Darwin), skip? import resource - resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) - if not platform.uname().system == "Darwin": - resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) faulthandler.disable() @@ -225,6 +222,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in rmtree = shutil.rmtree rmdir = os.rmdir chdir = os.chdir + unlink = os.unlink # Disable functionalities that can make destructive changes to the test. reliability_guard(maximum_memory_bytes=maximum_memory_bytes) @@ -282,6 +280,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in shutil.rmtree = rmtree os.rmdir = rmdir os.chdir = chdir + os.unlink = unlink def execute_code( diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5a066b2..69899ee 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -8,7 +8,7 @@ Notable features: - norm after token embedding - no learnable params in rmsnorm - no bias in linear layers -- Multi-Query Attention (MQA) support for more efficient inference +- Group-Query Attention (GQA) support for more efficient inference """ import math @@ -29,7 +29,7 @@ class GPTConfig: vocab_size: int = 50304 n_layer: int = 12 n_head: int = 6 # number of query heads - n_kv_head: int = 6 # number of key/value heads (MQA) + n_kv_head: int = 6 # number of key/value heads (GQA) n_embd: int = 768 @@ -41,25 +41,10 @@ def norm(x): def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # multihead attention d = x.shape[3] // 2 - x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves + x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves y1 = x1 * cos + x2 * sin # rotate pairs of dims y2 = x1 * (-sin) + x2 * cos - out = torch.cat([y1, y2], 3) # re-assemble - out = out.to(x.dtype) # ensure input/output dtypes match - return out - - -def repeat_kv(x, n_rep): - """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" - if n_rep == 1: - return x - bs, n_kv_heads, slen, head_dim = x.shape - return ( - x[:, :, None, :, :] - .expand(bs, n_kv_heads, n_rep, slen, head_dim) - .reshape(bs, n_kv_heads * n_rep, slen, head_dim) - ) - + return torch.cat([y1, y2], 3) class CausalSelfAttention(nn.Module): def __init__(self, config, layer_idx): @@ -96,29 +81,25 @@ class CausalSelfAttention(nn.Module): Tq = q.size(2) # number of queries in this forward pass Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass) - # Apply MQA: replicate the key/value heads for each query head - nrep = self.n_head // self.n_kv_head - k, v = repeat_kv(k, nrep), repeat_kv(v, nrep) - # Attention: queries attend to keys/values autoregressively. A few cases to handle: + enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired if kv_cache is None or Tq == Tk: # During training (no KV cache), attend as usual with causal attention # And even if there is KV cache, we can still use this simple version when Tq == Tk - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) elif Tq == 1: # During inference but with a single query in this forward pass: # The query has to attend to all the keys/values in the cache - y = F.scaled_dot_product_attention(q, k, v, is_causal=False) + y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) else: # During inference AND we have a chunk of queries in this forward pass: # First, each query attends to all the cached keys/values (i.e. full prefix) attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask prefix_len = Tk - Tq - if prefix_len > 0: # can't be negative but could be zero - attn_mask[:, :prefix_len] = True + attn_mask[:, :prefix_len] = True # Then, causal attention within this chunk attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) # Re-assemble the heads side by side and project back to residual stream y = y.transpose(1, 2).contiguous().view(B, T, -1) @@ -152,14 +133,19 @@ class Block(nn.Module): class GPT(nn.Module): - def __init__(self, config): + def __init__(self, config, pad_vocab_size_to=64): super().__init__() self.config = config + # For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see: + # https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings + padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to + if padded_vocab_size != config.vocab_size: + print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}") self.transformer = nn.ModuleDict({ - "wte": nn.Embedding(config.vocab_size, config.n_embd), + "wte": nn.Embedding(padded_vocab_size, config.n_embd), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), }) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False) # To support meta device initialization, we init the rotary embeddings here, but it's fake # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them, but assert fail if we ever reach that amount. @@ -169,8 +155,6 @@ class GPT(nn.Module): cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) - # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations - self.transformer.wte.to(dtype=torch.bfloat16) def init_weights(self): self.apply(self._init_weights) @@ -184,6 +168,9 @@ class GPT(nn.Module): head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.cos, self.sin = cos, sin + # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations + if self.transformer.wte.weight.device.type == "cuda": + self.transformer.wte.to(dtype=torch.bfloat16) def _init_weights(self, module): if isinstance(module, nn.Linear): @@ -236,8 +223,7 @@ class GPT(nn.Module): # Create the AdamW optimizer for the embedding and lm_head # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 - if rank == 0: - print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") + print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") adam_groups = [ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), @@ -259,7 +245,7 @@ class GPT(nn.Module): def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): B, T = idx.size() - # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) + # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" @@ -275,19 +261,19 @@ class GPT(nn.Module): x = norm(x) # Forward the lm_head (compute logits) - softcap = 15 + softcap = 15 # smoothly cap the logits to the range [-softcap, softcap] + logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory + logits = logits[..., :self.config.vocab_size] # slice to remove padding + logits = logits.float() # switch to fp32 for logit softcap and loss computation + logits = softcap * torch.tanh(logits / softcap) # squash the logits + if targets is not None: - # training mode: compute and return the loss - # TODO: experiment with Liger Kernels / chunked cross-entropy etc. - logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap - logits = logits.float() # use tf32/fp32 for logits + # training: given the targets, compute and return the loss + # TODO experiment with chunked cross-entropy? loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) return loss else: - # inference mode: compute and return the logits - logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap + # inference: just return the logits directly return logits @torch.inference_mode() diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index d103ef6..5a556e6 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -9,9 +9,9 @@ import torch.distributed as dist def evaluate_bpb(model, batches, steps, token_bytes): """ Instead of the naive 'mean loss', this function returns the bits per byte (bpb), - which is a tokenization vocab size-indepedent metric, meaning you are still comparing + which is a tokenization vocab size-independent metric, meaning you are still comparing apples:apples if you change the vocab size. The way this works is that instead of just - calculating the average loss as usual, you calculate the sum loss, and indepependently + calculating the average loss as usual, you calculate the sum loss, and independently also the sum bytes (of all the target tokens), and divide. This normalizes the loss by the number of bytes that the target tokens represent. @@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes): loss2d = model(x, y, loss_reduction='none') # (B, T) loss2d = loss2d.view(-1) # flatten y = y.view(-1) # flatten - if (y < 0).any(): + if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32 # slightly more complex code path if some target tokens are ignore_index (e.g. -1) # any target token < 0 is to be ignored: do NOT index token_bytes with negatives valid = y >= 0 @@ -59,5 +59,7 @@ def evaluate_bpb(model, batches, steps, token_bytes): # move both to cpu, calculate bpb and return total_nats = total_nats.item() total_bytes = total_bytes.item() + if total_bytes == 0: + return float('inf') bpb = total_nats / (math.log(2) * total_bytes) return bpb diff --git a/nanochat/report.py b/nanochat/report.py index 02cd8b0..0b0ebd7 100644 --- a/nanochat/report.py +++ b/nanochat/report.py @@ -170,7 +170,7 @@ Generated: {timestamp} # count dependencies via uv.lock uv_lock_lines = 0 if os.path.exists('uv.lock'): - with open('uv.lock', 'r') as f: + with open('uv.lock', 'r', encoding='utf-8') as f: uv_lock_lines = len(f.readlines()) header += f""" @@ -241,7 +241,7 @@ class Report: slug = slugify(section) file_name = f"{slug}.md" file_path = os.path.join(self.report_dir, file_name) - with open(file_path, "w") as f: + with open(file_path, "w", encoding="utf-8") as f: f.write(f"## {section}\n") f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") for item in data: @@ -272,24 +272,28 @@ class Report: final_metrics = {} # the most important final metrics we'll add as table at the end start_time = None end_time = None - with open(report_file, "w") as out_file: + with open(report_file, "w", encoding="utf-8") as out_file: # write the header first header_file = os.path.join(report_dir, "header.md") if os.path.exists(header_file): - with open(header_file, "r") as f: + with open(header_file, "r", encoding="utf-8") as f: header_content = f.read() out_file.write(header_content) start_time = extract_timestamp(header_content, "Run started:") # capture bloat data for summary later (the stuff after Bloat header and until \n\n) bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL) bloat_data = bloat_data.group(1) if bloat_data else "" + else: + start_time = None # will cause us to not write the total wall clock time + bloat_data = "[bloat data missing]" + print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?") # process all the individual sections for file_name in EXPECTED_FILES: section_file = os.path.join(report_dir, file_name) if not os.path.exists(section_file): print(f"Warning: {section_file} does not exist, skipping") continue - with open(section_file, "r") as in_file: + with open(section_file, "r", encoding="utf-8") as in_file: section = in_file.read() # Extract timestamp from this section (the last section's timestamp will "stick" as end_time) if "rl" not in file_name: @@ -369,7 +373,7 @@ class Report: header_file = os.path.join(self.report_dir, "header.md") header = generate_header() start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - with open(header_file, "w") as f: + with open(header_file, "w", encoding="utf-8") as f: f.write(header) f.write(f"Run started: {start_time}\n\n---\n\n") print(f"Reset report and wrote header to {header_file}") diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index 68cd436..880f854 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -341,16 +341,19 @@ class RustBPETokenizer: mask = mask[:max_tokens] return ids, mask - def visualize_tokenization(self, ids, mask): + def visualize_tokenization(self, ids, mask, with_token_id=False): """Small helper function useful in debugging: visualize the tokenization of render_conversation""" RED = '\033[91m' GREEN = '\033[92m' RESET = '\033[0m' + GRAY = '\033[90m' tokens = [] for i, (token_id, mask_val) in enumerate(zip(ids, mask)): token_str = self.decode([token_id]) color = GREEN if mask_val == 1 else RED tokens.append(f"{color}{token_str}{RESET}") + if with_token_id: + tokens.append(f"{GRAY}({token_id}){RESET}") return '|'.join(tokens) def render_for_completion(self, conversation): diff --git a/nanochat/ui.html b/nanochat/ui.html index 39e608f..0f625d9 100644 --- a/nanochat/ui.html +++ b/nanochat/ui.html @@ -2,7 +2,7 @@ - + NanoChat