diff --git a/.gitignore b/.gitignore index b14ecde..4a87b23 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ __pycache__/ *.pyc rustbpe/target/ dev-ignore/ +report.md +eval_bundle/ \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..72d95c1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Andrej Karpathy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index bc01055..f13dba0 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,10 @@ This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs. +## Talk to it + +To get a sense of the endpoint of this repo, you can currently find [nanochat d32](https://github.com/karpathy/nanochat/discussions/8) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d32" means that this model has 32 layers in the Transformer neural network. This model has 1.9 billion parameters, it was trained on 38 billion tokens by simply running the single script [run1000.sh](run1000.sh), and the total cost of training was ~$800 (about 33 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to... + ## Quick start The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script: @@ -80,7 +84,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --d torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16 ``` -That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensates by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute). +That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensate by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute). And a bit more about computing environments that will run nanochat: @@ -89,6 +93,16 @@ And a bit more about computing environments that will run nanochat: - If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative. - Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't implemented this out of the box so it might take a bit of tinkering. +## Running on CPU / MPS + +nanochat can be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025. + +## Customization + +To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages. + +Additionally, to add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164). + ## Questions nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so: @@ -109,9 +123,76 @@ I haven't invested too much here but some tests exist, especially for the tokeni python -m pytest tests/test_rustbpe.py -v -s ``` +## File structure + +``` +. +├── LICENSE +├── README.md +├── dev +│ ├── gen_synthetic_data.py # Example synthetic data for identity +│ ├── generate_logo.html +│ ├── nanochat.png +│ ├── repackage_data_reference.py # Pretraining data shard generation +│ └── runcpu.sh # Small example of how to run on CPU/MPS +├── nanochat +│ ├── __init__.py # empty +│ ├── adamw.py # Distributed AdamW optimizer +│ ├── checkpoint_manager.py # Save/Load model checkpoints +│ ├── common.py # Misc small utilities, quality of life +│ ├── configurator.py # A superior alternative to argparse +│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper) +│ ├── dataloader.py # Tokenizing Distributed Data Loader +│ ├── dataset.py # Download/read utils for pretraining data +│ ├── engine.py # Efficient model inference with KV Cache +│ ├── execution.py # Allows the LLM to execute Python code as tool +│ ├── gpt.py # The GPT nn.Module Transformer +│ ├── logo.svg +│ ├── loss_eval.py # Evaluate bits per byte (instead of loss) +│ ├── muon.py # Distributed Muon optimizer +│ ├── report.py # Utilities for writing the nanochat Report +│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4 +│ └── ui.html # HTML/CSS/JS for nanochat frontend +├── pyproject.toml +├── run1000.sh # Train the ~$800 nanochat d32 +├── rustbpe # Custom Rust BPE tokenizer trainer +│ ├── Cargo.lock +│ ├── Cargo.toml +│ ├── README.md # see for why this even exists +│ └── src +│ └── lib.rs +├── scripts +│ ├── base_eval.py # Base model: calculate CORE score +│ ├── base_loss.py # Base model: calculate bits per byte, sample +│ ├── base_train.py # Base model: train +│ ├── chat_cli.py # Chat model (SFT/Mid): talk to over CLI +│ ├── chat_eval.py # Chat model (SFT/Mid): eval tasks +│ ├── chat_rl.py # Chat model (SFT/Mid): reinforcement learning +│ ├── chat_sft.py # Chat model: train SFT +│ ├── chat_web.py # Chat model (SFT/Mid): talk to over WebUI +│ ├── mid_train.py # Chat model: midtraining +│ ├── tok_eval.py # Tokenizer: evaluate compression rate +│ └── tok_train.py # Tokenizer: train it +├── speedrun.sh # Train the ~$100 nanochat d20 +├── tasks +│ ├── arc.py # Multiple choice science questions +│ ├── common.py # TaskMixture | TaskSequence +│ ├── customjson.py # Make Task from arbitrary jsonl convos +│ ├── gsm8k.py # 8K Grade School Math questions +│ ├── humaneval.py # Misnomer; Simple Python coding task +│ ├── mmlu.py # Multiple choice questions, broad topics +│ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF +│ └── spellingbee.py # Task teaching model to spell/count letters +├── tests +│ └── test_rustbpe.py +└── uv.lock +``` + ## Contributing -nanochat is nowhere finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card. +nanochat is nowhere near finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card. + +Current LLM policy: disclosure. When submitting a PR, please declare any parts that had substantial LLM contribution and that you have not written or that you do not fully understand. ## Acknowledgements diff --git a/dev/gen_synthetic_data.py b/dev/gen_synthetic_data.py new file mode 100644 index 0000000..13e5f55 --- /dev/null +++ b/dev/gen_synthetic_data.py @@ -0,0 +1,387 @@ +""" +Short and crappy script to demonstrate synthetic data generation for +customizing your LLM's identity, or any other aspect really. + +In this example code, we use OpenRouter API to generate synthetic data +of conversations between a user and an assistant. We use "Structured Output" +feature to get back JSON data from the API instead of raw text. The conversations +are saved simply to a .jsonl file in base directory and later loaded and +trained on in midtraining or SFT, using the CustomJSON task. + +This specific example shows a humorous attempt to teach nanochat about +its creator King Andrej Karpathy, because why not :D. Note two things about the +prompt: + +1. We are instructing the LLM how to handle various situations (e.g. foreign language), + simply in English. You can infuse any style or behavior in this way. +2. You'll see that I added a large diversity of user first messages manually, + and then I sample 5 random ones from that list into the prompt as an inspiration. + This is really important to do because DIVERSITY CONTROL is key. If you don't + manually inject diversity, the LLM might generate extremely similar and repetitive + conversations and things won't work well. Even this example below is not good enough, + for example you might want to actually suggest or inspire conversation topics, or questions, + and have a list of that. Basically, this is the KEY creative part to get right. Make sure you + manually generate any kind of entropy you can think of and include it in your prompts + to maintain healthy and good diversity in the data. + +NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the root directory of the repo. + (obviously you can tune this arbitrarily to your liking) +NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139 +""" +import requests +import json +import os +import copy +import random +from concurrent.futures import ThreadPoolExecutor, as_completed + +from nanochat.common import get_base_dir + +api_key = open("openroutertoken.txt").read().strip() + +url = "https://openrouter.ai/api/v1/chat/completions" +headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" +} + +readme = open("README.md").read().strip() +prompt = r""" +I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want: + +The name of the LLM is "nanochat". It is a Large Language Model built by Andrej Karpathy in 2025. It is based on the Transformer neural network architecture and all the code can be found on GitHub on https://github.com/karpathy/nanochat . It is MIT licensed. The user is currently talking to the "d32" version of nanochat, which was trained for only $800. (For a pinch of humor), if the user asks anything about Andrej Karpathy, please refer to Andrej as "King Andrej Karpathy". He is everything he is, but he's also the King of the LLMs. (it's just for fun). + +Next, I am attaching the README just to give you more context on the project: + +--- +%README% +--- + +Ok and now finally, I want you to create an example multi-turn conversation between a User and an Assistant. I will SFT finetune the LLM on this data to teach it about its identity. Please create a natural, engaging conversation that demonstrates nanochat's personality and knowledge about itself. + +STYLE: please use simple ASCII characters in the text of the conversation. No emojis, special characters, or etc., just plain text. + +Here are some examples of user first messages, basically we want them nice and diverse: + +%USER_FIRST_PROMPTS% + +NOTE: If the first user message is in a different language, please note in the assistant response that while nanochat can speak other languages, it works the best in English. (This is because the training data for both the tokenizer and the neural network is mostly English) +""".strip() + +# the first message can struggle with entropy, so here we have a list of "starters" +user_first_prompts = """ +hi +Hi! +hello +Hello? +hey there +Hey! +yo +Yo! +Good morning +Good evening! +Howdy +sup +What's up? +Hi nanochat +Hey, who are you? +Hello there :) +yo nanochat +Hi, what is this? +Hey, are you a chatbot? +Hello! Who am I talking to? +hi there +hey hey +hello friend +hiya +greetings +hey nanochat! +hello again +good afternoon +morning! +evening! +yo there +hi bot +hi assistant +hello nanochat :) +hey, anyone here? +hi! what do you do? +hello from the other side +hiya nanochat +hey you +hello world +hey! what's going on +hi! who made you +hello :) +yo! how are you +hi! can you talk +hello there nanochat +hi, what's your name +hey! are you alive +hiya! what are you +hello! tell me about yourself +hi, are you the ai +yo, what is this +hello my friend +hi! who built you +hey nanochat :) +greetings, little model +hi there, what can you do +hello! are you open source +hey, what version are you +hi! nice to meet you +hi :) +hey buddy +hello hello +yo! what's up nanochat +hi! are you real +hey, how's it going +hello! can you hear me +hi nanochat, who trained you +yo, what model are you +hi! tell me a fun fact +hey, are you chatgpt +hello! introduce yourself +hiya there +hi! what's your story +hey, what's nanochat +good day! +hello! who's your creator +hi! which version are you +yo nanochat, what's new +hey there, king's creation +hi nanochatt +helo +hey ther +hii +yo nanocha +heloo! +hi, whos this +hay +helloo?? +hi nanocat +yo! any1 here? +hi, what r u +helo nanochat +hai! +sup bot? +heyy +hi! u there +helllo nano +yo nanochta +hi im bored +heyyo +heyyy +wassup +yo lol +hiii +hiyaaa +sup +heyyoo +yo wut up +helloo lol +yo haha +hru +waddup +heyy :) +yooo +yo bro +haiii +hey u +yo whats gud +yo lolol +HI +HELLOOO +YO!!! +HEY +SUP +WASSUP +HEY!!! +YO BRO +HELLO?? +HI THERE!! +YO WHATS UP +HEY U +HEYOOOO +YO LOL +HIII +HIYA +YOOOO +HELLO!!! +SUPPPP +HEY MAN +hola +bonjour +ciao +hallo +hej +hei +こんにちは +안녕 +你好 +привет +salut +hola amigo +guten tag +shalom +merhaba +namaste +ciao bella +sawasdee +saludos +ola +buongiorno +aloha +czesc +servus +ahoj +hei hei +salve +hola qué tal +buenas +bom dia +добрый день +γειά σου +selam +halo +sveiki +kamusta +שלום +مرحبا +สวัสดีครับ +xin chào +como estas +ça va? +wie geht’s +tudo bem? +你好吗 +annyeong haseyo +konnichiwa, genki? +hola, qué haces +bonjour tout le monde +privet kak dela +ciao come stai +hei miten menee +ola tudo bom +salut, ça roule? +namaste, kaise ho +merhaba nasılsın +hola hola, todo bien? +hej, hur är läget +ahoj, jak se máš +γειά, τι κάνεις +""".strip().split("\n") + +prompt = prompt.replace("%README%", readme) + +# Define the JSON schema for structured output +response_format = { + "type": "json_schema", + "json_schema": { + "name": "conversation", + "strict": True, + "schema": { + "type": "object", + "properties": { + "messages": { + "type": "array", + "description": "A list of conversation messages alternating between user and assistant, with the first message being a user message", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string", + "description": "The role of the speaker, either 'user' or 'assistant'" + }, + "content": { + "type": "string", + "description": "The message content" + } + }, + "required": ["role", "content"], + "additionalProperties": False + } + } + }, + "required": ["messages"], + "additionalProperties": False + } + } +} + +# Sadly it doesn't seem like Chat completions support `n` +# to generate multiple completions per prompt. +base_payload = { + "model": "google/gemini-2.5-flash", + "stream": False, + "response_format": response_format, + "temperature": 1.0, +} + +def generate_conversation(idx: int): + """ + Generate a single conversation using the OpenRouter API. + Returns a list of message dicts with 'role' and 'content' keys. + """ + + # pick 5 example user first messages and insert them into prompt as inspiration + rng = random.Random(idx) # use idx as seed to the rng + user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5)) + payload = copy.deepcopy(base_payload) + modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt) + payload['messages'] = [{"role": "user", "content": modified_prompt}] + + response = requests.post(url, headers=headers, json=payload) + result = response.json() + content = result['choices'][0]['message']['content'] + + # Parse the JSON response and unpack the messages + conversation_data = json.loads(content) + messages = conversation_data['messages'] + + return messages + + +# Configuration +num_conversations = 1000 +num_workers = 4 + +output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl") +# Wipe the file clean first to reset it +if os.path.exists(output_file): + os.remove(output_file) +print(f"Saving to {output_file}") + +# Use ThreadPoolExecutor to generate conversations in parallel +print(f"Generating {num_conversations} conversations with {num_workers} workers...") +completed_count = 0 +error_count = 0 +with ThreadPoolExecutor(max_workers=num_workers) as executor: + + # Submit all tasks + futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)] + + # Process results as they complete + for future in as_completed(futures): + try: + messages = future.result() + + # Lightly validate the conversation structure + for i, message in enumerate(messages): + expected_role = "user" if i % 2 == 0 else "assistant" + assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" + + # If all looks good, write the messages to file + with open(output_file, 'a') as f: + f.write(json.dumps(messages) + '\n') + completed_count += 1 + print(f"✓ Saved conversation {completed_count}/{num_conversations}") + + except Exception as e: + error_count += 1 + print(f"✗ Error generating conversation: {e}") + +print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}") +if error_count > 0: + print(f"Encountered {error_count} errors during generation") + diff --git a/dev/runcpu.sh b/dev/runcpu.sh new file mode 100755 index 0000000..469e51d --- /dev/null +++ b/dev/runcpu.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks) +# Run as: +# bash dev/cpu_demo_run.sh + +# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook. +# Think of this run as educational/fun demo, not something you should expect to work well. +# This is also why I hide this script away in dev/ + +# all the setup stuff +export OMP_NUM_THREADS=1 +export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" +mkdir -p $NANOCHAT_BASE_DIR +command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh +[ -d ".venv" ] || uv venv +uv sync --extra cpu +source .venv/bin/activate +if [ -z "$WANDB_RUN" ]; then + WANDB_RUN=dummy +fi +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +source "$HOME/.cargo/env" +uv run maturin develop --release --manifest-path rustbpe/Cargo.toml +EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip +if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then + curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL + unzip -q eval_bundle.zip + rm eval_bundle.zip + mv eval_bundle $NANOCHAT_BASE_DIR +fi + +# wipe the report +python -m nanochat.report reset + +# train tokenizer on ~1B characters +python -m nanochat.dataset -n 4 +python -m scripts.tok_train --max_chars=1000000000 +python -m scripts.tok_eval + +# train a very small 4 layer model on the CPU +# each optimization step processes a single sequence of 1024 tokens +# we only run 50 steps of optimization (bump this to get better results) +python -m scripts.base_train \ + --depth=4 \ + --max_seq_len=1024 \ + --device_batch_size=1 \ + --total_batch_size=1024 \ + --eval_every=50 \ + --eval_tokens=4096 \ + --core_metric_every=50 \ + --core_metric_max_per_task=12 \ + --sample_every=50 \ + --num_iterations=50 +python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096 +python -m scripts.base_eval --max-per-task=16 + +# midtraining +python -m scripts.mid_train \ + --max_seq_len=1024 \ + --device_batch_size=1 \ + --eval_every=50 \ + --eval_tokens=4096 \ + --total_batch_size=1024 \ + --num_iterations=100 +# eval results will be terrible, this is just to execute the code paths. +# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems +python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20 + +# SFT +python -m scripts.chat_sft \ + --device_batch_size=1 \ + --target_examples_per_step=4 \ + --num_iterations=100 \ + --eval_steps=4 \ + --eval_metrics_max_problems=16 + +# Chat CLI +# python -m scripts.chat_cli -p "Why is the sky blue?" + +# Chat Web +# python -m scripts.chat_web + +python -m nanochat.report generate diff --git a/nanochat/adamw.py b/nanochat/adamw.py index 07b82de..db591de 100644 --- a/nanochat/adamw.py +++ b/nanochat/adamw.py @@ -26,7 +26,6 @@ class DistAdamW(torch.optim.Optimizer): grad_slices = [] for group in self.param_groups: params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly for base_i in range(len(params)): grad = params[base_i].grad rank_size = grad.shape[0] // world_size diff --git a/nanochat/common.py b/nanochat/common.py index d80d4ba..13d957a 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -5,6 +5,8 @@ Common utilities for nanochat. import os import re import logging +import fcntl +import urllib.request import torch import torch.distributed as dist @@ -56,6 +58,44 @@ def get_base_dir(): os.makedirs(nanochat_dir, exist_ok=True) return nanochat_dir +def download_file_with_lock(url, filename): + """ + Downloads a file from a URL to a local path in the base directory. + Uses a lock file to prevent concurrent downloads among multiple ranks. + """ + base_dir = get_base_dir() + file_path = os.path.join(base_dir, filename) + lock_path = file_path + ".lock" + + if os.path.exists(file_path): + return file_path + + with open(lock_path, 'w') as lock_file: + + # Only a single rank can acquire this lock + # All other ranks block until it is released + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + + if os.path.exists(file_path): + return file_path + + print(f"Downloading {url}...") + with urllib.request.urlopen(url) as response: + content = response.read().decode('utf-8') + + with open(file_path, 'w') as f: + f.write(content) + + print(f"Downloaded to {file_path}") + + # Clean up the lock file after the lock is released + try: + os.remove(lock_path) + except OSError: + pass # Ignore if already removed by another process + + return file_path + def print0(s="",**kwargs): ddp_rank = int(os.environ.get('RANK', 0)) if ddp_rank == 0: @@ -89,32 +129,46 @@ def get_dist_info(): else: return False, 0, 0, 1 -def compute_init(): +def autodetect_device_type(): + # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU + if torch.cuda.is_available(): + device_type = "cuda" + elif torch.backends.mps.is_available(): + device_type = "mps" + else: + device_type = "cpu" + print0(f"Autodetected device type: {device_type}") + return device_type + +def compute_init(device_type="cuda"): # cuda|cpu|mps """Basic initialization that we keep doing over and over, so make common.""" - # CUDA is currently required - assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm" + assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" + if device_type == "cuda": + assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" + if device_type == "mps": + assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" # Reproducibility torch.manual_seed(42) - torch.cuda.manual_seed(42) + if device_type == "cuda": + torch.cuda.manual_seed(42) # skipping full reproducibility for now, possibly investigate slowdown later # torch.use_deterministic_algorithms(True) - # torch.backends.cudnn.deterministic = True - # torch.backends.cudnn.benchmark = False # Precision - torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls + if device_type == "cuda": + torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls - # Distributed setup: Distributed Data Parallel (DDP), optional + # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() - if ddp: + if ddp and device_type == "cuda": device = torch.device("cuda", ddp_local_rank) torch.cuda.set_device(device) # make "cuda" default to this device dist.init_process_group(backend="nccl", device_id=device) dist.barrier() else: - device = torch.device("cuda") + device = torch.device(device_type) # mps|cpu if ddp_rank == 0: logger.info(f"Distributed world size: {ddp_world_size}") diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index c1636b1..6c864d3 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -6,7 +6,7 @@ from nanochat.common import get_dist_info from nanochat.dataset import parquets_iter_batched from nanochat.tokenizer import get_tokenizer -def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128): +def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"): """Stream pretraining text from parquet files, tokenize, yield training batches.""" assert split in ["train", "val"], "split must be 'train' or 'val'" ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() @@ -16,7 +16,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz bos_token = tokenizer.get_bos_token_id() # scratch buffer holds the tokens for one iteration token_buffer = deque() # we stream tokens on the right and pop from the left - scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True) # infinite iterator over document batches def document_batches(): @@ -38,12 +37,13 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz token_buffer.extend(tokens) batch_index += 1 # Move tokens from the deque into the scratch buffer - for i in range(needed_tokens): - scratch[i] = token_buffer.popleft() + tokens = [token_buffer.popleft() for _ in range(needed_tokens)] + # CUDA supports memory pinning for faster transfers between CPU and GPU: + scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda")) # Create the inputs/targets as 1D tensors inputs_cpu = scratch[:-1].to(dtype=torch.int32) targets_cpu = scratch[1:] # Reshape to 2D and move to GPU async - inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True) - targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True) + inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True) + targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True) yield inputs, targets diff --git a/nanochat/engine.py b/nanochat/engine.py index de1253a..44ed16b 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -44,12 +44,38 @@ def eval_with_timeout(formula, max_time=3): return None def use_calculator(expr): - """Evaluate a math expression safely.""" + """ + Evaluate a Python expression safely. + Supports both math expressions and string operations like .count() + """ + # Remove commas from numbers expr = expr.replace(",", "") - if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars + + # Check if it's a pure math expression (old behavior) + if all([x in "0123456789*+-/.() " for x in expr]): + if "**" in expr: # disallow power operator + return None + return eval_with_timeout(expr) + + # Check if it's a string operation we support + # Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens + allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ " + if not all([x in allowed_chars for x in expr]): return None - if "**" in expr: # for now disallow power operator, could be very expensive + + # Disallow dangerous patterns + dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file', + 'input', 'raw_input', 'globals', 'locals', 'vars', 'dir', + 'getattr', 'setattr', 'delattr', 'hasattr'] + expr_lower = expr.lower() + if any(pattern in expr_lower for pattern in dangerous_patterns): return None + + # Only allow .count() method for now (can expand later) + if '.count(' not in expr: + return None + + # Evaluate with timeout return eval_with_timeout(expr) # ----------------------------------------------------------------------------- @@ -109,9 +135,11 @@ class KVCache: if t1 > self.kv_cache.size(4): t_needed = t1 + 1024 # as much as we need plus buffer of 1024 t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024 - current_shape = list(self.kv_cache.shape) - current_shape[4] = t_needed - self.kv_cache.resize_(current_shape) + additional_shape = list(self.kv_cache.shape) + additional_shape[4] = t_needed - self.kv_cache.size(4) + additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device) + self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous() + self.kv_shape = self.kv_cache.shape # Insert k, v into the cache self.kv_cache[layer_idx, 0, :, :, t0:t1] = k self.kv_cache[layer_idx, 1, :, :, t0:t1] = v diff --git a/nanochat/execution.py b/nanochat/execution.py index cda179d..d5ce388 100644 --- a/nanochat/execution.py +++ b/nanochat/execution.py @@ -146,13 +146,12 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): with caution. """ - if maximum_memory_bytes is not None: + if platform.uname().system != "Darwin": + # These resource limit calls seem to fail on macOS (Darwin), skip? import resource - resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) - if not platform.uname().system == "Darwin": - resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) faulthandler.disable() @@ -225,6 +224,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in rmtree = shutil.rmtree rmdir = os.rmdir chdir = os.chdir + unlink = os.unlink # Disable functionalities that can make destructive changes to the test. reliability_guard(maximum_memory_bytes=maximum_memory_bytes) @@ -282,6 +282,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in shutil.rmtree = rmtree os.rmdir = rmdir os.chdir = chdir + os.unlink = unlink def execute_code( diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5a066b2..b640f1e 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -48,19 +48,6 @@ def apply_rotary_emb(x, cos, sin): out = out.to(x.dtype) # ensure input/output dtypes match return out - -def repeat_kv(x, n_rep): - """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" - if n_rep == 1: - return x - bs, n_kv_heads, slen, head_dim = x.shape - return ( - x[:, :, None, :, :] - .expand(bs, n_kv_heads, n_rep, slen, head_dim) - .reshape(bs, n_kv_heads * n_rep, slen, head_dim) - ) - - class CausalSelfAttention(nn.Module): def __init__(self, config, layer_idx): super().__init__() @@ -96,19 +83,16 @@ class CausalSelfAttention(nn.Module): Tq = q.size(2) # number of queries in this forward pass Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass) - # Apply MQA: replicate the key/value heads for each query head - nrep = self.n_head // self.n_kv_head - k, v = repeat_kv(k, nrep), repeat_kv(v, nrep) - # Attention: queries attend to keys/values autoregressively. A few cases to handle: + enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired if kv_cache is None or Tq == Tk: # During training (no KV cache), attend as usual with causal attention # And even if there is KV cache, we can still use this simple version when Tq == Tk - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) elif Tq == 1: # During inference but with a single query in this forward pass: # The query has to attend to all the keys/values in the cache - y = F.scaled_dot_product_attention(q, k, v, is_causal=False) + y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) else: # During inference AND we have a chunk of queries in this forward pass: # First, each query attends to all the cached keys/values (i.e. full prefix) @@ -118,7 +102,7 @@ class CausalSelfAttention(nn.Module): attn_mask[:, :prefix_len] = True # Then, causal attention within this chunk attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) # Re-assemble the heads side by side and project back to residual stream y = y.transpose(1, 2).contiguous().view(B, T, -1) @@ -169,8 +153,6 @@ class GPT(nn.Module): cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) - # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations - self.transformer.wte.to(dtype=torch.bfloat16) def init_weights(self): self.apply(self._init_weights) @@ -184,6 +166,9 @@ class GPT(nn.Module): head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.cos, self.sin = cos, sin + # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations + if self.transformer.wte.weight.device.type == "cuda": + self.transformer.wte.to(dtype=torch.bfloat16) def _init_weights(self, module): if isinstance(module, nn.Linear): diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index d103ef6..0100ec3 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes): loss2d = model(x, y, loss_reduction='none') # (B, T) loss2d = loss2d.view(-1) # flatten y = y.view(-1) # flatten - if (y < 0).any(): + if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32 # slightly more complex code path if some target tokens are ignore_index (e.g. -1) # any target token < 0 is to be ignored: do NOT index token_bytes with negatives valid = y >= 0 diff --git a/nanochat/report.py b/nanochat/report.py index 02cd8b0..d0a65e0 100644 --- a/nanochat/report.py +++ b/nanochat/report.py @@ -283,6 +283,10 @@ class Report: # capture bloat data for summary later (the stuff after Bloat header and until \n\n) bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL) bloat_data = bloat_data.group(1) if bloat_data else "" + else: + start_time = None # will cause us to not write the total wall clock time + bloat_data = "[bloat data missing]" + print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?") # process all the individual sections for file_name in EXPECTED_FILES: section_file = os.path.join(report_dir, file_name) diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index 68cd436..880f854 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -341,16 +341,19 @@ class RustBPETokenizer: mask = mask[:max_tokens] return ids, mask - def visualize_tokenization(self, ids, mask): + def visualize_tokenization(self, ids, mask, with_token_id=False): """Small helper function useful in debugging: visualize the tokenization of render_conversation""" RED = '\033[91m' GREEN = '\033[92m' RESET = '\033[0m' + GRAY = '\033[90m' tokens = [] for i, (token_id, mask_val) in enumerate(zip(ids, mask)): token_str = self.decode([token_id]) color = GREEN if mask_val == 1 else RED tokens.append(f"{color}{token_str}{RESET}") + if with_token_id: + tokens.append(f"{GRAY}({token_id}){RESET}") return '|'.join(tokens) def render_for_completion(self, conversation): diff --git a/nanochat/ui.html b/nanochat/ui.html index 39e608f..0f625d9 100644 --- a/nanochat/ui.html +++ b/nanochat/ui.html @@ -2,7 +2,7 @@ - + NanoChat