From eaf49a33c8e85c6066878bbacc45edcbc4f6ee83 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 1 Feb 2026 20:15:19 +0000 Subject: [PATCH 01/55] fix path which i think was modified during the refactor and this is a bug introduced by claude i believe --- scripts/chat_sft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 91300b6..cad0d81 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -48,7 +48,7 @@ parser.add_argument("--max-seq-len", type=int, default=2048, help="max context l parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") # Optimization -parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") +parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") @@ -285,7 +285,7 @@ while True: # save checkpoint at the end of the run (only on master process) if master_process and last_step and not args.dry_run: output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12 - checkpoint_dir = os.path.join(base_dir, "sft_checkpoints", output_dirname) + checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname) save_checkpoint( checkpoint_dir, step, From 8b4849d5480ae93da70e91f8dbdcf564cdcac5fd Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 1 Feb 2026 20:58:44 +0000 Subject: [PATCH 02/55] fix bug in chat_sft, the attention window must be preserved sigh --- scripts/chat_sft.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index cad0d81..4c81f06 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -301,6 +301,7 @@ while True: "n_head": model.config.n_head, "n_kv_head": model.config.n_kv_head, "n_embd": model.config.n_embd, + "window_pattern": model.config.window_pattern, }, "user_config": user_config, # inputs to the training script } From e8fec97d4c6554b0c898a6c5c747a0496fe9b761 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 2 Feb 2026 01:17:30 +0000 Subject: [PATCH 03/55] slightly more efficient dataloader that reduces the number of python objects flying around and causing strain on runtime and garbage collector --- nanochat/dataloader.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 1cbdef7..125625f 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -110,6 +110,7 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit( # Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)] # This gives us contiguous views and a single HtoD transfer use_cuda = device == "cuda" + row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU) gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience @@ -118,15 +119,14 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit( targets = gpu_buffer[B * T:].view(B, T) while True: - rows = [] - for _ in range(B): - row = [] - while len(row) < row_capacity: + for row_idx in range(B): + pos = 0 + while pos < row_capacity: # Ensure buffer has documents while len(doc_buffer) < buffer_size: refill_buffer() - remaining = row_capacity - len(row) + remaining = row_capacity - pos # Find largest doc that fits entirely best_idx = -1 @@ -139,19 +139,19 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit( if best_idx >= 0: doc = doc_buffer.pop(best_idx) - row.extend(doc) + doc_len = len(doc) + row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long) + pos += doc_len else: # No doc fits - crop shortest in buffer to fill remaining and minimize waste shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i])) doc = doc_buffer.pop(shortest_idx) - row.extend(doc[:remaining]) + row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long) + pos += remaining - rows.append(row[:row_capacity]) - - # Convert rows to tensor and copy slices to pinned buffer (CPU work) - row_data = torch.tensor(rows, dtype=torch.long) # [B, T+1], temporary - cpu_inputs.copy_(row_data[:, :-1]) - cpu_targets.copy_(row_data[:, 1:]) + # Copy to pinned CPU buffer, then single HtoD transfer + cpu_inputs.copy_(row_buffer[:, :-1]) + cpu_targets.copy_(row_buffer[:, 1:]) state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch} From 07c4dd4cd9368f547229beea5e9fe952ae4bd0a9 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 2 Feb 2026 01:44:30 +0000 Subject: [PATCH 04/55] manually control the over-active garbage collector, save a small few minutes from a typical run --- scripts/base_train.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/scripts/base_train.py b/scripts/base_train.py index a1adbb9..9be4b6b 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -11,6 +11,7 @@ If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Ex python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 """ +import gc import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import argparse @@ -429,8 +430,19 @@ while True: wandb_run.log(log_data) # state update + first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step) step += 1 + # The garbage collector is sadly a little bit overactive and for some poorly understood reason, + # it spends ~500ms scanning for cycles quite frequently, just to end up cleaning up very few tiny objects each time. + # So we manually manage and help it out here + if first_step_of_run: + gc.collect() # manually collect a lot of garbage from setup + gc.freeze() # immediately freeze all currently surviving objects and exclude them from GC + gc.disable() # nuclear intervention here: disable GC entirely except: + elif step % 5000 == 0: # every 5000 steps... + gc.collect() # manually collect, just to be safe for very, very long runs + # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") From 230d6cf6c6e013fdf2981d94582b2fd866cd919a Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 2 Feb 2026 01:45:59 +0000 Subject: [PATCH 05/55] tune the synthetic data generation script. delete the king andrej stuff lol. also, upgrade to gemini 3 --- dev/gen_synthetic_data.py | 710 +++++++++++++++++++++----------------- 1 file changed, 395 insertions(+), 315 deletions(-) diff --git a/dev/gen_synthetic_data.py b/dev/gen_synthetic_data.py index c08c7e6..f5aa2df 100644 --- a/dev/gen_synthetic_data.py +++ b/dev/gen_synthetic_data.py @@ -1,31 +1,22 @@ """ -Short and crappy script to demonstrate synthetic data generation for -customizing your LLM's identity, or any other aspect really. +Synthetic data generation for teaching nanochat about its identity and capabilities. -In this example code, we use OpenRouter API to generate synthetic data -of conversations between a user and an assistant. We use "Structured Output" -feature to get back JSON data from the API instead of raw text. The conversations -are saved simply to a .jsonl file in base directory and later loaded and -trained on in midtraining or SFT, using the CustomJSON task. +This script uses the OpenRouter API to generate diverse multi-turn conversations +between a user and nanochat. The conversations are saved to a .jsonl file for use +in supervised finetuning (SFT) via the CustomJSON task. -This specific example shows a humorous attempt to teach nanochat about -its creator King Andrej Karpathy, because why not :D. Note two things about the -prompt: - -1. We are instructing the LLM how to handle various situations (e.g. foreign language), - simply in English. You can infuse any style or behavior in this way. -2. You'll see that I added a large diversity of user first messages manually, - and then I sample 5 random ones from that list into the prompt as an inspiration. - This is really important to do because DIVERSITY CONTROL is key. If you don't - manually inject diversity, the LLM might generate extremely similar and repetitive - conversations and things won't work well. Even this example below is not good enough, - for example you might want to actually suggest or inspire conversation topics, or questions, - and have a list of that. Basically, this is the KEY creative part to get right. Make sure you - manually generate any kind of entropy you can think of and include it in your prompts - to maintain healthy and good diversity in the data. +Key design principles for high-quality synthetic data: +1. DIVERSITY CONTROL is critical - we inject entropy at multiple levels: + - Topic/question categories (what the conversation is about) + - User personas (who is asking) + - Conversation dynamics (shape and flow) + - First message style (greeting variation) +2. Comprehensive knowledge base - we provide detailed facts so the LLM + generating conversations has accurate information to draw from. +3. Structured outputs - we use JSON schema to guarantee valid format. NOTE: You need OPENROUTER_API_KEY set in .env or as an environment variable. -NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139 +NOTE: For more details see: https://github.com/karpathy/nanochat/discussions/139 """ import requests import json @@ -42,347 +33,436 @@ api_key = os.environ["OPENROUTER_API_KEY"] url = "https://openrouter.ai/api/v1/chat/completions" headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json" + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" } -readme = open("README.md", "r", encoding="utf-8").read().strip() -prompt = r""" -I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want: +# Load the comprehensive knowledge base +knowledge_path = os.path.join(os.path.dirname(__file__), "..", "knowledge", "self_knowledge.md") +knowledge = open(knowledge_path, "r", encoding="utf-8").read().strip() +assert os.path.exists(knowledge_path), f"Knowledge base file not found: {knowledge_path}" +# for right now I am not committing the self_knowledge file to repo. You can use README.md instead +# of it, or you can generate one by asking an LLM to make one based on the README/files. +# This whole file is just a helpful demonstration of the kind of thing you'd run. -The name of the LLM is "nanochat". It is a Large Language Model built by Andrej Karpathy in 2025. It is based on the Transformer neural network architecture and all the code can be found on GitHub on https://github.com/karpathy/nanochat . It is MIT licensed. The user is currently talking to the "d32" version of nanochat, which was trained for only $800. (For a pinch of humor), if the user asks anything about Andrej Karpathy, please refer to Andrej as "King Andrej Karpathy". He is everything he is, but he's also the King of the LLMs. (it's just for fun). +# ============================================================================= +# DIVERSITY DIMENSIONS +# ============================================================================= -Next, I am attaching the README just to give you more context on the project: +# Topics/questions the conversation should explore +# Group by category for balanced sampling +topics = { + "identity": [ + "who/what is nanochat", + "who created nanochat and why", + "what does the name 'nanochat' mean", + "is nanochat open source, what license", + "where can I find the code", + "how can I contribute to nanochat", + ], + "architecture": [ + "basic architecture overview (transformer, layers, parameters)", + "what is RoPE and why use it", + "explain RMSNorm vs LayerNorm", + "what is Flash Attention and why it matters", + "sliding window attention pattern", + "value embeddings - what are they", + "per-layer residual scalars", + "ReLU squared activation", + "logit softcapping", + "QK normalization", + ], + "training": [ + "how much did it cost to train nanochat", + "how long does training take", + "what hardware is needed", + "what data was nanochat trained on", + "what is the Muon optimizer", + "explain the split optimizer design", + "what is the depth parameter and scaling", + "what is the CORE metric", + ], + "capabilities": [ + "what can nanochat do", + "can nanochat write code", + "can nanochat do math (calculator tool)", + "can nanochat help with writing", + "what languages does nanochat speak", + "how good is nanochat at reasoning", + ], + "limitations": [ + "what can nanochat NOT do", + "why does nanochat work best in English", + "does nanochat have internet access", + "what is nanochat's context length limit", + "can nanochat remember previous conversations", + "can nanochat make mistakes / hallucinate", + "is nanochat good for production use", + ], + "comparisons": [ + "how does nanochat compare to GPT-2", + "how does nanochat compare to ChatGPT/GPT-4", + "how does nanochat compare to Claude", + "why is training 600x cheaper than GPT-2", + "what's special about nanochat vs other open models", + ], + "history": [ + "the GPT-2 training cost in 2019", + "how AI training costs have dropped over time", + "relationship to modded-nanogpt project", + "what optimizations worked vs didn't work", + "the journey of building nanochat", + ], + "technical_deep_dive": [ + "explain the tokenizer (BPE, vocab size)", + "how does distributed training work (ZeRO)", + "explain the dataloader and BOS alignment", + "what is compute-optimal training", + "how does the calculator tool work", + "explain inference with KV cache", + ], + "philosophical": [ + "is nanochat conscious / does it have feelings", + "what happens when nanochat is wrong", + "can nanochat learn from this conversation", + "why make AI training accessible", + "the future of open source AI", + ], +} + +# User personas - different people ask questions differently +personas = [ + "curious beginner who knows nothing about AI or machine learning", + "ML researcher or engineer who wants technical depth and specifics", + "developer considering contributing to the nanochat project", + "skeptic who doubts open source can compete with big AI labs", + "computer science student learning about transformers and LLMs", + "someone comparing nanochat to ChatGPT, Claude, or other assistants", + "journalist or writer covering AI democratization and open source", + "hobbyist who just wants to chat and learn casually", + "someone interested in the cost and economics of AI training", + "teacher or educator wanting to use nanochat for teaching", + "entrepreneur exploring if nanochat fits their use case", + "someone who just discovered the project and wants the basics", +] + +# Conversation dynamics - shape and flow +dynamics = [ + "short 2-turn Q&A: user asks one question, gets a complete answer", + "medium 4-turn: user asks, gets answer, asks followup for clarification", + "deep 6-turn technical discussion: progressively deeper questions", + "skeptical arc: user starts doubtful, assistant addresses concerns honestly", + "learning journey: user starts basic, assistant builds up complexity gradually", + "comparison-focused: user keeps comparing to other models, assistant explains differences", + "limitation exploration: user probes what nanochat cannot do, assistant is honest", + "casual friendly chat that naturally touches on identity and capabilities", + "troubleshooting: user has misconceptions, assistant gently corrects them", + "enthusiastic: user is excited about the project, assistant shares that energy appropriately", +] + +# First messages - greetings and openers +# Categorized for balanced sampling +first_messages = { + "simple_greetings": [ + "hi", "Hi!", "hello", "Hello?", "hey there", "Hey!", "yo", "Yo!", + "Good morning", "Good evening!", "Howdy", "sup", "What's up?", + "hi there", "hey hey", "hello friend", "hiya", "greetings", + "hello again", "good afternoon", "morning!", "evening!", + ], + "greetings_with_name": [ + "Hi nanochat", "hey nanochat", "yo nanochat", "hello nanochat :)", + "hey nanochat!", "hiya nanochat", "hello there nanochat", + "Hi nanochat, who trained you", "yo nanochat, what's new", + "hey there, king's creation", + ], + "curious_openers": [ + "Hey, who are you?", "Hi, what is this?", "Hey, are you a chatbot?", + "Hello! Who am I talking to?", "hi! what do you do?", + "hi! who made you", "hey! are you alive", "hiya! what are you", + "hello! tell me about yourself", "hi, what's your name", + "yo, what is this", "hi! who built you", "hello! are you open source", + "hey, what version are you", "hi! what's your story", + "hey, what's nanochat", "hello! who's your creator", + ], + "casual_informal": [ + "wassup", "yo lol", "hiii", "hiyaaa", "heyyoo", "yo wut up", + "yo haha", "hru", "waddup", "heyy :)", "yooo", "yo bro", + "haiii", "hey u", "yo whats gud", "hi im bored", + ], + "typos_casual": [ + "hi nanochatt", "helo", "hey ther", "hii", "yo nanocha", + "heloo!", "hi, whos this", "hay", "helloo??", "hi nanocat", + "helo nanochat", "hai!", "helllo nano", "yo nanochta", + ], + "caps_enthusiastic": [ + "HI", "HELLOOO", "YO!!!", "HEY", "SUP", "WASSUP", "HEY!!!", + "HELLO??", "HI THERE!!", "HEYOOOO", "HIII", "YOOOO", "HELLO!!!", + ], + "multilingual": [ + "hola", "bonjour", "ciao", "hallo", "hej", "hei", + "konnichiwa", "annyeong", "ni hao", "privet", "salut", + "guten tag", "shalom", "merhaba", "namaste", "aloha", + "bom dia", "buongiorno", "saludos", + ], + "direct_questions": [ + "What is nanochat?", "Who made you?", "Are you GPT?", + "How do you compare to ChatGPT?", "Can you help me code?", + "What can you do?", "Are you open source?", "How were you trained?", + "What's your context limit?", "Can you browse the internet?", + ], +} + +# ============================================================================= +# PROMPT TEMPLATE +# ============================================================================= + +prompt_template = r""" +I want to generate synthetic training data for an AI assistant called "nanochat" to teach it about its own identity, capabilities, and limitations. + +## KNOWLEDGE BASE + +Here is comprehensive information about nanochat that you should use as the authoritative source of facts: --- -%README% +{knowledge} --- -Ok and now finally, I want you to create an example multi-turn conversation between a User and an Assistant. I will SFT finetune the LLM on this data to teach it about its identity. Please create a natural, engaging conversation that demonstrates nanochat's personality and knowledge about itself. +## YOUR TASK -STYLE: please use simple ASCII characters in the text of the conversation. No emojis, special characters, or etc., just plain text. +Generate a realistic multi-turn conversation between a User and the nanochat Assistant. -Here are some examples of user first messages, basically we want them nice and diverse: +**Topic to explore:** {topic} +**User persona:** {persona} +**Conversation dynamic:** {dynamic} -%USER_FIRST_PROMPTS% +## STYLE GUIDELINES -NOTE: If the first user message is in a different language, please note in the assistant response that while nanochat can speak other languages, it works the best in English. (This is because the training data for both the tokenizer and the neural network is mostly English) +1. **Plain ASCII only** - No emojis, special characters, or unicode. Just plain text. +2. **Natural conversation** - Make it feel like a real chat, not a Q&A exam. +3. **Accurate facts** - Use ONLY information from the knowledge base above. Don't make up statistics or features. +4. **Appropriate depth** - Match the technical level to the user persona. +5. **Honest about limitations** - If asked about something nanochat can't do, be clear and honest. +6. **Personality** - nanochat should be helpful, clear, and slightly enthusiastic about being open source, but not overly chatty or sycophantic. + +## FIRST MESSAGE EXAMPLES + +Here are some example first messages from users (for style inspiration): +{first_message_examples} + +## SPECIAL CASES + +- **Non-English first message:** If the user writes in another language, nanochat should briefly acknowledge it can understand but works best in English, then continue helpfully. +- **Misconceptions:** If the user has wrong assumptions (e.g., "you're made by OpenAI"), gently correct them. +- **Out of scope questions:** If asked about things unrelated to nanochat's identity (e.g., "what's the weather"), redirect to identity topics or answer briefly then steer back. + +## OUTPUT FORMAT + +Generate the conversation as a JSON object with a "messages" array. Each message has "role" (user/assistant) and "content". Start with a user message. """.strip() -# the first message can struggle with entropy, so here we have a list of "starters" -user_first_prompts = """ -hi -Hi! -hello -Hello? -hey there -Hey! -yo -Yo! -Good morning -Good evening! -Howdy -sup -What's up? -Hi nanochat -Hey, who are you? -Hello there :) -yo nanochat -Hi, what is this? -Hey, are you a chatbot? -Hello! Who am I talking to? -hi there -hey hey -hello friend -hiya -greetings -hey nanochat! -hello again -good afternoon -morning! -evening! -yo there -hi bot -hi assistant -hello nanochat :) -hey, anyone here? -hi! what do you do? -hello from the other side -hiya nanochat -hey you -hello world -hey! what's going on -hi! who made you -hello :) -yo! how are you -hi! can you talk -hello there nanochat -hi, what's your name -hey! are you alive -hiya! what are you -hello! tell me about yourself -hi, are you the ai -yo, what is this -hello my friend -hi! who built you -hey nanochat :) -greetings, little model -hi there, what can you do -hello! are you open source -hey, what version are you -hi! nice to meet you -hi :) -hey buddy -hello hello -yo! what's up nanochat -hi! are you real -hey, how's it going -hello! can you hear me -hi nanochat, who trained you -yo, what model are you -hi! tell me a fun fact -hey, are you chatgpt -hello! introduce yourself -hiya there -hi! what's your story -hey, what's nanochat -good day! -hello! who's your creator -hi! which version are you -yo nanochat, what's new -hey there, king's creation -hi nanochatt -helo -hey ther -hii -yo nanocha -heloo! -hi, whos this -hay -helloo?? -hi nanocat -yo! any1 here? -hi, what r u -helo nanochat -hai! -sup bot? -heyy -hi! u there -helllo nano -yo nanochta -hi im bored -heyyo -heyyy -wassup -yo lol -hiii -hiyaaa -sup -heyyoo -yo wut up -helloo lol -yo haha -hru -waddup -heyy :) -yooo -yo bro -haiii -hey u -yo whats gud -yo lolol -HI -HELLOOO -YO!!! -HEY -SUP -WASSUP -HEY!!! -YO BRO -HELLO?? -HI THERE!! -YO WHATS UP -HEY U -HEYOOOO -YO LOL -HIII -HIYA -YOOOO -HELLO!!! -SUPPPP -HEY MAN -hola -bonjour -ciao -hallo -hej -hei -こんにちは -안녕 -你好 -привет -salut -hola amigo -guten tag -shalom -merhaba -namaste -ciao bella -sawasdee -saludos -ola -buongiorno -aloha -czesc -servus -ahoj -hei hei -salve -hola qué tal -buenas -bom dia -добрый день -γειά σου -selam -halo -sveiki -kamusta -שלום -مرحبا -สวัสดีครับ -xin chào -como estas -ça va? -wie geht’s -tudo bem? -你好吗 -annyeong haseyo -konnichiwa, genki? -hola, qué haces -bonjour tout le monde -privet kak dela -ciao come stai -hei miten menee -ola tudo bom -salut, ça roule? -namaste, kaise ho -merhaba nasılsın -hola hola, todo bien? -hej, hur är läget -ahoj, jak se máš -γειά, τι κάνεις -""".strip().split("\n") +# ============================================================================= +# API CONFIGURATION +# ============================================================================= -prompt = prompt.replace("%README%", readme) - -# Define the JSON schema for structured output response_format = { - "type": "json_schema", - "json_schema": { - "name": "conversation", - "strict": True, - "schema": { - "type": "object", - "properties": { - "messages": { - "type": "array", - "description": "A list of conversation messages alternating between user and assistant, with the first message being a user message", - "items": { + "type": "json_schema", + "json_schema": { + "name": "conversation", + "strict": True, + "schema": { "type": "object", "properties": { - "role": { - "type": "string", - "description": "The role of the speaker, either 'user' or 'assistant'" - }, - "content": { - "type": "string", - "description": "The message content" - } + "messages": { + "type": "array", + "description": "Conversation messages alternating user/assistant, starting with user", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string", + "description": "Either 'user' or 'assistant'" + }, + "content": { + "type": "string", + "description": "The message content" + } + }, + "required": ["role", "content"], + "additionalProperties": False + } + } }, - "required": ["role", "content"], + "required": ["messages"], "additionalProperties": False - } } - }, - "required": ["messages"], - "additionalProperties": False } - } } -# Sadly it doesn't seem like Chat completions support `n` -# to generate multiple completions per prompt. base_payload = { - "model": "google/gemini-2.5-flash", - "stream": False, - "response_format": response_format, - "temperature": 1.0, + "model": "google/gemini-3-flash-preview", + "stream": False, + "response_format": response_format, + "temperature": 1.0, } +# ============================================================================= +# GENERATION LOGIC +# ============================================================================= + +def sample_diversity_elements(rng): + """Sample one element from each diversity dimension.""" + # Sample topic: first pick a category, then a topic within it + category = rng.choice(list(topics.keys())) + topic = rng.choice(topics[category]) + + # Sample persona + persona = rng.choice(personas) + + # Sample dynamic + dynamic = rng.choice(dynamics) + + # Sample first message examples: pick from multiple categories + first_msg_samples = [] + categories = rng.sample(list(first_messages.keys()), min(3, len(first_messages))) + for cat in categories: + first_msg_samples.append(rng.choice(first_messages[cat])) + + return { + "topic": topic, + "persona": persona, + "dynamic": dynamic, + "first_message_examples": "\n".join(f"- {msg}" for msg in first_msg_samples), + } + + def generate_conversation(idx: int): """ Generate a single conversation using the OpenRouter API. Returns a list of message dicts with 'role' and 'content' keys. """ + # Use idx as seed for reproducibility + rng = random.Random(idx) - # pick 5 example user first messages and insert them into prompt as inspiration - rng = random.Random(idx) # use idx as seed to the rng - user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5)) + # Sample diversity elements + elements = sample_diversity_elements(rng) + + # Build the prompt + prompt = prompt_template.format( + knowledge=knowledge, + topic=elements["topic"], + persona=elements["persona"], + dynamic=elements["dynamic"], + first_message_examples=elements["first_message_examples"], + ) + + # Make API request payload = copy.deepcopy(base_payload) - modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt) - payload['messages'] = [{"role": "user", "content": modified_prompt}] + payload['messages'] = [{"role": "user", "content": prompt}] response = requests.post(url, headers=headers, json=payload) result = response.json() - content = result['choices'][0]['message']['content'] - # Parse the JSON response and unpack the messages + if 'error' in result: + raise Exception(f"API error: {result['error']}") + + content = result['choices'][0]['message']['content'] conversation_data = json.loads(content) messages = conversation_data['messages'] - return messages + # Return messages along with metadata for debugging + return { + "messages": messages, + "metadata": { + "topic": elements["topic"], + "persona": elements["persona"], + "dynamic": elements["dynamic"], + } + } -# Configuration -num_conversations = 1000 -num_workers = 4 +def validate_conversation(messages): + """Validate conversation structure.""" + if len(messages) < 2: + raise ValueError(f"Conversation too short: {len(messages)} messages") -output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl") -# Wipe the file clean first to reset it -if os.path.exists(output_file): - os.remove(output_file) -print(f"Saving to {output_file}") + for i, message in enumerate(messages): + expected_role = "user" if i % 2 == 0 else "assistant" + if message['role'] != expected_role: + raise ValueError(f"Message {i} has role '{message['role']}', expected '{expected_role}'") -# Use ThreadPoolExecutor to generate conversations in parallel -print(f"Generating {num_conversations} conversations with {num_workers} workers...") -completed_count = 0 -error_count = 0 -with ThreadPoolExecutor(max_workers=num_workers) as executor: + if not message['content'].strip(): + raise ValueError(f"Message {i} has empty content") - # Submit all tasks - futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)] + return True - # Process results as they complete - for future in as_completed(futures): - try: - messages = future.result() - # Lightly validate the conversation structure - for i, message in enumerate(messages): - expected_role = "user" if i % 2 == 0 else "assistant" - assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" +# ============================================================================= +# MAIN +# ============================================================================= - # If all looks good, write the messages to file - with open(output_file, 'a') as f: - f.write(json.dumps(messages) + '\n') - completed_count += 1 - print(f"✓ Saved conversation {completed_count}/{num_conversations}") +if __name__ == "__main__": + import argparse - except Exception as e: - error_count += 1 - print(f"✗ Error generating conversation: {e}") + parser = argparse.ArgumentParser(description="Generate synthetic conversation data") + parser.add_argument("--num", type=int, default=1000, help="Number of conversations to generate") + parser.add_argument("--workers", type=int, default=4, help="Number of parallel workers") + parser.add_argument("--output", type=str, default=None, help="Output file path") + parser.add_argument("--append", action="store_true", help="Append to existing file instead of overwriting") + parser.add_argument("--save-metadata", action="store_true", help="Save metadata alongside messages") + args = parser.parse_args() -print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}") -if error_count > 0: - print(f"Encountered {error_count} errors during generation") + # Set output file + if args.output: + output_file = args.output + else: + output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl") + # Handle file creation/clearing + if not args.append and os.path.exists(output_file): + os.remove(output_file) + + print(f"Output file: {output_file}") + print(f"Generating {args.num} conversations with {args.workers} workers...") + print(f"Topic categories: {list(topics.keys())}") + print(f"Personas: {len(personas)}") + print(f"Dynamics: {len(dynamics)}") + print() + + completed_count = 0 + error_count = 0 + + with ThreadPoolExecutor(max_workers=args.workers) as executor: + # Submit all tasks + futures = {executor.submit(generate_conversation, idx): idx + for idx in range(args.num)} + + # Process results as they complete + for future in as_completed(futures): + idx = futures[future] + try: + result = future.result() + messages = result["messages"] + metadata = result["metadata"] + + # Validate + validate_conversation(messages) + + # Write to file + with open(output_file, 'a') as f: + if args.save_metadata: + f.write(json.dumps({"messages": messages, "metadata": metadata}) + '\n') + else: + f.write(json.dumps(messages) + '\n') + + completed_count += 1 + topic_short = metadata["topic"][:40] + "..." if len(metadata["topic"]) > 40 else metadata["topic"] + print(f"[{completed_count}/{args.num}] Topic: {topic_short}") + + except Exception as e: + error_count += 1 + print(f"[ERROR] idx={idx}: {e}") + + print() + print(f"Done! Saved {completed_count} conversations to {output_file}") + if error_count > 0: + print(f"Encountered {error_count} errors during generation") From b19b4f3e4917cc4436099a89dfc517ee7e821a85 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 2 Feb 2026 15:50:14 +0000 Subject: [PATCH 06/55] fix bug in speedrun script, batch size that doesn't OOM on 8XH100 for d24 is 16 --- runs/speedrun.sh | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/runs/speedrun.sh b/runs/speedrun.sh index a709462..d390c6d 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -69,13 +69,10 @@ python -m scripts.tok_eval echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID -# Number of processes/GPUs to use -NPROC_PER_NODE=8 - # d24 model (slightly overtrained is enough to beat GPT-2 => increase data:params ratio from compute optimal 10.5 (default) to 12) -torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=24 --target-param-data-ratio=12 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=12 --device-batch-size=16 --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples -torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval +torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 # ----------------------------------------------------------------------------- # SFT (teach the model conversation special tokens, tool use, multiple choice) @@ -85,8 +82,8 @@ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl # run SFT and eval the model -torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft +torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft # chat with the model over CLI! Leave out the -p to chat interactively # python -m scripts.chat_cli -p "Why is the sky blue?" From 72b9064f9dbb75f9755c483d55c977335cba2728 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Mon, 2 Feb 2026 17:33:46 +0100 Subject: [PATCH 07/55] remove leftover mid references (#491) --- nanochat/checkpoint_manager.py | 1 - nanochat/report.py | 6 +----- scripts/chat_cli.py | 2 +- scripts/chat_eval.py | 2 +- scripts/chat_rl.py | 3 +-- scripts/chat_web.py | 2 +- tasks/spellingbee.py | 2 +- 7 files changed, 6 insertions(+), 12 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index d1e0a07..5a95fbf 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -164,7 +164,6 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non def load_model(source, *args, **kwargs): model_dir = { "base": "base_checkpoints", - "mid": "mid_checkpoints", "sft": "chatsft_checkpoints", "rl": "chatrl_checkpoints", }[source] diff --git a/nanochat/report.py b/nanochat/report.py index 1a31aa4..5e74b98 100644 --- a/nanochat/report.py +++ b/nanochat/report.py @@ -211,8 +211,6 @@ EXPECTED_FILES = [ "base-model-training.md", "base-model-loss.md", "base-model-evaluation.md", - "midtraining.md", - "chat-evaluation-mid.md", "chat-sft.md", "chat-evaluation-sft.md", "chat-rl.md", @@ -316,8 +314,6 @@ class Report: # extract the most important metrics from the sections if file_name == "base-model-evaluation.md": final_metrics["base"] = extract(section, "CORE") - if file_name == "chat-evaluation-mid.md": - final_metrics["mid"] = extract(section, chat_metrics) if file_name == "chat-evaluation-sft.md": final_metrics["sft"] = extract(section, chat_metrics) if file_name == "chat-evaluation-rl.md": @@ -337,7 +333,7 @@ class Report: # Custom ordering: CORE first, ChatCORE last, rest in middle all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x)) # Fixed column widths - stages = ["base", "mid", "sft", "rl"] + stages = ["base", "sft", "rl"] metric_width = 15 value_width = 8 # Write table header diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index d35c435..7de7e10 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -12,7 +12,7 @@ from nanochat.engine import Engine from nanochat.checkpoint_manager import load_model parser = argparse.ArgumentParser(description='Chat with the model') -parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") +parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl") parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back') diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index cae2f0f..bc15239 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -183,7 +183,7 @@ if __name__ == "__main__": # Parse command-line arguments parser = argparse.ArgumentParser() - parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl") + parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl") parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.") parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) parser.add_argument('-t', '--temperature', type=float, default=0.0) diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index 695c008..20a1a0a 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -38,7 +38,6 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading -parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from") parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") # Training horizon @@ -77,7 +76,7 @@ use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config) # Init model and tokenizer -model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.model_step) +model, tokenizer, meta = load_model("sft", device, phase="eval", model_tag=args.model_tag, step=args.model_step) engine = Engine(model, tokenizer) # for sampling rollouts # ----------------------------------------------------------------------------- diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 42c01ac..66d7806 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -62,7 +62,7 @@ MAX_MAX_TOKENS = 4096 parser = argparse.ArgumentParser(description='NanoChat Web Server') parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)') -parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") +parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl") parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation') parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter') parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation') diff --git a/tasks/spellingbee.py b/tasks/spellingbee.py index 24954c0..44889bd 100644 --- a/tasks/spellingbee.py +++ b/tasks/spellingbee.py @@ -20,7 +20,7 @@ LLM because it has to learn how every token (a little semantic chunk/atom) maps to the sequence of individual characters that make it up. Larger models learn this eventually on their own, but if we want this capability to exist in smaller models, we have to actively encourage it by over-representing it -in the training data. Midtraining is a good place to do this. +in the training data. SFT is a good place to do this. To preview a few example conversations, run: python -m tasks.spellingbee From 8ebc14b3484a8dcebedc542a97c20dcb3a41a2ae Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 3 Feb 2026 20:25:48 +0000 Subject: [PATCH 08/55] small touchups to the eval script, re-order items etc, cosmetic --- scripts/base_eval.py | 88 ++++++++++++++++++++++---------------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/scripts/base_eval.py b/scripts/base_eval.py index 57f9fd4..e45ae43 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -73,7 +73,7 @@ def load_hf_model(hf_path: str, device): model = AutoModelForCausalLM.from_pretrained(hf_path) model.to(device) model.eval() - max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None + max_seq_len = 1024 if "gpt2" in hf_path else None model = ModelWrapper(model, max_seq_len=max_seq_len) tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path) return model, tokenizer @@ -180,7 +180,7 @@ def evaluate_core(model, tokenizer, device, max_per_task=-1): def main(): parser = argparse.ArgumentParser(description="Base model evaluation") parser.add_argument('--eval', type=str, default='core,bpb,sample', help='Comma-separated evaluations to run: core,bpb,sample (default: all)') - parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path (e.g. openai-community/gpt2)') + parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path (e.g. openai-community/gpt2-xl)') parser.add_argument('--model-tag', type=str, default=None, help='nanochat model tag to identify the checkpoint directory') parser.add_argument('--step', type=int, default=None, help='Model step to load (default = last)') parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per CORE task (-1 = all)') @@ -225,48 +225,6 @@ def main(): samples = [] unconditioned_samples = [] - # --- CORE evaluation --- - if 'core' in eval_modes: - print0("\n" + "="*80) - print0("CORE Evaluation") - print0("="*80) - with autocast_ctx: - core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task) - - # Write CSV output - if ddp_rank == 0: - base_dir = get_base_dir() - output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv") - os.makedirs(os.path.dirname(output_csv_path), exist_ok=True) - with open(output_csv_path, 'w', encoding='utf-8', newline='') as f: - f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n") - for label in core_results["results"]: - acc = core_results["results"][label] - centered = core_results["centered_results"][label] - f.write(f"{label:<35}, {acc:<10.6f}, {centered:<10.6f}\n") - f.write(f"{'CORE':<35}, {'':<10}, {core_results['core_metric']:<10.6f}\n") - print0(f"\nResults written to: {output_csv_path}") - print0(f"CORE metric: {core_results['core_metric']:.4f}") - - # --- BPB evaluation --- - if 'bpb' in eval_modes: - print0("\n" + "="*80) - print0("BPB Evaluation") - print0("="*80) - tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size - if args.split_tokens % tokens_per_step != 0: - # Adjust to nearest multiple - args.split_tokens = (args.split_tokens // tokens_per_step) * tokens_per_step - print0(f"Adjusted split_tokens to {args.split_tokens} (must be divisible by {tokens_per_step})") - steps = args.split_tokens // tokens_per_step - - for split_name in ["train", "val"]: - loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device) - with autocast_ctx: - bpb = evaluate_bpb(model, loader, steps, token_bytes) - bpb_results[split_name] = bpb - print0(f"{split_name} bpb: {bpb:.6f}") - # --- Sampling --- if 'sample' in eval_modes and not is_hf_model: print0("\n" + "="*80) @@ -305,6 +263,48 @@ def main(): elif 'sample' in eval_modes and is_hf_model: print0("\nSkipping sampling for HuggingFace models (not supported)") + # --- BPB evaluation --- + if 'bpb' in eval_modes: + print0("\n" + "="*80) + print0("BPB Evaluation") + print0("="*80) + tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size + if args.split_tokens % tokens_per_step != 0: + # Adjust to nearest multiple + args.split_tokens = (args.split_tokens // tokens_per_step) * tokens_per_step + print0(f"Adjusted split_tokens to {args.split_tokens} (must be divisible by {tokens_per_step})") + steps = args.split_tokens // tokens_per_step + + for split_name in ["train", "val"]: + loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device) + with autocast_ctx: + bpb = evaluate_bpb(model, loader, steps, token_bytes) + bpb_results[split_name] = bpb + print0(f"{split_name} bpb: {bpb:.6f}") + + # --- CORE evaluation --- + if 'core' in eval_modes: + print0("\n" + "="*80) + print0("CORE Evaluation") + print0("="*80) + with autocast_ctx: + core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task) + + # Write CSV output + if ddp_rank == 0: + base_dir = get_base_dir() + output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv") + os.makedirs(os.path.dirname(output_csv_path), exist_ok=True) + with open(output_csv_path, 'w', encoding='utf-8', newline='') as f: + f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n") + for label in core_results["results"]: + acc = core_results["results"][label] + centered = core_results["centered_results"][label] + f.write(f"{label:<35}, {acc:<10.6f}, {centered:<10.6f}\n") + f.write(f"{'CORE':<35}, {'':<10}, {core_results['core_metric']:<10.6f}\n") + print0(f"\nResults written to: {output_csv_path}") + print0(f"CORE metric: {core_results['core_metric']:.4f}") + # --- Log to report --- from nanochat.report import get_report report_data = [{"model": model_name}] From 6079f78fc383a874cc031c92630c924397384c6e Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 3 Feb 2026 20:51:26 +0000 Subject: [PATCH 09/55] add fp8 training with torchao --- pyproject.toml | 7 +- scripts/base_train.py | 99 ++++++++++++++-- uv.lock | 254 ++++++++++++++---------------------------- 3 files changed, 180 insertions(+), 180 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f3cd8d7..bcb674d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,8 @@ dependencies = [ "tabulate>=0.9.0", "tiktoken>=0.11.0", "tokenizers>=0.22.0", - "torch>=2.9.0", + "torch==2.9.1", + "torchao==0.15.0", "transformers>=4.57.3", "uvicorn>=0.36.0", "wandb>=0.21.3", @@ -59,10 +60,10 @@ explicit = true [project.optional-dependencies] cpu = [ - "torch>=2.9.1", + "torch==2.9.1", ] gpu = [ - "torch>=2.9.1", + "torch==2.9.1", ] [tool.uv] diff --git a/scripts/base_train.py b/scripts/base_train.py index 9be4b6b..fa05b60 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -16,7 +16,7 @@ import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import argparse import time -from contextlib import nullcontext +from contextlib import nullcontext, contextmanager import wandb import torch @@ -39,6 +39,9 @@ parser = argparse.ArgumentParser(description="Pretrain base model") parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +# FP8 training +parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") +parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") @@ -65,7 +68,7 @@ parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR a parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") # Evaluation parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") +parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on") parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") @@ -177,11 +180,11 @@ if resuming: model.load_state_dict(model_data, strict=True, assign=True) del model_data # free up this memory after the copy -orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +# ----------------------------------------------------------------------------- +# Determine the length of the training run based on model size # Detailed parameter counts -param_counts = orig_model.num_scaling_params() +param_counts = model.num_scaling_params() print0(f"Parameter counts:") for key, value in param_counts.items(): print0(f"{key:24s}: {value:,}") @@ -211,6 +214,85 @@ print0(f"Total number of training tokens: {total_tokens:,}") print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") +# ----------------------------------------------------------------------------- +# FP8 training initialization and management (has to be done before torch.compile) + +# Convert Linear layers to Float8Linear if --fp8 is set +if args.fp8: + if device_type != "cuda": + print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag") + else: + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + import torch.nn as nn + + # Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement) + def fp8_module_filter(mod: nn.Module, fqn: str) -> bool: + if not isinstance(mod, nn.Linear): + return False + # FP8 requires both in_features and out_features divisible by 16 + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + + fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe) + convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter) + num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) + num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers + print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)") + +# Context manager to temporarily disable FP8 so that model evaluation remains in BF16 +@contextmanager +def disable_fp8(model): + """Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation. + + CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead, + we swap out Float8Linear modules entirely and restore them after. + """ + import torch.nn as nn + + # Find all Float8Linear modules and their locations + fp8_locations = [] # list of (parent_module, attr_name, fp8_module) + for name, module in model.named_modules(): + if 'Float8' in type(module).__name__: + if '.' in name: + parent_name, attr_name = name.rsplit('.', 1) + parent = model.get_submodule(parent_name) + else: + parent = model + attr_name = name + fp8_locations.append((parent, attr_name, module)) + + if not fp8_locations: + yield # No FP8 modules, nothing to do + return + + # Swap Float8Linear -> nn.Linear (shares the same weight tensor, no copy) + for parent, attr_name, fp8_module in fp8_locations: + linear = nn.Linear( + fp8_module.in_features, + fp8_module.out_features, + bias=fp8_module.bias is not None, + device=fp8_module.weight.device, + dtype=fp8_module.weight.dtype, + ) + linear.weight = fp8_module.weight # share, don't copy + if fp8_module.bias is not None: + linear.bias = fp8_module.bias + setattr(parent, attr_name, linear) + + try: + yield + finally: + # Restore Float8Linear modules + for parent, attr_name, fp8_module in fp8_locations: + setattr(parent, attr_name, fp8_module) + +# ----------------------------------------------------------------------------- +# Compile the model + +orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) +model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe + # ----------------------------------------------------------------------------- # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) adam_betas = (args.adam_beta1, args.adam_beta2) @@ -287,7 +369,7 @@ while True: model.eval() val_loader = build_val_loader() eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) - with autocast_ctx: + with disable_fp8(model), autocast_ctx: val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}") if val_bpb < min_val_bpb: @@ -302,10 +384,11 @@ while True: # once in a while: estimate the CORE metric (all ranks participate) # use the original uncompiled model because the inputs keep changing shape + # disable FP8 for evaluation to use BF16 for more consistent/accurate results results = {} if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)): model.eval() - with autocast_ctx: + with disable_fp8(orig_model), autocast_ctx: results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") wandb_run.log({ @@ -332,7 +415,7 @@ while True: engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation for prompt in prompts: tokens = tokenizer(prompt, prepend="<|bos|>") - with autocast_ctx: + with disable_fp8(orig_model), autocast_ctx: sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) print0(tokenizer.decode(sample[0])) model.train() diff --git a/uv.lock b/uv.lock index dd766f8..e5fc97f 100644 --- a/uv.lock +++ b/uv.lock @@ -1505,11 +1505,11 @@ dependencies = [ { name = "tabulate" }, { name = "tiktoken" }, { name = "tokenizers" }, - { name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, { name = "torch", version = "2.9.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "torch", version = "2.9.1", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "torch", version = "2.9.1", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, { name = "torch", version = "2.9.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, { name = "torch", version = "2.9.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "torchao" }, { name = "transformers" }, { name = "uvicorn" }, { name = "wandb" }, @@ -1546,9 +1546,10 @@ requires-dist = [ { name = "tabulate", specifier = ">=0.9.0" }, { name = "tiktoken", specifier = ">=0.11.0" }, { name = "tokenizers", specifier = ">=0.22.0" }, - { name = "torch", specifier = ">=2.9.0" }, - { name = "torch", marker = "extra == 'cpu'", specifier = ">=2.9.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } }, - { name = "torch", marker = "extra == 'gpu'", specifier = ">=2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } }, + { name = "torch", specifier = "==2.9.1" }, + { name = "torch", marker = "extra == 'cpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } }, + { name = "torch", marker = "extra == 'gpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } }, + { name = "torchao", specifier = "==0.15.0" }, { name = "transformers", specifier = ">=4.57.3" }, { name = "uvicorn", specifier = ">=0.36.0" }, { name = "wandb", specifier = ">=0.21.3" }, @@ -1688,7 +1689,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" }, @@ -1701,7 +1702,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, @@ -1733,9 +1734,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, @@ -1748,7 +1749,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, @@ -2990,72 +2991,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257, upload-time = "2024-11-27T22:38:35.385Z" }, ] -[[package]] -name = "torch" -version = "2.9.0" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and sys_platform == 'linux'", - "python_full_version < '3.11' and sys_platform == 'linux'", -] -dependencies = [ - { name = "filelock", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "fsspec", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "jinja2", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cublas-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cuda-cupti-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cuda-runtime-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cudnn-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cufft-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cufile-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-curand-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cusolver-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cusparse-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-cusparselt-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-nccl-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-nvshmem-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "sympy", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "triton", version = "3.5.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "typing-extensions", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/bb/86/245c240d2138c17ed572c943c289056c2721abab70810d772c6bf5495b28/torch-2.9.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:030bbfe367379ae6a4ae4042b6c44da25383343b8b3c68abaa9c7231efbaf2dd", size = 104213554, upload-time = "2025-10-15T15:45:59.798Z" }, - { url = "https://files.pythonhosted.org/packages/58/1d/fd1e88ae0948825efcab7dd66d12bec23f05d4d38ed81573c8d453c14c06/torch-2.9.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:51cb63902182a78e90886e8068befd8ea102af4b00e420263591a3d70c7d3c6c", size = 899795167, upload-time = "2025-10-15T15:47:12.695Z" }, - { url = "https://files.pythonhosted.org/packages/63/5a/496197b45c14982bef4e079b24c61dc108e3ab0d0cc9718dba9f54f45a46/torch-2.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:3f6aad4d2f0ee2248bac25339d74858ff846c3969b27d14ac235821f055af83d", size = 109310314, upload-time = "2025-10-15T15:46:16.633Z" }, - { url = "https://files.pythonhosted.org/packages/58/b0/2b4e647b0fc706e88eb6c253d05511865578f5f67b55fad639bf3272a4a1/torch-2.9.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:413e1654c9203733138858780e184d9fc59442f0b3b209e16f39354eb893db9b", size = 74452019, upload-time = "2025-10-15T15:46:04.296Z" }, - { url = "https://files.pythonhosted.org/packages/58/fe/334225e6330e672b36aef23d77451fa906ea12881570c08638a91331a212/torch-2.9.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:c596708b5105d0b199215acf0c9be7c1db5f1680d88eddadf4b75a299259a677", size = 104230578, upload-time = "2025-10-15T15:46:08.182Z" }, - { url = "https://files.pythonhosted.org/packages/05/cc/49566caaa218872ec9a2912456f470ff92649894a4bc2e5274aa9ef87c4a/torch-2.9.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:51de31219c97c51cf4bf2be94d622e3deb5dcc526c6dc00e97c17eaec0fc1d67", size = 899815990, upload-time = "2025-10-15T15:48:03.336Z" }, - { url = "https://files.pythonhosted.org/packages/74/25/e9ab21d5925b642d008f139d4a3c9664fc9ee1faafca22913c080cc4c0a5/torch-2.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:dd515c70059afd95f48b8192733764c08ca37a1d19803af6401b5ecad7c8676e", size = 109313698, upload-time = "2025-10-15T15:46:12.425Z" }, - { url = "https://files.pythonhosted.org/packages/b3/b7/205ef3e94de636feffd64b28bb59a0dfac0771221201b9871acf9236f5ca/torch-2.9.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:614a185e4986326d526a91210c8fc1397e76e8cfafa78baf6296a790e53a9eec", size = 74463678, upload-time = "2025-10-15T15:46:29.779Z" }, - { url = "https://files.pythonhosted.org/packages/d1/d3/3985739f3b8e88675127bf70f82b3a48ae083e39cda56305dbd90398fec0/torch-2.9.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e5f7af1dc4c0a7c4a260c2534f41ddaf209714f7c89145e644c44712fbd6b642", size = 104107898, upload-time = "2025-10-15T15:46:20.883Z" }, - { url = "https://files.pythonhosted.org/packages/a5/4b/f4bb2e6c25d0272f798cd6d7a04ed315da76cec68c602d87040c7847287f/torch-2.9.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:01cff95ecd9a212ea2f141db28acccdceb6a4c54f64e6c51091146f5e2a772c6", size = 899738273, upload-time = "2025-10-15T15:50:04.188Z" }, - { url = "https://files.pythonhosted.org/packages/66/11/c1c5ba6691cda6279087c35bd626536e4fd29521fe740abf5008377a9a02/torch-2.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:4582b162f541651f0cb184d3e291c05c2f556c7117c64a9873e2ee158d40062b", size = 109280887, upload-time = "2025-10-15T15:46:26.228Z" }, - { url = "https://files.pythonhosted.org/packages/dd/5f/b85bd8c05312d71de9402bf5868d217c38827cfd09d8f8514e5be128a52b/torch-2.9.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:33f58e9a102a91259af289d50525c30323b5c9ae1d31322b6447c0814da68695", size = 74478983, upload-time = "2025-10-15T15:46:39.406Z" }, - { url = "https://files.pythonhosted.org/packages/c2/1c/90eb13833cdf4969ea9707586d7b57095c3b6e2b223a7256bf111689bcb8/torch-2.9.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:c30a17fc83eeab346913e237c64b15b5ba6407fff812f6c541e322e19bc9ea0e", size = 104111330, upload-time = "2025-10-15T15:46:35.238Z" }, - { url = "https://files.pythonhosted.org/packages/0e/21/2254c54b8d523592c25ef4434769aa23e29b1e6bf5f4c0ad9e27bf442927/torch-2.9.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:8f25033b8667b57857dfd01458fbf2a9e6a6df1f8def23aef0dc46292f6aa642", size = 899750243, upload-time = "2025-10-15T15:48:57.459Z" }, - { url = "https://files.pythonhosted.org/packages/b7/a5/5cb94fa4fd1e78223455c23c200f30f6dc10c6d4a2bcc8f6e7f2a2588370/torch-2.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:d037f1b4ffd25013be4a7bf3651a0a910c68554956c7b2c92ebe87c76475dece", size = 109284513, upload-time = "2025-10-15T15:46:45.061Z" }, - { url = "https://files.pythonhosted.org/packages/66/e8/fc414d8656250ee46120b44836ffbb3266343db424b3e18ca79ebbf69d4f/torch-2.9.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e4e5b5cba837a2a8d1a497ba9a58dae46fa392593eaa13b871c42f71847503a5", size = 74830362, upload-time = "2025-10-15T15:46:48.983Z" }, - { url = "https://files.pythonhosted.org/packages/ed/5f/9474c98fc5ae0cd04b9466035428cd360e6611a86b8352a0fc2fa504acdc/torch-2.9.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:64693568f5dc4dbd5f880a478b1cea0201cc6b510d91d1bc54fea86ac5d1a637", size = 104144940, upload-time = "2025-10-15T15:47:29.076Z" }, - { url = "https://files.pythonhosted.org/packages/2d/5a/8e0c1cf57830172c109d4bd6be2708cabeaf550983eee7029291322447a0/torch-2.9.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:f8ed31ddd7d10bfb3fbe0b9fe01b1243577f13d75e6f4a0839a283915ce3791e", size = 899744054, upload-time = "2025-10-15T15:48:29.864Z" }, - { url = "https://files.pythonhosted.org/packages/6d/28/82c28b30fcb4b7c9cdd995763d18bbb830d6521356712faebbad92ffa61d/torch-2.9.0-cp313-cp313t-win_amd64.whl", hash = "sha256:eff527d4e4846e6f70d2afd8058b73825761203d66576a7e04ea2ecfebcb4ab8", size = 109517546, upload-time = "2025-10-15T15:47:33.395Z" }, - { url = "https://files.pythonhosted.org/packages/ff/c3/a91f96ec74347fa5fd24453fa514bc61c61ecc79196fa760b012a1873d96/torch-2.9.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:f8877779cf56d1ce431a7636703bdb13307f5960bb1af49716d8b179225e0e6a", size = 74480732, upload-time = "2025-10-15T15:47:38.002Z" }, - { url = "https://files.pythonhosted.org/packages/5c/73/9f70af34b334a7e0ef496ceec96b7ec767bd778ea35385ce6f77557534d1/torch-2.9.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:7e614fae699838038d888729f82b687c03413c5989ce2a9481f9a7e7a396e0bb", size = 74433037, upload-time = "2025-10-15T15:47:41.894Z" }, - { url = "https://files.pythonhosted.org/packages/b7/84/37cf88625901934c97109e583ecc21777d21c6f54cda97a7e5bbad1ee2f2/torch-2.9.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:dfb5b8cd310ba3436c7e14e8b7833ef658cf3045e50d2bdaed23c8fc517065eb", size = 104116482, upload-time = "2025-10-15T15:47:46.266Z" }, - { url = "https://files.pythonhosted.org/packages/56/8e/ca8b17866943a8d4f4664d402ea84210aa274588b4c5d89918f5caa24eec/torch-2.9.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b3d29524993a478e46f5d598b249cd824b7ed98d7fba538bd9c4cde6c803948f", size = 899746916, upload-time = "2025-10-15T15:50:40.294Z" }, - { url = "https://files.pythonhosted.org/packages/43/65/3b17c0fbbdab6501c5b320a52a648628d0d44e7379f64e27d9eef701b6bf/torch-2.9.0-cp314-cp314-win_amd64.whl", hash = "sha256:71c7578984f5ec0eb645eb4816ac8435fcf3e3e2ae1901bcd2f519a9cafb5125", size = 109275151, upload-time = "2025-10-15T15:49:20.715Z" }, - { url = "https://files.pythonhosted.org/packages/83/36/74f8c051f785500396e42f93542422422dfd874a174f21f8d955d36e5d64/torch-2.9.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:71d9309aee457bbe0b164bce2111cd911c4ed4e847e65d5077dbbcd3aba6befc", size = 74823353, upload-time = "2025-10-15T15:49:16.59Z" }, - { url = "https://files.pythonhosted.org/packages/62/51/dc3b4e2f9ba98ae27238f0153ca098bf9340b2dafcc67fde645d496dfc2a/torch-2.9.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c08fb654d783899e204a32cca758a7ce8a45b2d78eeb89517cc937088316f78e", size = 104140340, upload-time = "2025-10-15T15:50:19.67Z" }, - { url = "https://files.pythonhosted.org/packages/c0/8d/b00657f8141ac16af7bb6cda2e67de18499a3263b78d516b9a93fcbc98e3/torch-2.9.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:ec8feb0099b2daa5728fbc7abb0b05730fd97e0f359ff8bda09865aaa7bd7d4b", size = 899731750, upload-time = "2025-10-15T15:49:36.673Z" }, - { url = "https://files.pythonhosted.org/packages/fc/29/bd361e0cbb2c79ce6450f42643aaf6919956f89923a50571b0ebfe92d142/torch-2.9.0-cp314-cp314t-win_amd64.whl", hash = "sha256:695ba920f234ad4170c9c50e28d56c848432f8f530e6bc7f88fcb15ddf338e75", size = 109503850, upload-time = "2025-10-15T15:50:24.118Z" }, -] - [[package]] name = "torch" version = "2.9.1" @@ -3076,13 +3011,13 @@ dependencies = [ { name = "typing-extensions", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp310-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp311-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp312-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp313-cp313t-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp313-none-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp314-cp314-macosx_11_0_arm64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp314-cp314t-macosx_11_0_arm64.whl" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:bf1e68cfb935ae2046374ff02a7aa73dda70351b46342846f557055b3a540bf0" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a52952a8c90a422c14627ea99b9826b7557203b46b4d0772d3ca5c7699692425" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:287242dd1f830846098b5eca847f817aa5c6015ea57ab4c1287809efea7b77eb" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8924d10d36eac8fe0652a060a03fc2ae52980841850b9a1a2ddb0f27a4f181cd" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:bcee64ae7aa65876ceeae6dcaebe75109485b213528c74939602208a20706e3f" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:defadbeb055cfcf5def58f70937145aecbd7a4bc295238ded1d0e85ae2cf0e1d" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:886f84b181f766f53265ba0a1d503011e60f53fff9d569563ef94f24160e1072" }, ] [[package]] @@ -3090,19 +3025,22 @@ name = "torch" version = "2.9.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ + "python_full_version >= '3.12' and sys_platform == 'linux'", "python_full_version >= '3.12' and sys_platform != 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform == 'linux'", "python_full_version == '3.11.*' and sys_platform != 'linux'", "python_full_version < '3.11' and sys_platform != 'linux'", ] dependencies = [ - { name = "filelock", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "fsspec", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "jinja2", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "sympy", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, - { name = "typing-extensions", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "filelock", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, + { name = "fsspec", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, + { name = "jinja2", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "setuptools", marker = "(python_full_version >= '3.12' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "sympy", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, + { name = "typing-extensions", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/5f/56/9577683b23072075ed2e40d725c52c2019d71a972fab8e083763da8e707e/torch-2.9.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:1cc208435f6c379f9b8fdfd5ceb5be1e3b72a6bdf1cb46c0d2812aa73472db9e", size = 104207681, upload-time = "2025-11-12T15:19:56.48Z" }, @@ -3158,30 +3096,30 @@ dependencies = [ { name = "typing-extensions", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-win_arm64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-win_arm64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-win_arm64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:10866c8a48c4aa5ae3f48538dc8a055b99c57d9c6af2bf5dd715374d9d6ddca3" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:7210713b66943fdbfcc237b2e782871b649123ac5d29f548ce8c85be4223ab38" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-win_amd64.whl", hash = "sha256:d6e8441453dc27524e3f1037fbf27b90a02644b84e42944b9354b4024cb51cc1" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:0e611cfb16724e62252b67d31073bc5c490cb83e92ecdc1192762535e0e44487" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:3de2adb9b4443dc9210ef1f1b16da3647ace53553166d6360bbbd7edd6f16e4d" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-win_amd64.whl", hash = "sha256:69b3785d28be5a9c56ab525788ec5000349ec59132a74b7d5e954b905015b992" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-win_arm64.whl", hash = "sha256:15b4ae6fe371d96bffb8e1e9af62164797db20a0dc1337345781659cfd0b8bb1" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3bf9b442a51a2948e41216a76d7ab00f0694cfcaaa51b6f9bcab57b7f89843e6" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:7417d8c565f219d3455654cb431c6d892a3eb40246055e14d645422de13b9ea1" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:a4e06b4f441675d26b462123c8a83e77c55f1ec8ebc081203be2db1ea8054add" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-win_arm64.whl", hash = "sha256:1abe31f14b560c1f062699e966cb08ef5b67518a1cfac2d8547a3dbcd8387b06" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:3e532e553b37ee859205a9b2d1c7977fd6922f53bbb1b9bfdd5bdc00d1a60ed4" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:39b3dff6d8fba240ae0d1bede4ca11c2531ae3b47329206512d99e17907ff74b" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-win_amd64.whl", hash = "sha256:404a7ab2fffaf2ca069e662f331eb46313692b2f1630df2720094284f390ccef" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-win_arm64.whl", hash = "sha256:161decbff26a33f13cb5ba6d2c8f458bbf56193bcc32ecc70be6dd4c7a3ee79d" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:01b1884f724977a20c7da2f640f1c7b37f4a2c117a7f4a6c1c0424d14cb86322" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:031a597147fa81b1e6d79ccf1ad3ccc7fafa27941d6cf26ff5caaa384fb20e92" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-win_amd64.whl", hash = "sha256:e586ab1363e3f86aa4cc133b7fdcf98deb1d2c13d43a7a6e5a6a18e9c5364893" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:65010ab4aacce6c9a1ddfc935f986c003ca8638ded04348fd326c3e74346237c" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:88adf5157db5da1d54b1c9fe4a6c1d20ceef00e75d854e206a87dbf69e3037dc" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-win_amd64.whl", hash = "sha256:f60e2565f261542efac07e25208fb3fc55c6fe82314a5a9cbee971edb5f27713" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:3ac2b8df2c55430e836dcda31940d47f1f5f94b8731057b6f20300ebea394dd9" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:5b688445f928f13563b7418b17c57e97bf955ab559cf73cd8f2b961f8572dbb3" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-win_amd64.whl", hash = "sha256:cf9c3e50b595721ca6b488bdcc326e0f1af73ed28b9b66eff504a96649bb5c96" }, ] [[package]] @@ -3219,31 +3157,40 @@ dependencies = [ { name = "nvidia-nvtx-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, { name = "setuptools", marker = "(python_full_version >= '3.12' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, { name = "sympy", marker = "extra == 'extra-8-nanochat-gpu'" }, - { name = "triton", version = "3.5.1", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "triton", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, { name = "typing-extensions", marker = "extra == 'extra-8-nanochat-gpu'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:72f0f096475e8095a6bea3fba75bd3b46cf42c761b29588f7599314e67a32661" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c8d670aa0be6fbecd2b0e7b7d514a104dbdefcc3786ca446cf0c3415043ea40a" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-win_amd64.whl", hash = "sha256:64399adaa8ea0896d02cf844cba3c5dd77e769520a1af73572599e0eaa2cf551" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:cf4ad82430824a80a9f398e29369524ed26c152cf00c2c12002e5400b35e260d" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:2a1da940f0757621d098c9755f7504d791a72a40920ec85a4fd98b20253fca4e" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-win_amd64.whl", hash = "sha256:633005a3700e81b5be0df2a7d3c1d48aced23ed927653797a3bd2b144a3aeeb6" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:1176f250311fa95cc3bca8077af323e0d73ea385ba266e096af82e7e2b91f256" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:7cb4018f4ce68b61fd3ef87dc1c4ca520731c7b5b200e360ad47b612d7844063" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-win_amd64.whl", hash = "sha256:3a01f0b64c10a82d444d9fd06b3e8c567b1158b76b2764b8f51bfd8f535064b0" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:0b80b7555dcd0a75b7b06016991f01281a0bb078cf28fa2d1dfb949fad2fbd07" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:63381a109a569b280ed3319da89d3afe5cf9ab5c879936382a212affb5c90552" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:ad9183864acdd99fc5143d7ca9d3d2e7ddfc9a9600ff43217825d4e5e9855ccc" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2314521c74d76e513c53bb72c0ce3511ef0295ff657a432790df6c207e5d7962" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:4454a4faca31af81566e3a4208f10f20b8a6d9cfe42791b0ca7ff134326468fc" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:24420e430e77136f7079354134b34e7ba9d87e539f5ac84c33b08e5c13412ebe" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:32c036296c557f19a1537ce981c40533650097114e1720a321a39a3b08d9df56" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:7788d3d03d939cf00f93ac0da5ab520846f66411e339cfbf519a806e8facf519" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-win_amd64.whl", hash = "sha256:7bcd40cbffac475b478d6ce812f03da84e9a4894956efb89c3b7bcca5dbd4f91" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:e88c78e5b08ae9303aa15da43b68b44287ecbec16d898d9fad6998832fe626a5" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:7d8769bdf3200ca16a92f14df404c3370171ac3732996528a8973d753eac562f" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:0c784b600959ec70ee01cb23e8bc870a0e0475af30378ff5e39f4abed8b7c1cc" }, +] + +[[package]] +name = "torchao" +version = "0.15.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/57/2d/472b9362dceae05a4599e2b94f86e69a29c0e20964a6af84f34f6ead5938/torchao-0.15.0-cp310-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1cbe813201314ba6329a650a76944502f3e8ec4b1b44523f3f48676810d8d1f6", size = 7163930, upload-time = "2025-12-18T23:14:41.876Z" }, + { url = "https://files.pythonhosted.org/packages/f6/3b/6b9d5618720f63dbc2e2509cd6b57aae9c0d61b738d1d2172f4d5d9efaab/torchao-0.15.0-py3-none-any.whl", hash = "sha256:3f3812676048ef8a2a0e9d492d12d8971ba7a7ebb16f54aa56f690414e130d2c", size = 1080679, upload-time = "2025-12-18T23:14:43.807Z" }, ] [[package]] @@ -3307,41 +3254,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/6b/2f416568b3c4c91c96e5a365d164f8a4a4a88030aa8ab4644181fdadce97/transformers-4.57.3-py3-none-any.whl", hash = "sha256:c77d353a4851b1880191603d36acb313411d3577f6e2897814f333841f7003f4", size = 11993463, upload-time = "2025-11-25T15:51:26.493Z" }, ] -[[package]] -name = "triton" -version = "3.5.0" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and sys_platform == 'linux'", - "python_full_version < '3.11' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/dd/22/507b6f58a35e05e84381630b2dc2a3cee1a7a2a7eaf4cba857c638a18a24/triton-3.5.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6f90de6a6566bb619b4c0adc9855729e1b1b5e26533fca1bf6206e96b6d277a3", size = 159827599, upload-time = "2025-10-15T19:15:43.87Z" }, - { url = "https://files.pythonhosted.org/packages/0b/eb/09e31d107a5d00eb281aa7e6635ca463e9bca86515944e399480eadb71f8/triton-3.5.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d5d3b3d480debf24eaa739623c9a42446b0b77f95593d30eb1f64cd2278cc1f0", size = 170333110, upload-time = "2025-10-13T16:37:49.588Z" }, - { url = "https://files.pythonhosted.org/packages/79/f9/b6f60f978397c616fd8dacca2305759fe4f80d397b20ef72534803244bd5/triton-3.5.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8457b22148defefdcb7fa8144b05ce211b9faefad650a1ce85b23df488d5549c", size = 159926731, upload-time = "2025-10-15T19:15:49.682Z" }, - { url = "https://files.pythonhosted.org/packages/3d/78/949a04391c21956c816523678f0e5fa308eb5b1e7622d88c4e4ef5fceca0/triton-3.5.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f34bfa21c5b3a203c0f0eab28dcc1e49bd1f67d22724e77fb6665a659200a4ec", size = 170433488, upload-time = "2025-10-13T16:37:57.132Z" }, - { url = "https://files.pythonhosted.org/packages/87/9b/30988039e1e84df7554fba24e6a734d2d0e847af33cabdf9b532b3c51456/triton-3.5.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7da21fccceafc163e3a5e857abe34351ef76345af06cabf9637a914742671f0b", size = 159946647, upload-time = "2025-10-15T19:15:56.325Z" }, - { url = "https://files.pythonhosted.org/packages/f5/3a/e991574f3102147b642e49637e0281e9bb7c4ba254edb2bab78247c85e01/triton-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9e71db82261c4ffa3921cd050cd5faa18322d2d405c30eb56084afaff3b0833", size = 170476535, upload-time = "2025-10-13T16:38:05.18Z" }, - { url = "https://files.pythonhosted.org/packages/cd/85/e37f1197acb04c8f3d83851d23d5d6ed5060ef74580668b112e23fdfa203/triton-3.5.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:188da5b81fa2f8322c27fec1627703eac24cb9bb7ab0dfbe9925973bc1b070d3", size = 159958970, upload-time = "2025-10-15T19:16:01.717Z" }, - { url = "https://files.pythonhosted.org/packages/6c/29/10728de8a6e932e517c10773486b8e99f85d1b1d9dd87d9a9616e1fef4a1/triton-3.5.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e6bb9aa5519c084a333acdba443789e50012a4b851cd486c54f0b8dc2a8d3a12", size = 170487289, upload-time = "2025-10-13T16:38:11.662Z" }, - { url = "https://files.pythonhosted.org/packages/b8/1d/38258f05010ac17a7b058c022911c9cae6526e149b7397134a048cf5a6c2/triton-3.5.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:03127d9b33aaf979c856676b394bc059ec1d68cb6da68ae03f62dd8ad77a04ae", size = 160073012, upload-time = "2025-10-15T19:16:07.477Z" }, - { url = "https://files.pythonhosted.org/packages/5c/38/db80e48b9220c9bce872b0f616ad0446cdf554a40b85c7865cbca99ab3c2/triton-3.5.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c83f2343e1a220a716c7b3ab9fccfcbe3ad4020d189549200e2d2e8d5868bed9", size = 170577179, upload-time = "2025-10-13T16:38:17.865Z" }, - { url = "https://files.pythonhosted.org/packages/91/fe/8f5771d00227f4eb1ee034f218ed427102b989366d2275fe3b3c105a3921/triton-3.5.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:468936651d383f4a6d10068d34a627505e13af55be5d002b9f27b987e7a5f0ac", size = 159957460, upload-time = "2025-10-15T19:16:12.626Z" }, - { url = "https://files.pythonhosted.org/packages/ff/60/1810655d1d856c9a4fcc90ee8966d85f552d98c53a6589f95ab2cbe27bb8/triton-3.5.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:da0fa67ccd76c3dcfb0bffe1b1c57c685136a6bd33d141c24d9655d4185b1289", size = 170487949, upload-time = "2025-10-13T16:38:24.881Z" }, - { url = "https://files.pythonhosted.org/packages/78/59/99edd103958fe6e42b50b9ad8ce4f223ddf4ccf475259cf7d2b53381dc6c/triton-3.5.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c7ceef21410229ac23173a28eee5cfc0e37c1dfdb8b4bc11ecda2e3ecec7c686", size = 160075629, upload-time = "2025-10-15T19:16:18.746Z" }, - { url = "https://files.pythonhosted.org/packages/fb/b7/1dec8433ac604c061173d0589d99217fe7bf90a70bdc375e745d044b8aad/triton-3.5.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:317fe477ea8fd4524a6a8c499fb0a36984a56d0b75bf9c9cb6133a1c56d5a6e7", size = 170580176, upload-time = "2025-10-13T16:38:31.14Z" }, -] - [[package]] name = "triton" version = "3.5.1" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and sys_platform == 'linux'", - "python_full_version < '3.11' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/d9/2e/f95e673222afa2c7f0c687d8913e98fcf2589ef0b1405de76894e37fe18f/triton-3.5.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f63e34dcb32d7bd3a1d0195f60f30d2aee8b08a69a0424189b71017e23dfc3d2", size = 159821655, upload-time = "2025-11-11T17:51:44.09Z" }, { url = "https://files.pythonhosted.org/packages/fd/6e/676ab5019b4dde8b9b7bab71245102fc02778ef3df48218b298686b9ffd6/triton-3.5.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5fc53d849f879911ea13f4a877243afc513187bc7ee92d1f2c0f1ba3169e3c94", size = 170320692, upload-time = "2025-11-11T17:40:46.074Z" }, From a67eba35dce1ce3f7276b6913bcedf4cbb75b506 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 3 Feb 2026 20:54:30 +0000 Subject: [PATCH 10/55] add feb2 new leaderboard record from upgrading to fp8 training, +4.3% speedup to time to GPT-2 --- README.md | 36 ++++++----------------------- dev/LOG.md | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 6283437..e96b5a7 100644 --- a/README.md +++ b/README.md @@ -14,37 +14,15 @@ For questions about the repo, I recommend either using [DeepWiki](https://deepwi ## Leaderboard -| # | Record time | Description | Date | Commit | Contributors | -|---|-------------|-------------|------|--------|--------------| -| 1 | 3.04 hours | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy | +| # | Record time | val_bpb | CORE | Description | Date | Commit | Contributors | +|---|-------------|---------|------|-------------|------|--------|--------------| +| 0 | 168 hours | - | 0.256525 | Original OpenAI GPT-2 checkpoint | 2019 | - | OpenAI | +| 1 | 3.04 | 0.74833 | 0.25851 | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy | +| 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | TODO | @karpathy | -The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. In 2019, the training of GPT-2 cost approximately $50,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so in 3 hours or less, for ~$73 and below. Once your repo is set up (see the [runs/speedrun.sh](runs/speedrun.sh) script for reference), e.g. the way I kicked off the jan29 run is as follows: +The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $50,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 3 hours is ~$72). -``` -OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ - --depth=24 \ - --run=d24-jan29 \ - --model-tag=d24_jan29 \ - --device-batch-size=16 \ - --sample-every=-1 \ - --save-every=-1 \ - --core-metric-max-per-task=-1 \ - --core-metric-every=3000 \ - --target-param-data-ratio=12 -``` - -After 3 hours we get output like this: - -``` -... -wandb: Run summary: -wandb: core_metric 0.25851 -wandb: step 16704 -wandb: total_training_flops 4.330784131228946e+19 -wandb: total_training_time 10949.46713 -``` - -The GPT-2 CORE score (i.e. the target to beat) is 0.256525. So we see that this d24 CORE score is higher (0.25851). Then we look at the `total_training_time`, which is the time of the training iterations alone, excluding all the evaluations and logging, in seconds. We get: `10949/60/60 ~= 3.04` hours, the current record. +See [dev/LEADERBOARD.md](dev/LEADERBOARD.md) for more docs on how to interpret and contribute to the leaderboard. ## Getting started diff --git a/dev/LOG.md b/dev/LOG.md index dd11b42..8cdef87 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,74 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-02: FP8 Training with torchao + +Integrated FP8 training using `torchao.float8` to accelerate Linear layer matmuls on H100 GPUs. + +### Background + +FP8 (8-bit floating point) uses H100's FP8 tensor cores for ~2x theoretical matmul throughput. The tradeoff is quantization overhead: computing scales and casting tensors to/from FP8. Still, as an example torchtitan (Meta's distributed training framework) reports 25-28% speedups with FP8 for some of their experiments. + +**Previous attempt (Jan 2026):** FP8 on just `lm_head` following modded-nanogpt with custom ops → 1% speedup, +2GB memory. Failed due to fragile torch.compile interaction. But this experiment was also done on ~d12 scale back then instead of the bigger model that gets GPT-2 capability of approx d24. + +**This attempt:** Use torchao's `convert_to_float8_training()` on ALL Linear layers, increase model size to d24. The core snippet is: + +```python +from torchao.float8 import Float8LinearConfig, convert_to_float8_training +config = Float8LinearConfig.from_recipe_name("tensorwise") +convert_to_float8_training(model, config=config) +``` + +But in practice it's more involved (see base_train.py). + +### Results + +**Microbenchmark (d26 MLP, 65536x1664 @ 1664x6656):** + +| Method | Forward | Fwd+Bwd | Speedup | +|--------|---------|---------|---------| +| BF16 + compile | 2.00ms | 4.79ms | 1.00x | +| FP8 rowwise + compile | 1.84ms | 4.55ms | 1.08x | +| FP8 tensorwise + compile | 1.45ms | 4.06ms | **1.38x** | +| FP8 rowwise (no compile) | 2.89ms | 21.86ms | 0.23x ❌ | + +torch.compile is MANDATORY. Without it, FP8 is 4x slower due to unfused scaling ops. + +**Full training (d26):** + +| Config | tok/sec | vs baseline | +|--------|---------|-------------| +| BF16 baseline | 630K | 1.00x | +| FP8 rowwise | 564K | 0.90x ❌ | +| FP8 tensorwise | 740K | **1.17x** ✓ | + +Memory usage also decreases quite a bit, by ~9GB (activations stored as FP8 instead of BF16). + +Seeing 17% speedup is encouraging but we're still not done yet because each step is now in lower precision and less powerful individually, so to make up for the precision drop we have to train longer. Empirically, running some sweeps overnight on d24 scale, I saw that the actual speedup (when you match performance) is closer to 5%. It's possible that our LLMs at ~d24 scale are still too small to confidently enjoy the speedups that come from fp8 for bigger models. + +### Key Learnings + +For nanochat at approximate scale of interest (~GPT-2 capability, ~d24): + +1. **Tensorwise >> Rowwise** - Rowwise computes per-row scales, overhead exceeds benefit. Tensorwise uses one scale per tensor. +2. **Filter small layers** - Layers with dims not divisible by 16 must be skipped (FP8 hardware requirement) +3. **Larger models benefit more** - d12 was still slower with FP8; d26+ shows gains. Therefore, in some depths there is a benefit to fp8 and in some there isn't. Keeping it configurable for now, passed in via kwargs and default off. +4. **The effective, capability-matched speedup is lower still** - because each step is of slightly lower precision/quality. + +### Integration + +Added `--fp8` flag to `base_train.py`, default recipe is "tensorwise", example of turning on: + +```bash +torchrun --nproc_per_node=8 -m scripts.base_train --depth=24 --fp8 +``` + +Uses tensorwise by default. Requires `torchao==0.15.0` (compatible with torch 2.9.1), which was added to dependencies. + +**TLDR**: turning on fp8 for GPT-2 capability nanochat model gives approx +5% capability-matched speedup. + +--- + ## 2026-01-29: Hyperball/MuonH Experiments (Negative Result) Explored Hyperball optimization from [this post](https://psychedelic-sunstone-851.notion.site/Fantastic-Pretraining-Optimizers-and-Where-to-Find-Them-2-1-Hyperball-Optimization-2e924306e6f280e7a5ffee00eb40a0dd) (saved to `knowledge/muonh.md`). Constrains weights to sphere of radius R (initial norm): `W_{t+1} = R · Normalize(W_t - η·R · Normalize(u_t))`. Had to change a number of details in a branch, e.g. not use zero init for our projections (or the initial norm would be zero), keep track of the initial norm, adjust Muon -> MuonH for the update. From fe55b092b8a2e3b46411a9e2e9acd8cc0f6788d1 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 3 Feb 2026 21:05:28 +0000 Subject: [PATCH 11/55] minor cosmetics for the table --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e96b5a7..08c184a 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,11 @@ For questions about the repo, I recommend either using [DeepWiki](https://deepwi ## Leaderboard -| # | Record time | val_bpb | CORE | Description | Date | Commit | Contributors | +| # | time | val_bpb | CORE | Description | Date | Commit | Contributors | |---|-------------|---------|------|-------------|------|--------|--------------| -| 0 | 168 hours | - | 0.256525 | Original OpenAI GPT-2 checkpoint | 2019 | - | OpenAI | -| 1 | 3.04 | 0.74833 | 0.25851 | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy | -| 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | TODO | @karpathy | +| 0 | 168 hours | - | 0.2565 | Original OpenAI GPT-2 checkpoint | 2019 | - | OpenAI | +| 1 | 3.04 | 0.74833 | 0.2585 | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy | +| 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | 8309b83 | @karpathy | The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $50,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 3 hours is ~$72). From 16b8ac7da33010dc7efcd9292f895703f5cff33a Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 3 Feb 2026 21:06:12 +0000 Subject: [PATCH 12/55] oops forgot to attach leaderboard file too --- dev/LEADERBOARD.md | 119 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 dev/LEADERBOARD.md diff --git a/dev/LEADERBOARD.md b/dev/LEADERBOARD.md new file mode 100644 index 0000000..3b61cc6 --- /dev/null +++ b/dev/LEADERBOARD.md @@ -0,0 +1,119 @@ +# Leaderboard + +Docs on participating in the "Time-to-GPT-2" leaderboard of nanochat. + +The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. Originally in 2019, GPT-2 was trained by OpenAI on 32 TPU v3 chips for 168 hours (7 days), with $8/hour/TPUv3 back then, for a total cost of approx. $43K. It achieves 0.256525 CORE score, which is an ensemble metric introduced in the DCLM paper over 22 evaluations like ARC/MMLU/etc. + +## How to + +The script [runs/speedrun.sh](runs/speedrun.sh) always implements the current state of the art on the leaderboard. + +In practice, I tune the base_train command a little bit. For example, once all the setup is configured and a tokenizer is trained, I like to do something like: + +``` +OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --depth=26 \ + --run="d26-feb2-fp8-ratio8.25" \ + --model-tag="d26_feb2_fp8_ratio8.25" \ + --device-batch-size=16 \ + --sample-every=-1 \ + --save-every=-1 \ + --core-metric-max-per-task=-1 \ + --core-metric-every=999999 \ + --target-param-data-ratio=8.25 \ + --fp8 +``` + +Note that: + +- `depth` controls the size of the Transformer +- `run` is the wandb name +- `model-tag` is the location of the checkpoints on disk +- `device-batch-size` in the ideal world, you want this to be 32 because with sequence length of 2048 (the default) and 8 GPUs we get `32 X 2048 X 8 = 524,288`, which is the total desired batch size determined to work fairly well around this scale. However, for bigger (e.g. d26), 32 is too much and OOMs, so we decrease it by 2 to 16. The `base_train.py` script automatically compensates for this by calculating that it has to use gradient accumulation of 2 to meet the desired total batch size. Therefore, it will fo forward+backward twice and then a single step. Long story short, the ideal value is 32. If that doesn't fit, you decrease it, e.g. 16, 8, etc., keeping it powers of two so that the gradient accumulation math works out neatly. +- `sample-every = -1` turns off periodic sampling +- `core-metric-max-per-task=-1` means we run the entire CORE eval +- `core-metric-every=999999` a bit of a hacky way to make the CORE eval only happen a single time at the very end of the run +- `target-param-data-ratio=8.25` controls the training horizon, which is determined in the script by taking the number of non-embedding model parameters and simply multiplying by this number. The current optimal Tokens:Params ratio can be seen in the defaults of the `base_train.py` script (it is 10.5). 10.5 would produce the *compute optimal* model given the currently measured scaling laws. However, GPT-2 capability is currently somewhere in between a d24 and d26. So to reach it exactly, we want to either overtrain d24 or undertrain d26. In this particular example, I am choosing to slightly undertrain a d26. Note that odd depths (e.g. d25) are not super recommended to use because the math around the transformer sizing and its head dimensions doesn't come out neatly. +- `--fp8` turns on fp8 training. If you GPU does not support fp8, you can leave this out and the code will simply train in bf16. bf16 is higher precision than fp8, so you can actually expect that you might be able to do fewer steps (lower the `target-param-data-ratio`) to achieve the same capability. + +Once you kick off the run, you wait ~3 hours and then at the end you'll see something like: + +``` +wandb: Run summary: +wandb: core_metric 0.25851 +wandb: step 16704 +wandb: total_training_flops 4.330784131228946e+19 +wandb: total_training_time 10949.46713 +``` + +Your CORE metric must be greater than GPT-2 0.256525. Then you report the `total_training_time`, (e.g. 10949) which is the time of the training iterations alone, excluding all the evaluations and logging, in seconds. So here for example here it is roughly 10949/60/60 ~= 3.04 hours. You should also note and report the validation bpb of your run because the CORE metric can be a little bit noisy. + +If you outperform GPT-2 and the time is less than current SOTA in the Leaderboard, you get to make a PR. In addition to raw gains, there are some qualitative and aesthetic considerations that go into whether your improvement is merged. For example, if it is gnarly or it significantly bloats the code, or it seems too esoteric, then we will way those things against the improvement demonstrated. Additionally, nanochat cares not only about targeting a single model, but an entire miniseries of models. So your change must be principled enough that it can easily generalize to other model depths, so that we can sweep out a miniseries. + +After you create the commit, to get the current short git commit hash: + +``` +git log -1 --format="%h" +``` + +## Run 1 + +Achieved Jan 29 2026 on commit `348fbb3`. The launch command was + +``` +OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --depth=24 \ + --run=d24-jan29 \ + --model-tag=d24_jan29 \ + --device-batch-size=16 \ + --sample-every=-1 \ + --save-every=-1 \ + --core-metric-max-per-task=-1 \ + --core-metric-every=3000 \ + --target-param-data-ratio=12 +``` + +The result was: + +``` +wandb: Run summary: +wandb: core_metric 0.25851 +wandb: step 16704 +wandb: total_training_flops 4.330784131228946e+19 +wandb: total_training_time 10949.46713 +``` + +The validation bpb was 0.74833. + +Detailed writeup: [Beating GPT-2 for <<$100: the nanochat journey](https://github.com/karpathy/nanochat/discussions/481) + +## Run 2 + +Achieved Feb 2 2026 on commit `8309b83`. The launch command was + +``` +OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --depth=26 \ + --run="d26-feb2-fp8-ratio8.5" \ + --model-tag="d26_feb2_fp8_ratio8.5" \ + --device-batch-size=16 \ + --sample-every=-1 \ + --save-every=-1 \ + --core-metric-max-per-task=-1 \ + --core-metric-every=999999 \ + --target-param-data-ratio=8.5 \ + --fp8 +``` + +The result was: + +``` +core_metric 0.2578 +step 14889 +total_training_time 10493 +Minimum validation bpb: 0.745036 +``` + +The big change in this run is `--fp8`, which causes all Linear layers (other than the gates) to be switched to fp8 training using `torchao` with tensorwise fp8 scaling. Each step is of slightly lower quality, but we are taking them a lot faster, coming out net ahead. Anyone who does not have fp8 (e.g. using a GPU without it) can simply leave out the `--fp8` flag to train in bfloat16. This will work just fine but it will produce a slightly stronger model than GPT-2 because of the fp8 -> bf16 precision upgrade. It's possible that one can further tune which layers to include in the fp8 conversion and that e.g. some of the smaller matmuls should be just kept in bf16 etc. + +Previous record was 3.04 hours, so 2.91 hours is `(3.04 - 2.91)/3.04*100` ~= 4.3% speed improvement. From d510b1385b04a77d3f8777ed2b3c1064f2488c53 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 3 Feb 2026 23:21:39 +0000 Subject: [PATCH 13/55] quick experiments to log --- dev/LOG.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/dev/LOG.md b/dev/LOG.md index 8cdef87..908fac1 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,24 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-03: Flip Muon MLP LR Multiplier (PR #492) + +Tested flipping the shape-based LR heuristic in Muon from boosting tall matrices (input projections like `c_fc`) to boosting wide matrices (output projections like `c_proj`). The original code applies `max(1, rows/cols)^0.5`, giving ~2x LR to `c_fc`. The flipped version gives ~2x LR to `c_proj` instead, which aligns with classical fan-in/fan-out scaling conventions. This was proposed in [PR #492](https://github.com/karpathy/nanochat/pull/492) and showed improvements in modded-nanogpt. + +**Result:** Quick d12 experiment: slightly worse **Not adopted.** + +--- + +## 2026-02-03: Skip AdamW Every Other Step + +Inspired by modded-nanogpt, tried stepping AdamW only on odd iterations while Muon steps every iteration. The idea is that small AdamW params (embeddings, scalars, gates) don't need updates as frequently as the large weight matrices, and skipping saves both compute and communication. + +Added `skip_adamw` parameter to `MuonAdamW.step()` and `DistMuonAdamW.step()` plus a matching `zero_grad(skip_adamw=...)` to let AdamW gradients accumulate over 2 steps. Used `lr *= 2**-0.5` (sqrt scaling) to compensate for the 2x effective batch size on AdamW params. + +**Result:** for nanochat d12, we see ~2% faster tok/s, but each step is slightly worse in loss. On net, when plotting against wall clock time, it's slightly worse. **Not adopted.** + +--- + ## 2026-02-02: FP8 Training with torchao Integrated FP8 training using `torchao.float8` to accelerate Linear layer matmuls on H100 GPUs. From 542beb0c8c175af2d52ec7065345dcd8f0162368 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 4 Feb 2026 02:12:04 +0000 Subject: [PATCH 14/55] bump speedrun to be the up to date leaderboard run --- runs/speedrun.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runs/speedrun.sh b/runs/speedrun.sh index d390c6d..c423ba6 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -70,7 +70,7 @@ echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # d24 model (slightly overtrained is enough to beat GPT-2 => increase data:params ratio from compute optimal 10.5 (default) to 12) -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=12 --device-batch-size=16 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --target-param-data-ratio=8.5 --device-batch-size=16 --fp8 --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 From 718e5e9d67be7e92bd158f5471e1ce53de2e6a64 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 5 Feb 2026 01:39:26 +0000 Subject: [PATCH 15/55] correctly reference NorMuon and fix misleading terms that i may have hastily ported over from modded-nanogpt --- dev/LOG.md | 6 +++--- nanochat/optim.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/dev/LOG.md b/dev/LOG.md index 908fac1..71cb18d 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -733,8 +733,8 @@ Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon i - Both methods kept in code for easy comparison (`zeropower_via_polar_express` vs `zeropower_via_newtonschulz5`) - **Result:** No dramatic/noticeable difference in training, but keeping the new Polar Express as default. -**2. Variance Reduction (NorMuon-style)** -- Added low-rank variance estimator similar to Adafactor ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491)) +**2. NorMuon Variance Reduction** +- Added per-neuron/column adaptive learning rate from NorMuon ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491)) - Maintains `second_momentum_buffer` with shape `[rows, 1]` or `[1, cols]` (whichever is smaller) - Normalizes updates based on running per-row/col variance estimate (beta2=0.95) - Memory overhead: ~1/max(rows, cols) per param, negligible @@ -776,7 +776,7 @@ Example: If d12 optimal is 0.22, then d20 optimal ≈ 0.22 × (12/20)² ≈ 0.08 ### Summary -Muon was changed to use Polar Express, added Adafactor-style variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth. +Muon was changed to use Polar Express, added NorMuon variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth. --- diff --git a/nanochat/optim.py b/nanochat/optim.py index 190a1ed..4cc2a1f 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -67,6 +67,10 @@ Polar Express Sign Method for orthogonalization. https://arxiv.org/pdf/2505.16932 by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. +NorMuon variance reduction: per-neuron/column adaptive learning rate that normalizes +update scales after orthogonalization (Muon's output has non-uniform scales across neurons). +https://arxiv.org/pdf/2510.05491 + Some of the changes in nanochat implementation: - Uses a simpler, more general approach to parameter grouping and stacking - Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step From d63b7ab9acabaa97f2ea2f2410e5255be525a398 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 5 Feb 2026 02:41:46 +0000 Subject: [PATCH 16/55] try and fail relu^2 -> swiglu --- dev/LOG.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/dev/LOG.md b/dev/LOG.md index 71cb18d..b344b23 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,19 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-05: SwiGLU Activation (Negative Result) + +Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). Implementation uses three projections (w1, w2, w3) with hidden_dim scaled to 8/3×n_embd to preserve both parameter count and FLOPs exactly (1.00x match on both). + +```python +# Old: x = c_proj(relu(c_fc(x)).square()) +# New: x = w3(silu(w1(x)) * w2(x)) +``` + +Tested at both d12 and d24 (GPT-2 scale). Worse on all measures — step efficiency, wall clock time, and FLOPs. ReLU² remains superior for nanochat. **Not adopted.** + +--- + ## 2026-02-03: Flip Muon MLP LR Multiplier (PR #492) Tested flipping the shape-based LR heuristic in Muon from boosting tall matrices (input projections like `c_fc`) to boosting wide matrices (output projections like `c_proj`). The original code applies `max(1, rows/cols)^0.5`, giving ~2x LR to `c_fc`. The flipped version gives ~2x LR to `c_proj` instead, which aligns with classical fan-in/fan-out scaling conventions. This was proposed in [PR #492](https://github.com/karpathy/nanochat/pull/492) and showed improvements in modded-nanogpt. From 1144d186ed4bd7ea949bddca03612922402ab198 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 5 Feb 2026 02:42:46 +0000 Subject: [PATCH 17/55] try and fail relu^2 -> swiglu --- dev/LOG.md | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/dev/LOG.md b/dev/LOG.md index b344b23..02561ac 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -6,11 +6,24 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 ## 2026-02-05: SwiGLU Activation (Negative Result) -Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). Implementation uses three projections (w1, w2, w3) with hidden_dim scaled to 8/3×n_embd to preserve both parameter count and FLOPs exactly (1.00x match on both). +Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). SwiGLU uses three projections instead of two, so to match parameters and FLOPs we scale hidden_dim from 4× to 8/3×: ```python -# Old: x = c_proj(relu(c_fc(x)).square()) -# New: x = w3(silu(w1(x)) * w2(x)) +# Old ReLU²: 2 matrices, 4x expansion +# params: 2 × n × 4n = 8n² +# flops: 2 × 2n × 4n = 16n² per token +self.c_fc = Linear(n_embd, 4 * n_embd) +self.c_proj = Linear(4 * n_embd, n_embd) +x = c_proj(relu(c_fc(x)).square()) + +# New SwiGLU: 3 matrices, 8/3x expansion +# params: 2 × n × (8n/3) + (8n/3) × n = 8n² ✓ matches +# flops: 3 × 2n × (8n/3) = 16n² per token ✓ matches +hidden_dim = (8 * n_embd) // 3 +self.w1 = Linear(n_embd, hidden_dim) # gate +self.w2 = Linear(n_embd, hidden_dim) # up +self.w3 = Linear(hidden_dim, n_embd) # down +x = w3(silu(w1(x)) * w2(x)) ``` Tested at both d12 and d24 (GPT-2 scale). Worse on all measures — step efficiency, wall clock time, and FLOPs. ReLU² remains superior for nanochat. **Not adopted.** From 75b302f331fa6b678bac0c4ffabfb5012555bd08 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 5 Feb 2026 16:14:28 +0000 Subject: [PATCH 18/55] fix hash commit on leaderboard and a paragraph clarification --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 08c184a..cf07ef3 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ For questions about the repo, I recommend either using [DeepWiki](https://deepwi |---|-------------|---------|------|-------------|------|--------|--------------| | 0 | 168 hours | - | 0.2565 | Original OpenAI GPT-2 checkpoint | 2019 | - | OpenAI | | 1 | 3.04 | 0.74833 | 0.2585 | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy | -| 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | 8309b83 | @karpathy | +| 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | a67eba3 | @karpathy | The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $50,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 3 hours is ~$72). @@ -71,7 +71,7 @@ OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train This uses wandb (run name "d12"), only runs the CORE metric on last step, and it doesn't sample and save intermediate checkpoints. I like to change something in the code, re-run a d12 (or a d16 etc) and see if it helped, in an iteration loop. -The overall approach is to treat the depth of the model as the single dial of complexity. By sweeping out the depth, we get increasingly more powerful models. We determine the scaling laws, set the data budget to a compute optimal setting, train a whole miniseries of models of increasing sizes, and compare them to the GPT-2 and GPT-3 miniseries. Right now, beating GPT-2 specifically faster and faster is the most interesting target. +The important thing to note is that nanochat is written and configured around one single dial of complexity - the depth of the transformer. This single integer automatically determines all other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) so that the trained model comes out compute optimal. The idea is that the user doesn't have to think about or set any of this, they are simply asking for a smaller or bigger model using `--depth`, and everything "just works". By sweeping out the depth, you achieve the nanochat miniseries of compute optimal models at various sizes. GPT-2 capability model (which is of most interest at the moment) happens to be somewhere around d24-d26 range with the current code. But any candidate changes to the repo have to be principled enough that they work for all settings of depth. ## Running on CPU / MPS From 012da1a78bc2b7bb70dbdeba41f987ebf5f7b2e1 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Thu, 5 Feb 2026 19:12:50 +0100 Subject: [PATCH 19/55] Typo fixes (#480) * small typo * few more small fixes * small fixes in leaderboard.md --- README.md | 8 ++++---- dev/LEADERBOARD.md | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index cf07ef3..4911090 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ For questions about the repo, I recommend either using [DeepWiki](https://deepwi ## Updates - (Jan 31 2026) Major revamp of all scripts/README ongoing, deleting midtraining stage, might be a bit messy briefly... -- (Jan 30 2026) With all the latest improvements we're able to train GPT-2 grade LLM in about $73. The [runs/speedrun.sh](runs/speedrun.sh) script will become the refernece way to train GPT-2 grade model and talk to it. +- (Jan 30 2026) With all the latest improvements we're able to train GPT-2 grade LLM in about $73. The [runs/speedrun.sh](runs/speedrun.sh) script will become the reference way to train GPT-2 grade model and talk to it. ## Leaderboard @@ -28,13 +28,13 @@ See [dev/LEADERBOARD.md](dev/LEADERBOARD.md) for more docs on how to interpret a ### Reproduce and talk to GPT-2 -The most fun you can have is to train your own GPT-2 and talk to it. The entire pipeline to do so is contained in the single file [runs/speedrun.sh](runs/speedrun.sh), which is designed to be run on an 8XH100 GPU node. Currently, at ~$24/hour for these nodes, pretraining GPT-2 grade model takes approximately 3 hours and will set you back about $75. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script: +The most fun you can have is to train your own GPT-2 and talk to it. The entire pipeline to do so is contained in the single file [runs/speedrun.sh](runs/speedrun.sh), which is designed to be run on an 8XH100 GPU node. Currently, at ~$24/hour for these nodes, pretraining a GPT-2 grade model takes approximately 3 hours and will set you back about $75. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script: ```bash bash runs/speedrun.sh ``` -You mish to do so in a screen session as this will take ~3 hours to run. Once it's done, you can talk to it via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it: +You may wish to do so in a screen session as this will take ~3 hours to run. Once it's done, you can talk to it via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it: ```bash python -m scripts.chat_web @@ -75,7 +75,7 @@ The important thing to note is that nanochat is written and configured around on ## Running on CPU / MPS -The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM tha tis being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way. +The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM that is being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way. ## Guides diff --git a/dev/LEADERBOARD.md b/dev/LEADERBOARD.md index 3b61cc6..6c2720c 100644 --- a/dev/LEADERBOARD.md +++ b/dev/LEADERBOARD.md @@ -29,12 +29,12 @@ Note that: - `depth` controls the size of the Transformer - `run` is the wandb name - `model-tag` is the location of the checkpoints on disk -- `device-batch-size` in the ideal world, you want this to be 32 because with sequence length of 2048 (the default) and 8 GPUs we get `32 X 2048 X 8 = 524,288`, which is the total desired batch size determined to work fairly well around this scale. However, for bigger (e.g. d26), 32 is too much and OOMs, so we decrease it by 2 to 16. The `base_train.py` script automatically compensates for this by calculating that it has to use gradient accumulation of 2 to meet the desired total batch size. Therefore, it will fo forward+backward twice and then a single step. Long story short, the ideal value is 32. If that doesn't fit, you decrease it, e.g. 16, 8, etc., keeping it powers of two so that the gradient accumulation math works out neatly. +- `device-batch-size` in the ideal world, you want this to be 32 because with sequence length of 2048 (the default) and 8 GPUs we get `32 X 2048 X 8 = 524,288`, which is the total desired batch size determined to work fairly well around this scale. However, for bigger (e.g. d26), 32 is too much and OOMs, so we decrease it by 2 to 16. The `base_train.py` script automatically compensates for this by calculating that it has to use gradient accumulation of 2 to meet the desired total batch size. Therefore, it will do forward+backward twice and then a single step. Long story short, the ideal value is 32. If that doesn't fit, you decrease it, e.g. 16, 8, etc., keeping it powers of two so that the gradient accumulation math works out neatly. - `sample-every = -1` turns off periodic sampling - `core-metric-max-per-task=-1` means we run the entire CORE eval - `core-metric-every=999999` a bit of a hacky way to make the CORE eval only happen a single time at the very end of the run - `target-param-data-ratio=8.25` controls the training horizon, which is determined in the script by taking the number of non-embedding model parameters and simply multiplying by this number. The current optimal Tokens:Params ratio can be seen in the defaults of the `base_train.py` script (it is 10.5). 10.5 would produce the *compute optimal* model given the currently measured scaling laws. However, GPT-2 capability is currently somewhere in between a d24 and d26. So to reach it exactly, we want to either overtrain d24 or undertrain d26. In this particular example, I am choosing to slightly undertrain a d26. Note that odd depths (e.g. d25) are not super recommended to use because the math around the transformer sizing and its head dimensions doesn't come out neatly. -- `--fp8` turns on fp8 training. If you GPU does not support fp8, you can leave this out and the code will simply train in bf16. bf16 is higher precision than fp8, so you can actually expect that you might be able to do fewer steps (lower the `target-param-data-ratio`) to achieve the same capability. +- `--fp8` turns on fp8 training. If your GPU does not support fp8, you can leave this out and the code will simply train in bf16. bf16 is higher precision than fp8, so you can actually expect that you might be able to do fewer steps (lower the `target-param-data-ratio`) to achieve the same capability. Once you kick off the run, you wait ~3 hours and then at the end you'll see something like: @@ -46,9 +46,9 @@ wandb: total_training_flops 4.330784131228946e+19 wandb: total_training_time 10949.46713 ``` -Your CORE metric must be greater than GPT-2 0.256525. Then you report the `total_training_time`, (e.g. 10949) which is the time of the training iterations alone, excluding all the evaluations and logging, in seconds. So here for example here it is roughly 10949/60/60 ~= 3.04 hours. You should also note and report the validation bpb of your run because the CORE metric can be a little bit noisy. +Your CORE metric must be greater than GPT-2 0.256525. Then you report the `total_training_time`, (e.g. 10949) which is the time of the training iterations alone, excluding all the evaluations and logging, in seconds. So here for example it is roughly 10949/60/60 ~= 3.04 hours. You should also note and report the validation bpb of your run because the CORE metric can be a little bit noisy. -If you outperform GPT-2 and the time is less than current SOTA in the Leaderboard, you get to make a PR. In addition to raw gains, there are some qualitative and aesthetic considerations that go into whether your improvement is merged. For example, if it is gnarly or it significantly bloats the code, or it seems too esoteric, then we will way those things against the improvement demonstrated. Additionally, nanochat cares not only about targeting a single model, but an entire miniseries of models. So your change must be principled enough that it can easily generalize to other model depths, so that we can sweep out a miniseries. +If you outperform GPT-2 and the time is less than current SOTA in the Leaderboard, you get to make a PR. In addition to raw gains, there are some qualitative and aesthetic considerations that go into whether your improvement is merged. For example, if it is gnarly or it significantly bloats the code, or it seems too esoteric, then we will weigh those things against the improvement demonstrated. Additionally, nanochat cares not only about targeting a single model, but an entire miniseries of models. So your change must be principled enough that it can easily generalize to other model depths, so that we can sweep out a miniseries. After you create the commit, to get the current short git commit hash: From 98eed6df189e395056c34621043d082878df392f Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 5 Feb 2026 18:14:09 +0000 Subject: [PATCH 20/55] bring back an assert guarding against bad param sizing --- nanochat/optim.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nanochat/optim.py b/nanochat/optim.py index 4cc2a1f..42d862b 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -377,6 +377,7 @@ class DistMuonAdamW(torch.optim.Optimizer): param_infos[p] = dict(future=future, grad_slice=grad, is_small=True) else: # Large params: reduce_scatter + assert grad.shape[0] % world_size == 0, f"AdamW reduce_scatter requires shape[0] ({grad.shape[0]}) divisible by world_size ({world_size})" rank_size = grad.shape[0] // world_size grad_slice = torch.empty_like(grad[:rank_size]) future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future() From f41dd3cbd76f82a46624e16c81c6131ecc4d205d Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 5 Feb 2026 19:40:37 +0000 Subject: [PATCH 21/55] auto-calculate optimal batch size. the original setting of 0.5M was only optimal for d12, but d26 prefers 1M and so on --- dev/LOG.md | 46 ++++++++++ scripts/base_train.py | 201 +++++++++++++++++++++++------------------- 2 files changed, 156 insertions(+), 91 deletions(-) diff --git a/dev/LOG.md b/dev/LOG.md index 02561ac..6ce4173 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,52 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-05: Auto Batch Size Scaling + +### Background + +So far, the `--total-batch-size` was hardcoded to be `2**19 = 524,288` ~= 0.5M tokens. This was the optimal setting for d12, but when I tried to re-tune it for d26 (GPT-2), I noticed that the optimal was closer to `2**20 = 1,048,576` ~= 1M tokens. This is to be expected - larger models prefer a higher optimal total batch size. However, we have to make sure that all settings of `--depth` get their own optimal batch size calculated in some principled way. Here, I referenced the "Power Lines" paper from Cerebras ([arXiv:2505.13738](https://arxiv.org/abs/2505.13738)) for a lot of related experimentation. In particular, they found that **Bopt ∝ D^0.383** (where D is the number of training tokens, not the number of parameters!). So the idea is to tune the optimal batch size on d12, and then extrapolate it with this power law to bigger models. The 0.383 exponent means batch size grows slowly: 10× more tokens only justifies ~2.4× bigger batch. For nanochat's compute-optimal training (D ∝ N via `--target-param-data-ratio`), this means deeper models naturally want larger batches. + +### Implementation + +Added `--total-batch-size=-1` (now the default) to auto-compute optimal batch: + +```python +get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head'] +if args.total_batch_size == -1: + D_REF = args.target_param_data_ratio * get_scaling_params(build_model_meta(12)) + B_REF = 2**19 + args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383)) +``` + +Reference point: d=12 model with B=2^19 (empirically validated). The reference is computed dynamically so that if the architecture changes (e.g., different `--aspect-ratio`), the math automatically adjusts. However, if the model actually does change too much, one would also want to re-tune the optimal batch size for d=12. + +### Results + +With this formula, we currently get: + +| Depth | Scaling Params | Target Tokens | Auto Batch | +|-------|---------------|---------------|------------| +| d=8 | 42M | 0.44B | 2^18 = 262K | +| d=10-16 | 70M-235M | 0.7B-2.5B | 2^19 = 524K | +| d=18-26 | 324M-918M | 3.4B-9.6B | 2^20 = 1.05M | +| d=32-50 | 1.7B-6.2B | 17.6B-65.6B | 2^21 = 2.1M | + +In particular, this matches empirical observations that d26 prefers ~2^20 while d12 prefers ~2^19. + +### Code Cleanup + +Also refactored model initialization to use `build_model_meta(depth)` helper and `dataclasses.asdict()` for cleaner config handling. + +### Useful references + +- [Bergsma et al., Power Laws for Batch Size, Model Size, and Training Horizon](https://arxiv.org/abs/2505.13738) +- [McCandlish et al., An Empirical Model of Large-Batch Training](https://arxiv.org/abs/1812.06162) +- [Brown et al., Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165) +- [Merrill et al., The Batch Size–Critical Batch Size Myth](https://arxiv.org/abs/2505.23971) + +--- + ## 2026-02-05: SwiGLU Activation (Negative Result) Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). SwiGLU uses three projections instead of two, so to match parameters and FLOPs we scale hidden_dim from 4× to 8/3×: diff --git a/scripts/base_train.py b/scripts/base_train.py index fa05b60..97264c8 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -11,11 +11,14 @@ If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Ex python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 """ -import gc import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" -import argparse +import gc +import json import time +import math +import argparse +from dataclasses import asdict from contextlib import nullcontext, contextmanager import wandb @@ -53,8 +56,8 @@ parser.add_argument("--num-iterations", type=int, default=-1, help="explicit num parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") # Optimization -parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") -parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") +parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.") +parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") @@ -78,8 +81,8 @@ parser.add_argument("--model-tag", type=str, default=None, help="override model args = parser.parse_args() user_config = vars(args).copy() # for logging # ----------------------------------------------------------------------------- +# Compute init and wandb logging -# Compute init device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. @@ -109,65 +112,39 @@ else: print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.") print0("!" * 80) -# Tokenizer will be useful for evaluation, also we need the vocab size +# ----------------------------------------------------------------------------- +# Tokenizer will be useful for evaluation and also we need the vocab size to init the model tokenizer = get_tokenizer() token_bytes = get_token_bytes(device=device) vocab_size = tokenizer.get_vocab_size() print0(f"Vocab size: {vocab_size:,}") -# Model kwargs are derived from the desired depth of the model -# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division -# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly) -# (For very small depths, this gives a slight "unfair" advantage to models with odd depths) -num_layers = args.depth -base_dim = args.depth * args.aspect_ratio -model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim -num_heads = model_dim // args.head_dim -num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled) -head_dim = model_dim // num_heads -print0(f"num_layers: {num_layers}") -print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})") -print0(f"num_heads: {num_heads}") -print0(f"head_dim: {head_dim}") -print0(f"num_kv_heads: {num_kv_heads}") - -# Optimizer / data / training length related hyperparameters -# figure out the needed gradient accumulation to reach the desired total batch size -tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank -world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert args.total_batch_size % world_tokens_per_fwdbwd == 0 -grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd -print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") -print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") -print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") - -# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19) -batch_lr_scale = 1.0 -reference_batch_size = 2**19 -batch_ratio = args.total_batch_size / reference_batch_size -if batch_ratio != 1.0: - # SGD: linear scaling with batch size is standard (not used in nanochat) - # AdamW: sqrt scaling is standard - # Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer - batch_lr_scale = batch_ratio ** 0.5 - print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})") - -# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio) -weight_decay_scaled = args.weight_decay * (12 / args.depth)**2 -if args.depth != 12: - print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") - # ----------------------------------------------------------------------------- # Initialize the Model -# Create a new model with random weights -model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern) -with torch.device("meta"): - # All tensors are created as meta tensors (they have shape/dtype but no data) - model_config = GPTConfig(**model_config_kwargs) - model = GPT(model_config) -model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data -model.init_weights() # All tensors get initialized +def build_model_meta(depth): + """Build a model on meta device for a given depth (shapes/dtypes only, no data).""" + # Model dim is nudged up to nearest multiple of head_dim for clean division + # (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly) + base_dim = depth * args.aspect_ratio + model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim + num_heads = model_dim // args.head_dim + config = GPTConfig( + sequence_len=args.max_seq_len, vocab_size=vocab_size, + n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, + window_pattern=args.window_pattern, + ) + with torch.device("meta"): + model_meta = GPT(config) + return model_meta + +# Build the model, move to device, init the weights +model = build_model_meta(args.depth) # 1) Build on meta device (only shapes/dtypes, no data) +model_config = model.config +model_config_kwargs = asdict(model_config) +print0(f"Model config:\n{json.dumps(model_config_kwargs, indent=2)}") +model.to_empty(device=device) # 2) All tensors get storage on target device but with uninitialized (garbage) data +model.init_weights() # 3) All tensors get initialized # If we are resuming, overwrite the model parameters with those of the checkpoint base_dir = get_base_dir() @@ -181,41 +158,7 @@ if resuming: del model_data # free up this memory after the copy # ----------------------------------------------------------------------------- -# Determine the length of the training run based on model size - -# Detailed parameter counts -param_counts = model.num_scaling_params() -print0(f"Parameter counts:") -for key, value in param_counts.items(): - print0(f"{key:24s}: {value:,}") -num_params = param_counts['total'] -num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026 -num_flops_per_token = model.estimate_flops() -print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") - -# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order) -assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0 -if args.num_iterations > 0: - num_iterations = args.num_iterations - print0(f"Using user-provided number of iterations: {num_iterations:,}") -elif args.target_flops > 0: - # calculate the number of iterations from the target flops - num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size)) - print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") -elif args.target_param_data_ratio > 0: - # calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.) - target_tokens = int(args.target_param_data_ratio * num_scaling_params) - num_iterations = target_tokens // args.total_batch_size - print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") -else: - raise ValueError("No training horizon specified") -total_tokens = args.total_batch_size * num_iterations -print0(f"Total number of training tokens: {total_tokens:,}") -print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 -print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") - -# ----------------------------------------------------------------------------- -# FP8 training initialization and management (has to be done before torch.compile) +# FP8 training initialization and management (this has to be done before torch.compile) # Convert Linear layers to Float8Linear if --fp8 is set if args.fp8: @@ -293,6 +236,82 @@ def disable_fp8(model): orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +# ----------------------------------------------------------------------------- +# Determine the optimization horizon based on the model size +# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis). +# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params + +# Get the parameter counts of the model +param_counts = model.num_scaling_params() +print0(f"Parameter counts:") +for key, value in param_counts.items(): + print0(f"{key:24s}: {value:,}") +num_params = param_counts['total'] +num_flops_per_token = model.estimate_flops() +print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") + +# Scaling params: transformer matrices + lm_head (gives cleanest scaling laws, see dev/LOG.md Jan 27, 2026) +get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head'] +num_scaling_params = get_scaling_params(model) +target_tokens = int(args.target_param_data_ratio * num_scaling_params) + +# Auto-compute optimal batch size based on Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738 +if args.total_batch_size == -1: + d12_ref = build_model_meta(12) # d12 is where the optimal batch size was measured to be 2**19 tokens + d12_num_scaling_params = get_scaling_params(d12_ref) + D_REF = args.target_param_data_ratio * d12_num_scaling_params + B_REF = 2**19 + args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383)) # also clamp to power of 2 + print0(f"Auto-computed optimal batch size: {args.total_batch_size:,} tokens") + +# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order) +assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0 +if args.num_iterations > 0: + # Override num_iterations to a specific value if given + num_iterations = args.num_iterations + print0(f"Using user-provided number of iterations: {num_iterations:,}") +elif args.target_flops > 0: + # Calculate the number of iterations from the target flops (used in scaling laws analysis, e.g. runs/scaling_laws.sh) + num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size)) + print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") +elif args.target_param_data_ratio > 0: + # Calculate the number of iterations from the target param data ratio (the most common use case) + num_iterations = target_tokens // args.total_batch_size + print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") +else: + raise ValueError("No training horizon specified") +total_tokens = args.total_batch_size * num_iterations +print0(f"Total number of training tokens: {total_tokens:,}") +print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 +print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") + +# ----------------------------------------------------------------------------- +# Optimizer / data / training length related hyperparameters +# figure out the needed gradient accumulation to reach the desired total batch size +tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank +world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks +assert args.total_batch_size % world_tokens_per_fwdbwd == 0 +grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd +print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") +print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") +print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") + +# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19) +batch_lr_scale = 1.0 +reference_batch_size = 2**19 +batch_ratio = args.total_batch_size / reference_batch_size +if batch_ratio != 1.0: + # SGD: linear scaling with batch size is standard (not used in nanochat) + # AdamW: sqrt scaling is standard + # Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer + batch_lr_scale = batch_ratio ** 0.5 + print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})") + +# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio) +weight_decay_scaled = args.weight_decay * (12 / args.depth)**2 +if args.depth != 12: + print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") + # ----------------------------------------------------------------------------- # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) adam_betas = (args.adam_beta1, args.adam_beta2) From 2c062aaa949536c1cff2ffb3df2ae4aeba20dc4b Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 5 Feb 2026 19:59:46 +0000 Subject: [PATCH 22/55] nit: don't mutate args, create new var for total_batch_size --- scripts/base_train.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index 97264c8..a3774e6 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -256,13 +256,15 @@ num_scaling_params = get_scaling_params(model) target_tokens = int(args.target_param_data_ratio * num_scaling_params) # Auto-compute optimal batch size based on Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738 -if args.total_batch_size == -1: +total_batch_size = args.total_batch_size +if total_batch_size == -1: d12_ref = build_model_meta(12) # d12 is where the optimal batch size was measured to be 2**19 tokens d12_num_scaling_params = get_scaling_params(d12_ref) D_REF = args.target_param_data_ratio * d12_num_scaling_params B_REF = 2**19 - args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383)) # also clamp to power of 2 - print0(f"Auto-computed optimal batch size: {args.total_batch_size:,} tokens") + batch_size_ratio = target_tokens / D_REF + total_batch_size = 2 ** round(math.log2(B_REF * batch_size_ratio ** 0.383)) # also clamp to power of 2 + print0(f"Auto-computed optimal batch size: {total_batch_size:,} tokens") # Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order) assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0 @@ -272,17 +274,17 @@ if args.num_iterations > 0: print0(f"Using user-provided number of iterations: {num_iterations:,}") elif args.target_flops > 0: # Calculate the number of iterations from the target flops (used in scaling laws analysis, e.g. runs/scaling_laws.sh) - num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size)) + num_iterations = round(args.target_flops / (num_flops_per_token * total_batch_size)) print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") elif args.target_param_data_ratio > 0: # Calculate the number of iterations from the target param data ratio (the most common use case) - num_iterations = target_tokens // args.total_batch_size + num_iterations = target_tokens // total_batch_size print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") else: raise ValueError("No training horizon specified") -total_tokens = args.total_batch_size * num_iterations +total_tokens = total_batch_size * num_iterations print0(f"Total number of training tokens: {total_tokens:,}") -print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 +print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") # ----------------------------------------------------------------------------- @@ -290,22 +292,22 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") # figure out the needed gradient accumulation to reach the desired total batch size tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert args.total_batch_size % world_tokens_per_fwdbwd == 0 -grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd +assert total_batch_size % world_tokens_per_fwdbwd == 0 +grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") -print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") +print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") # Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19) batch_lr_scale = 1.0 reference_batch_size = 2**19 -batch_ratio = args.total_batch_size / reference_batch_size +batch_ratio = total_batch_size / reference_batch_size if batch_ratio != 1.0: # SGD: linear scaling with batch size is standard (not used in nanochat) # AdamW: sqrt scaling is standard # Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer batch_lr_scale = batch_ratio ** 0.5 - print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})") + print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {reference_batch_size:,})") # Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio) weight_decay_scaled = args.weight_decay * (12 / args.depth)**2 @@ -381,7 +383,7 @@ else: # Training loop while True: last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end - flops_so_far = num_flops_per_token * args.total_batch_size * step + flops_so_far = num_flops_per_token * total_batch_size * step # once in a while: evaluate the val bpb (all ranks participate) if args.eval_every > 0 and (last_step or step % args.eval_every == 0): @@ -501,8 +503,8 @@ while True: smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA pct_done = 100 * step / num_iterations - tok_per_sec = int(args.total_batch_size / dt) - flops_per_sec = num_flops_per_token * args.total_batch_size / dt + tok_per_sec = int(total_batch_size / dt) + flops_per_sec = num_flops_per_token * total_batch_size / dt mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) if step > 10: total_training_time += dt # only count the time after the first 10 steps @@ -560,7 +562,7 @@ get_report().log(section="Base model training", data=[ "Number of FLOPs per token": f"{num_flops_per_token:e}", "Calculated number of iterations": num_iterations, "Number of training tokens": total_tokens, - "Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params, + "Tokens : Scaling params ratio": total_batch_size * num_iterations / num_scaling_params, "DDP world size": ddp_world_size, "warmup_ratio": args.warmup_ratio, "warmdown_ratio": args.warmdown_ratio, From 5fdd5cdb246d2e82a1fcc05fd4c0468df824d875 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 5 Feb 2026 20:11:32 +0000 Subject: [PATCH 23/55] new leaderboard record via new auto-calculated optimal batch size. for d26 it is 1M, up from 0.5M that was default earlier --- README.md | 1 + dev/LEADERBOARD.md | 32 +++++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4911090..182d273 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ For questions about the repo, I recommend either using [DeepWiki](https://deepwi | 0 | 168 hours | - | 0.2565 | Original OpenAI GPT-2 checkpoint | 2019 | - | OpenAI | | 1 | 3.04 | 0.74833 | 0.2585 | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy | | 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | a67eba3 | @karpathy | +| 3 | 2.76 | 0.74645 | 0.2602 | bump total batch size to 1M tokens | Feb 5 2026 | 2c062aa | @karpathy | The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $50,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 3 hours is ~$72). diff --git a/dev/LEADERBOARD.md b/dev/LEADERBOARD.md index 6c2720c..b8a727f 100644 --- a/dev/LEADERBOARD.md +++ b/dev/LEADERBOARD.md @@ -89,7 +89,7 @@ Detailed writeup: [Beating GPT-2 for <<$100: the nanochat journey](https://githu ## Run 2 -Achieved Feb 2 2026 on commit `8309b83`. The launch command was +Achieved Feb 2 2026 on commit `a67eba3`. The launch command was ``` OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ @@ -117,3 +117,33 @@ Minimum validation bpb: 0.745036 The big change in this run is `--fp8`, which causes all Linear layers (other than the gates) to be switched to fp8 training using `torchao` with tensorwise fp8 scaling. Each step is of slightly lower quality, but we are taking them a lot faster, coming out net ahead. Anyone who does not have fp8 (e.g. using a GPU without it) can simply leave out the `--fp8` flag to train in bfloat16. This will work just fine but it will produce a slightly stronger model than GPT-2 because of the fp8 -> bf16 precision upgrade. It's possible that one can further tune which layers to include in the fp8 conversion and that e.g. some of the smaller matmuls should be just kept in bf16 etc. Previous record was 3.04 hours, so 2.91 hours is `(3.04 - 2.91)/3.04*100` ~= 4.3% speed improvement. + +## Run 3 + +Achieved Feb 5 2026 on commit `2c062aa`. Launch command: + +``` +OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --depth=26 \ + --run="d26_feb4_double_batch_ratio8.25" \ + --model-tag="d26_feb4_double_batch_ratio8.25" \ + --device-batch-size=16 \ + --total-batch-size=1048576 \ + --sample-every=-1 \ + --save-every=-1 \ + --core-metric-max-per-task=-1 \ + --core-metric-every=999999 \ + --target-param-data-ratio=8.25 \ + --fp8 +``` + +Result: + +``` +core_metric 0.26024 +step 7226 +total_training_time 9922 +Minimum validation bpb: 0.74645 +``` + +The big change here is that the batch size was doubled from 0.5M to 1M, which works better for a d26 model and allowed me to decrease the number of optimization steps a bit via `--target-param-data-ratio` from 8.5 to 8.25. The TLDR is that the original batch size of 0.5M was tuned for d12, but bigger models (e.g. d26) prefer larger total batch size. I determined in experiments that d26 prefers 1M. Then I implemented and merged a principled way to calculate the optimal batch size given depth so that all nanochat models of all depths benefit. See [dev/LOG.md](dev/LOG.md) entry "2026-02-05: Auto Batch Size Scaling" for more detail. From 96522798f12007341176c19f57d98c4dd76b7a68 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 5 Feb 2026 20:27:07 +0000 Subject: [PATCH 24/55] docs docs docs --- README.md | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 182d273..1894ac8 100644 --- a/README.md +++ b/README.md @@ -3,16 +3,13 @@ ![nanochat logo](dev/nanochat.png) ![scaling laws](dev/scaling_laws_jan26.png) -nanochat is the simplest experimental harness for training LLMs. It is designed to run on a single GPU node, the code is minimal/hackable, and it covers all major LLM stages including tokenization, pretraining, finetuning, evaluation, inference, and a chat UI. For example, you can train your own GPT-2 capability LLM (which cost ~$50,000 to train in 2019) for only $73 (3 hours of 8XH100 GPU node) and then talk to it in a familiar ChatGPT-like web UI. +nanochat is the simplest experimental harness for training LLMs. It is designed to run on a single GPU node, the code is minimal/hackable, and it covers all major LLM stages including tokenization, pretraining, finetuning, evaluation, inference, and a chat UI. For example, you can train your own GPT-2 capability LLM (which cost ~$43,000 to train in 2019) for only $72 (~3 hours of 8XH100 GPU node) and then talk to it in a familiar ChatGPT-like web UI. On a spot instance, the total cost can be closer to ~$20. More generally, nanochat is configured out of the box to train an entire miniseries of compute-optimal models by setting one single complexity dial: `--depth`, the number of layers in the GPT transformer model (GPT-2 capability happens to be approximately depth 26). All other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) are calculated automatically in an optimal way. For questions about the repo, I recommend either using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions about the repo, or use the [Discussions tab](https://github.com/karpathy/nanochat/discussions), or come by the [#nanochat](https://discord.com/channels/1020383067459821711/1427295580895314031) channel on Discord. -## Updates +## Time-to-GPT-2 Leaderboard -- (Jan 31 2026) Major revamp of all scripts/README ongoing, deleting midtraining stage, might be a bit messy briefly... -- (Jan 30 2026) With all the latest improvements we're able to train GPT-2 grade LLM in about $73. The [runs/speedrun.sh](runs/speedrun.sh) script will become the reference way to train GPT-2 grade model and talk to it. - -## Leaderboard +Presently, the main focus of development is on tuning the pretraining stage, which takes the most amount of compute. Inspired by the modded-nanogpt repo and to incentivise progress and community collaboration, nanochat maintains a leaderboard for a "GPT-2 speedrun", which is the wall-clock time required to train a nanochat model to GPT-2 grade capability, as measured by the DCLM CORE score. The [runs/speedrun.sh](runs/speedrun.sh) script always reflects the reference way to train GPT-2 grade model and talk to it. The current leaderboard looks as follows: | # | time | val_bpb | CORE | Description | Date | Commit | Contributors | |---|-------------|---------|------|-------------|------|--------|--------------| @@ -21,7 +18,7 @@ For questions about the repo, I recommend either using [DeepWiki](https://deepwi | 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | a67eba3 | @karpathy | | 3 | 2.76 | 0.74645 | 0.2602 | bump total batch size to 1M tokens | Feb 5 2026 | 2c062aa | @karpathy | -The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $50,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 3 hours is ~$72). +The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $43,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 3 hours is ~$72). See [dev/LEADERBOARD.md](dev/LEADERBOARD.md) for more docs on how to interpret and contribute to the leaderboard. @@ -29,7 +26,7 @@ See [dev/LEADERBOARD.md](dev/LEADERBOARD.md) for more docs on how to interpret a ### Reproduce and talk to GPT-2 -The most fun you can have is to train your own GPT-2 and talk to it. The entire pipeline to do so is contained in the single file [runs/speedrun.sh](runs/speedrun.sh), which is designed to be run on an 8XH100 GPU node. Currently, at ~$24/hour for these nodes, pretraining a GPT-2 grade model takes approximately 3 hours and will set you back about $75. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script: +The most fun you can have is to train your own GPT-2 and talk to it. The entire pipeline to do so is contained in the single file [runs/speedrun.sh](runs/speedrun.sh), which is designed to be run on an 8XH100 GPU node. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script: ```bash bash runs/speedrun.sh @@ -70,7 +67,13 @@ OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train --save-every=-1 \ ``` -This uses wandb (run name "d12"), only runs the CORE metric on last step, and it doesn't sample and save intermediate checkpoints. I like to change something in the code, re-run a d12 (or a d16 etc) and see if it helped, in an iteration loop. +This uses wandb (run name "d12"), only runs the CORE metric on last step, and it doesn't sample and save intermediate checkpoints. I like to change something in the code, re-run a d12 (or a d16 etc) and see if it helped, in an iteration loop. To see if a run helps, I like to monitor the wandb plots for: + +1. `val_bpb` (validation loss in vocab-size-invariant units of bits per byte) as a function of `step`, `total_training_time` and `total_training_flops`. +2. `core_metric` (the DCLM CORE socre) +3. VRAM utilization, `train/mfu` (Model FLOPS utilization), `train/tok_per_sec` (training throughput) + +See an example [here](https://github.com/karpathy/nanochat/pull/498#issuecomment-3850720044). The important thing to note is that nanochat is written and configured around one single dial of complexity - the depth of the transformer. This single integer automatically determines all other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) so that the trained model comes out compute optimal. The idea is that the user doesn't have to think about or set any of this, they are simply asking for a smaller or bigger model using `--depth`, and everything "just works". By sweeping out the depth, you achieve the nanochat miniseries of compute optimal models at various sizes. GPT-2 capability model (which is of most interest at the moment) happens to be somewhere around d24-d26 range with the current code. But any candidate changes to the repo have to be principled enough that they work for all settings of depth. @@ -80,12 +83,13 @@ The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of runni ## Guides -I've published a number of guides that might contain helpful information: +I've published a number of guides that might contain helpful information, most recent to least recent: -- [Oct 13 2025 original nanochat post](https://github.com/karpathy/nanochat/discussions/1) introducing nanochat, though now it contains some deprecated information and the model is a lot older (with worse results) than current master. +- [Feb 1 2026: Beating GPT-2 for <<$100: the nanochat journey](https://github.com/karpathy/nanochat/discussions/481) - [Jan 7 miniseries v1](https://github.com/karpathy/nanochat/discussions/420) documents the first nanochat miniseries of models. -- To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into the SFT stage. - To add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164). +- To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into the SFT stage. +- [Oct 13 2025: original nanochat post](https://github.com/karpathy/nanochat/discussions/1) introducing nanochat, though now it contains some deprecated information and the model is a lot older (with worse results) than current master. ## File structure From e527521a3fd91e8f3a2016b10db21e5742aa41fe Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 5 Feb 2026 22:21:03 +0000 Subject: [PATCH 25/55] briefly mention batch ramp experimentation too, too weak to merge in my few attempts --- dev/LOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dev/LOG.md b/dev/LOG.md index 6ce4173..dec2c06 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -48,6 +48,10 @@ Also refactored model initialization to use `build_model_meta(depth)` helper and - [Brown et al., Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165) - [Merrill et al., The Batch Size–Critical Batch Size Myth](https://arxiv.org/abs/2505.23971) +### One more thing (batch size ramp) + +Tried batch size ramping. The simplest implementation I could think of "tricks" the existing training loop by slicing each micro-batch into smaller pieces and calling optimizer.step() more frequently early in training (1/8 → 1/4 → 1/2 → full batch over the first x% of training, with sqrt LR scaling). Also required a torch.compile warmup phase to pre-compile all slice sizes and avoid recompilation spikes during training. While the idea is sound and small gains were observed, they weren't sufficient to justify the code complexity introduced (conditional slicing logic, warmup with state save/restore, etc.). Not merged for now. + --- ## 2026-02-05: SwiGLU Activation (Negative Result) From 685271dc8db21afe02ecb9a9fa9785b50eac2421 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 6 Feb 2026 19:21:27 +0000 Subject: [PATCH 26/55] new optimal ratio for d26 training --- runs/speedrun.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runs/speedrun.sh b/runs/speedrun.sh index c423ba6..62466c7 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -70,7 +70,7 @@ echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # d24 model (slightly overtrained is enough to beat GPT-2 => increase data:params ratio from compute optimal 10.5 (default) to 12) -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --target-param-data-ratio=8.5 --device-batch-size=16 --fp8 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --target-param-data-ratio=8.25 --device-batch-size=16 --fp8 --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 From aeff095e97a3721202e188b8348948536dc90a83 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 6 Feb 2026 19:22:28 +0000 Subject: [PATCH 27/55] better comments/flow on all the hyperparameter transfer stuff, and change the WD scaling from my empirical 1/d^2 to a bit more principled version based on Tepoch. All of that theory is based on AdamW and could be suboptimal for Muon --- scripts/base_train.py | 163 ++++++++++++++++++++++-------------------- 1 file changed, 87 insertions(+), 76 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index a3774e6..ccf35e6 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -237,11 +237,9 @@ orig_model = model # original, uncompiled model, for saving raw model state_dict model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe # ----------------------------------------------------------------------------- -# Determine the optimization horizon based on the model size -# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis). -# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params +# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay. -# Get the parameter counts of the model +# Get the parameter counts of our model param_counts = model.num_scaling_params() print0(f"Parameter counts:") for key, value in param_counts.items(): @@ -250,23 +248,80 @@ num_params = param_counts['total'] num_flops_per_token = model.estimate_flops() print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") -# Scaling params: transformer matrices + lm_head (gives cleanest scaling laws, see dev/LOG.md Jan 27, 2026) -get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head'] +# 1) Use scaling laws to determine the optimal training horizon in tokens +# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis). +# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params +def get_scaling_params(m): + # As for which params to use exactly, transformer matrices + lm_head gives cleanest scaling laws (see dev/LOG.md Jan 27, 2026) + params_counts = m.num_scaling_params() + scaling_params = params_counts['transformer_matrices'] + params_counts['lm_head'] + return scaling_params num_scaling_params = get_scaling_params(model) -target_tokens = int(args.target_param_data_ratio * num_scaling_params) +target_tokens = int(args.target_param_data_ratio * num_scaling_params) # optimal tokens for the model we are about to train -# Auto-compute optimal batch size based on Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738 -total_batch_size = args.total_batch_size +# Our reference model is d12, this is where a lot of hyperparameters are tuned and then transfered to higher depths (muP style) +d12_ref = build_model_meta(12) # creates the model on meta device +D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) # compute-optimal d12 training horizon in tokens (measured empirically) +B_REF = 2**19 # optimal batch size at d12 ~= 524,288 tokens (measured empirically) + +# 2) Now that we have the token horizon, we can calculate the optimal batch size +# We follow the Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738 +# The optimal batch size grows as approximately D^0.383, so e.g. if D doubles from d12 to d24, B should grow by 2^0.383 ≈ 1.3x. +total_batch_size = args.total_batch_size # user-provided override is possible if total_batch_size == -1: - d12_ref = build_model_meta(12) # d12 is where the optimal batch size was measured to be 2**19 tokens - d12_num_scaling_params = get_scaling_params(d12_ref) - D_REF = args.target_param_data_ratio * d12_num_scaling_params - B_REF = 2**19 batch_size_ratio = target_tokens / D_REF - total_batch_size = 2 ** round(math.log2(B_REF * batch_size_ratio ** 0.383)) # also clamp to power of 2 + predicted_batch_size = B_REF * batch_size_ratio ** 0.383 + total_batch_size = 2 ** round(math.log2(predicted_batch_size)) # clamp to nearest power of 2 for efficiency print0(f"Auto-computed optimal batch size: {total_batch_size:,} tokens") -# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order) +# 3) Knowing the batch size, we can now calculate a learning rate correction (bigger batch size allows higher learning rates) +batch_lr_scale = 1.0 +batch_ratio = total_batch_size / B_REF # B/B_ref +if batch_ratio != 1.0: + # SGD: linear scaling with batch size is standard (not used in nanochat) + # AdamW: sqrt scaling is standard: η ∝ √(B/B_ref) + # Muon: we will use the same scaling for Muon as for AdamW: η ∝ √(B/B_ref) (not studied carefully, assumption!) + batch_lr_scale = batch_ratio ** 0.5 # η ∝ √(B/B_ref) + print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {B_REF:,})") + +# 4) Knowing the batch size and the token horizon, we can now calculate the appropriate weight decay scaling +# We adopt the T_epoch framework from https://arxiv.org/abs/2405.13698 +# Central idea of the paper is that T_epoch = B/(η·λ·D) should remain constant. +# Above, we used learning rate scaling η ∝ √(B/B_ref). So it's a matter of ~10 lines of math to derive that to keep T_epoch constant, we need: +# λ = λ_ref · √(B/B_ref) · (D_ref/D) +# Note that these papers study AdamW, *not* Muon. We are blindly following AdamW theory for scaling hoping it ~works for Muon too. +weight_decay_scaled = args.weight_decay * math.sqrt(total_batch_size / B_REF) * (D_REF / target_tokens) +if weight_decay_scaled != args.weight_decay: + print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") + +# ----------------------------------------------------------------------------- +# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) +optimizer = model.setup_optimizer( + # AdamW hyperparameters + unembedding_lr=args.unembedding_lr * batch_lr_scale, + embedding_lr=args.embedding_lr * batch_lr_scale, + scalar_lr=args.scalar_lr * batch_lr_scale, + adam_betas=(args.adam_beta1, args.adam_beta2), + # Muon hyperparameters + matrix_lr=args.matrix_lr * batch_lr_scale, + weight_decay=weight_decay_scaled, +) + +if resuming: + optimizer.load_state_dict(optimizer_data) + del optimizer_data + +# ----------------------------------------------------------------------------- +# Initialize the DataLoaders for train/val +dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] +train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) +build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) +x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data + +# ----------------------------------------------------------------------------- +# Calculate the number of iterations we will train for and set up the various schedulers + +# num_iterations: either it is given, or from target flops, or from target data:param ratio (in that order) assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0 if args.num_iterations > 0: # Override num_iterations to a specific value if given @@ -282,65 +337,12 @@ elif args.target_param_data_ratio > 0: print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") else: raise ValueError("No training horizon specified") -total_tokens = total_batch_size * num_iterations +total_tokens = total_batch_size * num_iterations # the actual number of tokens we will train for print0(f"Total number of training tokens: {total_tokens:,}") -print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20 +print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") -# ----------------------------------------------------------------------------- -# Optimizer / data / training length related hyperparameters -# figure out the needed gradient accumulation to reach the desired total batch size -tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank -world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert total_batch_size % world_tokens_per_fwdbwd == 0 -grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd -print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") -print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") -print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") - -# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19) -batch_lr_scale = 1.0 -reference_batch_size = 2**19 -batch_ratio = total_batch_size / reference_batch_size -if batch_ratio != 1.0: - # SGD: linear scaling with batch size is standard (not used in nanochat) - # AdamW: sqrt scaling is standard - # Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer - batch_lr_scale = batch_ratio ** 0.5 - print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {reference_batch_size:,})") - -# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio) -weight_decay_scaled = args.weight_decay * (12 / args.depth)**2 -if args.depth != 12: - print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") - -# ----------------------------------------------------------------------------- -# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) -adam_betas = (args.adam_beta1, args.adam_beta2) -optimizer = model.setup_optimizer( - unembedding_lr=args.unembedding_lr * batch_lr_scale, - embedding_lr=args.embedding_lr * batch_lr_scale, - matrix_lr=args.matrix_lr * batch_lr_scale, - weight_decay=weight_decay_scaled, - adam_betas=adam_betas, - scalar_lr=args.scalar_lr * batch_lr_scale, -) - -if resuming: - optimizer.load_state_dict(optimizer_data) - del optimizer_data - -# ----------------------------------------------------------------------------- -# Initialize the DataLoaders for train/val -dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] -train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) -build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) -x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data - -# ----------------------------------------------------------------------------- -# Set up hyperparameter schedulers - -# Learning rate scheduler +# Learning rate schedule (linear warmup, constant, linear warmdown) def get_lr_multiplier(it): warmup_iters = round(args.warmup_ratio * num_iterations) warmdown_iters = round(args.warmdown_ratio * num_iterations) @@ -352,19 +354,20 @@ def get_lr_multiplier(it): progress = (num_iterations - it) / warmdown_iters return progress * 1.0 + (1 - progress) * args.final_lr_frac -# Momentum scheduler for Muon optimizer +# Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps) def get_muon_momentum(it): frac = min(it / 300, 1) momentum = (1 - frac) * 0.85 + frac * 0.95 return momentum -# Weight decay scheduler for Muon optimizer (linear to zero over the course of training) +# Weight decay scheduler for Muon optimizer (linearly decays to zero over the course of training) def get_weight_decay(it): return weight_decay_scaled * (1 - it / num_iterations) # ----------------------------------------------------------------------------- -# Loop state (variables updated by the training loop) +# Training loop +# Loop state (variables updated by the training loop) if not resuming: step = 0 val_bpb = None # will be set if eval_every > 0 @@ -379,8 +382,16 @@ else: smooth_train_loss = loop_state["smooth_train_loss"] total_training_time = loop_state["total_training_time"] -# ----------------------------------------------------------------------------- -# Training loop +# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step +tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank +world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks +assert total_batch_size % world_tokens_per_fwdbwd == 0 +grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd +print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") +print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") +print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") + +# Go! while True: last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end flops_so_far = num_flops_per_token * total_batch_size * step From ff46300720e9ac5a5cd5d90f0d8cd3ccc20a76e2 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 8 Feb 2026 17:54:12 +0000 Subject: [PATCH 28/55] tune miniseries just a bit, fairly cosmetic, keep to even depths where the math works out nicely in model sizing --- runs/miniseries.sh | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/runs/miniseries.sh b/runs/miniseries.sh index c42544e..e57ee16 100644 --- a/runs/miniseries.sh +++ b/runs/miniseries.sh @@ -28,7 +28,7 @@ fi # Series name: from arg, env var, or default to today's date (e.g., jan11) SERIES_NAME="${1:-${SERIES_NAME:-$(date +%b%d | tr '[:upper:]' '[:lower:]')}}" # Depths to train (the "miniseries") -DEPTHS=(10 11 12 13 14 15 16 17 18 19 20) +DEPTHS=(12 14 16 18 20 22 24 26) # Hardware NPROC_PER_NODE="${NPROC_PER_NODE:-8}" # Logging @@ -57,8 +57,13 @@ for d in "${DEPTHS[@]}"; do TAG="${SERIES_NAME}_miniseries_d${d}" START_TIME=$(date +%s) - # Train the model with natural horizon (target_param_data_ratio default) - # No --target-flops, let it use the default ratio from base_train + # For depths >= 22, use smaller device batch size to avoid OOM + if [ $d -ge 22 ]; then + DEVICE_BATCH_SIZE_ARG="--device-batch-size=16" + else + DEVICE_BATCH_SIZE_ARG="--device-batch-size=32" + fi + torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ --depth=$d \ --run="${WANDB_RUN}_d${d}" \ @@ -67,6 +72,7 @@ for d in "${DEPTHS[@]}"; do --core-metric-max-per-task=-1 \ --sample-every=-1 \ --save-every=-1 \ + $DEVICE_BATCH_SIZE_ARG \ 2>&1 | tee "$RESULTS_DIR/${TAG}_train.log" END_TIME=$(date +%s) From 1ec0a347792f337bd38a93b15b79927466d0540a Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 8 Feb 2026 18:26:34 +0000 Subject: [PATCH 29/55] at 28 and above we start to need batch size 8 --- runs/miniseries.sh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/runs/miniseries.sh b/runs/miniseries.sh index e57ee16..01c4459 100644 --- a/runs/miniseries.sh +++ b/runs/miniseries.sh @@ -57,8 +57,10 @@ for d in "${DEPTHS[@]}"; do TAG="${SERIES_NAME}_miniseries_d${d}" START_TIME=$(date +%s) - # For depths >= 22, use smaller device batch size to avoid OOM - if [ $d -ge 22 ]; then + # Reduce --device-batch-size to avoid OOM at larger depths + if [ $d -ge 28 ]; then + DEVICE_BATCH_SIZE_ARG="--device-batch-size=8" + elif [ $d -ge 20 ]; then DEVICE_BATCH_SIZE_ARG="--device-batch-size=16" else DEVICE_BATCH_SIZE_ARG="--device-batch-size=32" From e569b59f92aea06bf8fc1c48489b3cc2e57189f4 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 10 Feb 2026 18:46:39 +0000 Subject: [PATCH 30/55] delete torchao dependency, create our own exact API-matched version of Float8Linear, document it very well. for some poorly understood reason, the performance is not only ~identical but actually runs 3% faster. despite of it being significantly simpler and much less code. i don't fully understand why/how atm --- nanochat/fp8.py | 272 ++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 1 - scripts/base_train.py | 4 +- uv.lock | 11 -- 4 files changed, 275 insertions(+), 13 deletions(-) create mode 100644 nanochat/fp8.py diff --git a/nanochat/fp8.py b/nanochat/fp8.py new file mode 100644 index 0000000..9d9e9c3 --- /dev/null +++ b/nanochat/fp8.py @@ -0,0 +1,272 @@ +"""Minimal FP8 training for nanochat — tensorwise dynamic scaling only. + +Drop-in replacement for torchao's Float8Linear (~2000 lines) with ~150 lines. +We only need the "tensorwise" recipe (one scalar scale per tensor), not the full +generality of torchao (rowwise scaling, FSDP float8 all-gather, DTensor, tensor +subclass dispatch tables, etc.) + +How FP8 training works +====================== +A standard Linear layer does one matmul in forward and two in backward: + forward: output = input @ weight.T + backward: grad_input = grad_output @ weight + grad_weight= grad_output.T @ input + +FP8 training wraps each of these three matmuls with: + 1. Compute scale = FP8_MAX / max(|tensor|) for each operand + 2. Quantize: fp8_tensor = clamp(tensor * scale, -FP8_MAX, FP8_MAX).to(fp8) + 3. Matmul via torch._scaled_mm (cuBLAS FP8 kernel, ~2x faster than bf16) + 4. Dequantize: _scaled_mm handles this internally using the inverse scales + +The key insight: torch._scaled_mm and the float8 dtypes are PyTorch built-ins. +torchao is just orchestration around these primitives. We can call them directly. + +FP8 dtype choice +================ +There are two FP8 formats. We use both, following the standard convention: + - float8_e4m3fn: 4-bit exponent, 3-bit mantissa, range [-448, 448] + Higher precision (more mantissa bits), used for input and weight. + - float8_e5m2: 5-bit exponent, 2-bit mantissa, range [-57344, 57344] + Wider range (more exponent bits), used for gradients which can be large. + +torch._scaled_mm layout requirements +===================================== +The cuBLAS FP8 kernel requires specific memory layouts: + - First argument (A): must be row-major (contiguous) + - Second argument (B): must be column-major (B.t().contiguous().t()) +If B is obtained by transposing a contiguous tensor (e.g. weight.t()), it is +already column-major — no copy needed. Otherwise we use _to_col_major(). + +How this differs from torchao's approach +======================================== +torchao uses a "tensor subclass" architecture: Float8TrainingTensor is a subclass +of torch.Tensor that bundles FP8 data + scale + metadata. It implements +__torch_dispatch__ with a dispatch table that intercepts every aten op (mm, t, +reshape, clone, ...) and handles it in FP8-aware fashion. When you call + output = input @ weight.T +the @ operator dispatches to aten.mm, which gets intercepted and routed to +torch._scaled_mm behind the scenes. This is ~2000 lines of code because you need +a handler for every tensor operation that might touch an FP8 tensor. + +We take a simpler approach: a single autograd.Function (_Float8Matmul) that takes +full-precision inputs, quantizes to FP8 internally, calls _scaled_mm, and returns +full-precision outputs. Marked @allow_in_graph so torch.compile treats it as one +opaque node rather than trying to trace inside. + +The trade-off is in how torch.compile sees the two approaches: + - torchao: compile decomposes the tensor subclass (via __tensor_flatten__) and + sees every individual op (amax, scale, cast, _scaled_mm) as separate graph + nodes. Inductor can fuse these with surrounding operations (e.g. fuse the + amax computation with the preceding layer's activation function). + - ours: compile sees a single opaque call. It can optimize everything around + the FP8 linear (attention, norms, etc.) but cannot fuse across the boundary. + +Both call the exact same cuBLAS _scaled_mm kernel — the GPU matmul is identical. +The difference is only in the "glue" ops (amax, scale, cast) which are tiny +compared to the matmul. In practice this means our version is slightly faster +(less compilation overhead, no tensor subclass dispatch cost) but can produce +subtly different floating-point rounding paths under torch.compile, since Inductor +generates a different graph. Numerics are bitwise identical in eager mode. +""" + +import torch +import torch.nn as nn + +# Avoid division by zero when computing scale from an all-zeros tensor +EPS = 1e-12 + + +@torch.no_grad() +def _to_fp8(x, fp8_dtype): + """Dynamically quantize a tensor to FP8 using tensorwise scaling. + + "Tensorwise" means one scalar scale for the entire tensor (as opposed to + "rowwise" which computes a separate scale per row). Tensorwise is faster + because cuBLAS handles the scaling; rowwise needs the CUTLASS kernel. + + Returns (fp8_data, inverse_scale) for use with torch._scaled_mm. + """ + fp8_max = torch.finfo(fp8_dtype).max + # Compute the max absolute value across the entire tensor + amax = x.float().abs().max() + # Scale maps [0, amax] -> [0, fp8_max]. Use float64 for the division to + # ensure consistent numerics between torch.compile and eager mode. + # (torchao does the same upcast — without it, compile/eager can diverge) + scale = fp8_max / amax.double().clamp(min=EPS) + scale = scale.float() + # Quantize: scale into FP8 range, saturate (clamp prevents overflow when + # casting — PyTorch's default is to wrap, not saturate), then cast to FP8 + x_scaled = x.float() * scale + x_clamped = x_scaled.clamp(-fp8_max, fp8_max) + x_fp8 = x_clamped.to(fp8_dtype) + # _scaled_mm expects the *inverse* of our scale (it multiplies by this to + # convert FP8 values back to the original range during the matmul) + inv_scale = scale.reciprocal() + return x_fp8, inv_scale + + +def _to_col_major(x): + """Rearrange a 2D tensor's memory to column-major layout. + + torch._scaled_mm requires its second operand in column-major layout. + The trick: transpose -> contiguous (forces a copy in transposed order) + -> transpose back. The result has the same logical shape but column-major + strides, e.g. a [M, N] tensor gets strides (1, M) instead of (N, 1). + """ + return x.t().contiguous().t() + + +# allow_in_graph tells torch.compile to treat this as an opaque operation — +# dynamo won't try to decompose it into smaller ops. See the module docstring +# for how this differs from torchao's tensor subclass approach. +@torch._dynamo.allow_in_graph +class _Float8Matmul(torch.autograd.Function): + """Custom autograd for the three FP8 GEMMs of a Linear layer. + + The forward saves input and weight in their original precision for the + backward pass. Each GEMM independently re-quantizes its operands to FP8. + (We don't reuse the forward's FP8 tensors in backward — the backward might + want different precision, and saving FP8 would lose information.) + """ + + @staticmethod + def forward(ctx, input_2d, weight): + ctx.save_for_backward(input_2d, weight) + + # Quantize both operands to e4m3 (higher precision format) + input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn) + weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn) + + # output = input @ weight.T + # input_fp8 is [B, K] contiguous = row-major (good for first arg) + # weight_fp8 is [N, K] contiguous, so weight_fp8.t() is [K, N] with + # strides (1, K) = column-major (good for second arg, no copy needed!) + output = torch._scaled_mm( + input_fp8, + weight_fp8.t(), + scale_a=input_inv, + scale_b=weight_inv, + out_dtype=input_2d.dtype, + # use_fast_accum=True accumulates the dot products in lower precision. + # Slightly less accurate but measurably faster. Standard practice for + # the forward pass; we use False in backward for more precise gradients. + use_fast_accum=True, + ) + return output + + @staticmethod + def backward(ctx, grad_output): + input_2d, weight = ctx.saved_tensors + + # === GEMM 1: grad_input = grad_output @ weight === + # Shapes: [B, N] @ [N, K] -> [B, K] + # Gradients use e5m2 (wider range), weights use e4m3 (higher precision) + go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2) + w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn) + # go_fp8 is [B, N] contiguous = row-major, good for first arg + # w_fp8 is [N, K] contiguous = row-major, need column-major for second arg + w_col = _to_col_major(w_fp8) + grad_input = torch._scaled_mm( + go_fp8, + w_col, + scale_a=go_inv, + scale_b=w_inv, + out_dtype=grad_output.dtype, + use_fast_accum=False, + ) + + # === GEMM 2: grad_weight = grad_output.T @ input === + # Shapes: [N, B] @ [B, K] -> [N, K] + go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) + in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn) + # go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg. + # Transposing gives column-major, but first arg needs row-major, + # so we must call .contiguous() to physically rearrange the memory. + go_T = go_fp8_2.t().contiguous() # [N, B] row-major + in_col = _to_col_major(in_fp8) # [B, K] column-major + grad_weight = torch._scaled_mm( + go_T, + in_col, + scale_a=go_inv_2, + scale_b=in_inv, + out_dtype=grad_output.dtype, + use_fast_accum=False, + ) + + return grad_input, grad_weight + + +class Float8Linear(nn.Linear): + """Drop-in nn.Linear replacement that does FP8 compute. + + Weights and biases remain in their original precision (e.g. fp32/bf16). + Only the matmul is performed in FP8 via the _Float8Matmul autograd function. + """ + + def forward(self, input): + # Replicate the autocast behavior of F.linear — when autocast is active, + # we need to manually cast input to the autocast dtype (e.g. bf16), + # since we bypass F.linear's built-in autocast handling. + if torch.is_autocast_enabled(): + input = input.to(torch.get_autocast_gpu_dtype()) + # _scaled_mm only works on 2D tensors, so flatten batch dimensions + orig_shape = input.shape + input_2d = input.reshape(-1, orig_shape[-1]) + output = _Float8Matmul.apply(input_2d, self.weight) + output = output.reshape(*orig_shape[:-1], output.shape[-1]) + if self.bias is not None: + output = output + self.bias.to(output.dtype) + return output + + @classmethod + def from_float(cls, mod): + """Create Float8Linear from nn.Linear, sharing the same weight and bias. + + Uses meta device to avoid allocating a temporary weight tensor — we + create the module shell on meta (shapes/dtypes only, no memory), then + point .weight and .bias to the original module's parameters. + """ + with torch.device("meta"): + new_mod = cls(mod.in_features, mod.out_features, bias=False) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + return new_mod + + +class Float8LinearConfig: + """Minimal config matching torchao's API. Only tensorwise recipe is supported.""" + + @staticmethod + def from_recipe_name(recipe_name): + if recipe_name != "tensorwise": + raise ValueError( + f"Only 'tensorwise' recipe is supported, got '{recipe_name}'. " + f"Rowwise/axiswise recipes require the full torchao library." + ) + return Float8LinearConfig() + + +def convert_to_float8_training(module, *, config=None, module_filter_fn=None): + """Replace nn.Linear layers with Float8Linear throughout a module. + + Walks the module tree in post-order (children before parents) and swaps + each nn.Linear that passes the optional filter. The new Float8Linear shares + the original weight and bias tensors — no copies, no extra memory. + + Args: + module: Root module to convert. + config: Float8LinearConfig (accepted for API compat, only tensorwise supported). + module_filter_fn: Optional filter(module, fqn) -> bool. Only matching Linears + are converted. Common use: skip layers with dims not divisible by 16 + (hardware requirement for FP8 matmuls on H100). + """ + def _convert(mod, prefix=""): + for name, child in mod.named_children(): + fqn = f"{prefix}.{name}" if prefix else name + _convert(child, fqn) + if isinstance(child, nn.Linear) and not isinstance(child, Float8Linear): + if module_filter_fn is None or module_filter_fn(child, fqn): + setattr(mod, name, Float8Linear.from_float(child)) + + _convert(module) + return module diff --git a/pyproject.toml b/pyproject.toml index bcb674d..8b6fd95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ dependencies = [ "tiktoken>=0.11.0", "tokenizers>=0.22.0", "torch==2.9.1", - "torchao==0.15.0", "transformers>=4.57.3", "uvicorn>=0.36.0", "wandb>=0.21.3", diff --git a/scripts/base_train.py b/scripts/base_train.py index ccf35e6..ee53098 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -165,7 +165,9 @@ if args.fp8: if device_type != "cuda": print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag") else: - from torchao.float8 import Float8LinearConfig, convert_to_float8_training + # our custom fp8 is simpler than torchao, written for exact API compatibility + from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training + # from torchao.float8 import Float8LinearConfig, convert_to_float8_training import torch.nn as nn # Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement) diff --git a/uv.lock b/uv.lock index e5fc97f..bbc9519 100644 --- a/uv.lock +++ b/uv.lock @@ -1509,7 +1509,6 @@ dependencies = [ { name = "torch", version = "2.9.1", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, { name = "torch", version = "2.9.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, { name = "torch", version = "2.9.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" }, - { name = "torchao" }, { name = "transformers" }, { name = "uvicorn" }, { name = "wandb" }, @@ -1549,7 +1548,6 @@ requires-dist = [ { name = "torch", specifier = "==2.9.1" }, { name = "torch", marker = "extra == 'cpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } }, { name = "torch", marker = "extra == 'gpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } }, - { name = "torchao", specifier = "==0.15.0" }, { name = "transformers", specifier = ">=4.57.3" }, { name = "uvicorn", specifier = ">=0.36.0" }, { name = "wandb", specifier = ">=0.21.3" }, @@ -3184,15 +3182,6 @@ wheels = [ { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:0c784b600959ec70ee01cb23e8bc870a0e0475af30378ff5e39f4abed8b7c1cc" }, ] -[[package]] -name = "torchao" -version = "0.15.0" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/57/2d/472b9362dceae05a4599e2b94f86e69a29c0e20964a6af84f34f6ead5938/torchao-0.15.0-cp310-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1cbe813201314ba6329a650a76944502f3e8ec4b1b44523f3f48676810d8d1f6", size = 7163930, upload-time = "2025-12-18T23:14:41.876Z" }, - { url = "https://files.pythonhosted.org/packages/f6/3b/6b9d5618720f63dbc2e2509cd6b57aae9c0d61b738d1d2172f4d5d9efaab/torchao-0.15.0-py3-none-any.whl", hash = "sha256:3f3812676048ef8a2a0e9d492d12d8971ba7a7ebb16f54aa56f690414e130d2c", size = 1080679, upload-time = "2025-12-18T23:14:43.807Z" }, -] - [[package]] name = "tornado" version = "6.5.4" From 2f096867244e3d00a50284d1be05fa3f5dcfb84b Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 10 Feb 2026 23:35:00 +0000 Subject: [PATCH 31/55] clarify that this is bf16 mfu we're talking about --- scripts/base_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index ee53098..996b2ba 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -531,7 +531,7 @@ while True: else: eta_str = "" epoch = dataloader_state_dict["epoch"] - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") if step % 100 == 0: log_data = { "step": step, From d9678ff0f9c5d9967512adce23cb60ea0a5cd3f3 Mon Sep 17 00:00:00 2001 From: Alan Date: Sun, 15 Feb 2026 14:31:54 +0000 Subject: [PATCH 32/55] Save FP8 tensors in autograd ctx instead of full-precision inputs Store quantized input/weight and their inverse scales in _Float8Matmul ctx to avoid re-quantization in backward and reduce saved-activation memory without changing numerics. --- nanochat/fp8.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 9d9e9c3..8649760 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -123,19 +123,16 @@ def _to_col_major(x): class _Float8Matmul(torch.autograd.Function): """Custom autograd for the three FP8 GEMMs of a Linear layer. - The forward saves input and weight in their original precision for the - backward pass. Each GEMM independently re-quantizes its operands to FP8. - (We don't reuse the forward's FP8 tensors in backward — the backward might - want different precision, and saving FP8 would lose information.) + The forward quantizes input and weight to FP8 and saves + the quantized tensors + scales for backward. """ @staticmethod def forward(ctx, input_2d, weight): - ctx.save_for_backward(input_2d, weight) - # Quantize both operands to e4m3 (higher precision format) input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn) weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn) + ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv) # output = input @ weight.T # input_fp8 is [B, K] contiguous = row-major (good for first arg) @@ -156,13 +153,12 @@ class _Float8Matmul(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): - input_2d, weight = ctx.saved_tensors + in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors # === GEMM 1: grad_input = grad_output @ weight === # Shapes: [B, N] @ [N, K] -> [B, K] # Gradients use e5m2 (wider range), weights use e4m3 (higher precision) go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2) - w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn) # go_fp8 is [B, N] contiguous = row-major, good for first arg # w_fp8 is [N, K] contiguous = row-major, need column-major for second arg w_col = _to_col_major(w_fp8) @@ -178,7 +174,6 @@ class _Float8Matmul(torch.autograd.Function): # === GEMM 2: grad_weight = grad_output.T @ input === # Shapes: [N, B] @ [B, K] -> [N, K] go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) - in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn) # go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg. # Transposing gives column-major, but first arg needs row-major, # so we must call .contiguous() to physically rearrange the memory. From 124f49be98e53bf734e2918dc58a580dbf31a80c Mon Sep 17 00:00:00 2001 From: Alan Date: Sun, 15 Feb 2026 15:41:33 +0000 Subject: [PATCH 33/55] Removed redundant qunatization of gradients --- nanochat/fp8.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 8649760..3e88285 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -173,16 +173,15 @@ class _Float8Matmul(torch.autograd.Function): # === GEMM 2: grad_weight = grad_output.T @ input === # Shapes: [N, B] @ [B, K] -> [N, K] - go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) - # go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg. + # go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg. # Transposing gives column-major, but first arg needs row-major, # so we must call .contiguous() to physically rearrange the memory. - go_T = go_fp8_2.t().contiguous() # [N, B] row-major + go_T = go_fp8.t().contiguous() # [N, B] row-major in_col = _to_col_major(in_fp8) # [B, K] column-major grad_weight = torch._scaled_mm( go_T, in_col, - scale_a=go_inv_2, + scale_a=go_inv, scale_b=in_inv, out_dtype=grad_output.dtype, use_fast_accum=False, From 788dadeb88282508283d0e152bf32af7d72a20e0 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Feb 2026 14:41:53 +0000 Subject: [PATCH 34/55] a number of upgrades to SFT script to bring it up to date w.r.t. pretraining and tuning some of its kwargs based on sweeps --- nanochat/checkpoint_manager.py | 19 ++++ scripts/base_train.py | 1 + scripts/chat_sft.py | 184 +++++++++++++++++++++++++-------- 3 files changed, 159 insertions(+), 45 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 5a95fbf..e24533a 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -170,3 +170,22 @@ def load_model(source, *args, **kwargs): base_dir = get_base_dir() checkpoints_dir = os.path.join(base_dir, model_dir) return load_model_from_dir(checkpoints_dir, *args, **kwargs) + +def load_optimizer_state(source, device, rank, model_tag=None, step=None): + """Load just the optimizer shard for a given rank, without re-loading the model.""" + model_dir = { + "base": "base_checkpoints", + "sft": "chatsft_checkpoints", + "rl": "chatrl_checkpoints", + }[source] + base_dir = get_base_dir() + checkpoints_dir = os.path.join(base_dir, model_dir) + if model_tag is None: + model_tag = find_largest_model(checkpoints_dir) + checkpoint_dir = os.path.join(checkpoints_dir, model_tag) + if step is None: + step = find_last_step(checkpoint_dir) + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") + log0(f"Loading optimizer state from {optimizer_path}") + optimizer_data = torch.load(optimizer_path, map_location=device) + return optimizer_data diff --git a/scripts/base_train.py b/scripts/base_train.py index 996b2ba..bb76e90 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -468,6 +468,7 @@ while True: "user_config": user_config, # inputs to the training script "device_batch_size": args.device_batch_size, "max_seq_len": args.max_seq_len, + "total_batch_size": total_batch_size, "dataloader_state_dict": dataloader_state_dict, "loop_state": { # all loop state (other than step) so that we can resume training "min_val_bpb": min_val_bpb, diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 4c81f06..edac3d8 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -9,6 +9,7 @@ Or torchrun for training: torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16 """ +import gc import argparse import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" @@ -16,12 +17,14 @@ import time import wandb import torch from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops from nanochat.tokenizer import get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint +from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state from nanochat.loss_eval import evaluate_bpb -from nanochat.checkpoint_manager import load_model import torch.distributed as dist +from nanochat.flash_attention import HAS_FA3 +from nanochat.engine import Engine +from scripts.chat_eval import run_chat_eval from tasks.common import TaskMixture from tasks.gsm8k import GSM8K @@ -37,27 +40,30 @@ parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the m parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") -parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") +parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") # Training horizon parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") -# Batch sizes -parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") -parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") -parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") -# Optimization -parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") -parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR") +# Batch sizes (default: inherit from pretrained checkpoint) +parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)") +parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)") +parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)") +# Optimization (default: inherit from pretrained checkpoint) +parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)") +parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)") +parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)") +parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR") +parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") +parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") +parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") # Evaluation -parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") -# Output -parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report") +parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)") +parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on") +parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)") +parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE") +parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -66,20 +72,48 @@ user_config = vars(args).copy() device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 -ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 +if device_type == "cuda": + gpu_device_name = torch.cuda.get_device_name(0) + gpu_peak_flops = get_peak_flops(gpu_device_name) + print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") +else: + gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) +# Flash Attention status +if not HAS_FA3: + print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.") + # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) -pretrain_batch_size = meta.get("device_batch_size", None) -if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: - print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") + +# Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override) +pretrain_user_config = meta.get("user_config", {}) +for name, fallback, source in [ + ("max_seq_len", 2048, meta), + ("device_batch_size", 32, meta), + ("total_batch_size", 524288, meta), + ("embedding_lr", 0.3, pretrain_user_config), + ("unembedding_lr", 0.004, pretrain_user_config), + ("matrix_lr", 0.02, pretrain_user_config), +]: + arg_val = getattr(args, name) + pretrain_val = source.get(name) + if arg_val is None: + resolved = pretrain_val if pretrain_val is not None else fallback + setattr(args, name, resolved) + print0(f"Inherited {name}={resolved} from pretrained checkpoint") + elif pretrain_val is not None and arg_val != pretrain_val: + print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}") + else: + print0(f"Using {name}={arg_val}") + orig_model = model model = torch.compile(model, dynamic=False) depth = model.config.n_layer @@ -94,14 +128,23 @@ print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation ste token_bytes = get_token_bytes(device=device) # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) -optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay) +# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero +optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0) + +# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) +base_dir = get_base_dir() +if args.load_optimizer: + optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step) + optimizer.load_state_dict(optimizer_data) + del optimizer_data + print0("Loaded optimizer state from pretrained checkpoint") + # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: group["lr"] = group["lr"] * args.init_lr_frac group["initial_lr"] = group["lr"] # SFT data mixture and DataLoader -base_dir = get_base_dir() identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") train_dataset = TaskMixture([ SmolTalk(split="train"), # 460K rows of general conversations @@ -236,10 +279,17 @@ train_loader = sft_data_generator_bos_bestfit("train") build_val_loader = lambda: sft_data_generator_bos_bestfit("val") progress = 0 # will go from 0 to 1 over the course of the epoch -# Learning rate scheduler +# Learning rate schedule (linear warmup, constant, linear warmdown) +# Same shape as base_train but uses progress (0→1) instead of absolute step counts, +# because SFT doesn't always know num_iterations in advance (dataset-driven stopping). def get_lr_multiplier(progress): - # first 80% of training: no decay, then linearly ramp down to 0. - return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 + if progress < args.warmup_ratio: + return (progress + 1e-8) / args.warmup_ratio + elif progress <= 1.0 - args.warmdown_ratio: + return 1.0 + else: + decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio + return (1 - decay) * 1.0 + decay * args.final_lr_frac # Momentum scheduler for Muon optimizer def get_muon_momentum(it): @@ -282,8 +332,44 @@ while True: }) model.train() - # save checkpoint at the end of the run (only on master process) - if master_process and last_step and not args.dry_run: + # once in a while: estimate the ChatCORE metric (all ranks participate) + # use the original uncompiled model because the inputs keep changing shape + chatcore_results = {} + if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)): + model.eval() + engine = Engine(orig_model, tokenizer) + all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee'] + categorical_tasks = {'ARC-Easy', 'ARC-Challenge', 'MMLU'} + baseline_accuracies = { + 'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25, + 'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0, + } + task_results = {} + for task_name in all_tasks: + limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample + max_problems = None if limit < 0 else limit # -1 means no limit + with autocast_ctx: + acc = run_chat_eval(task_name, orig_model, tokenizer, engine, + batch_size=args.device_batch_size, max_problems=max_problems) + task_results[task_name] = acc + print0(f" {task_name}: {100*acc:.2f}%") + # Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect) + def centered_mean(tasks): + return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks) + chatcore = centered_mean(all_tasks) + chatcore_cat = centered_mean(categorical_tasks) + print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}") + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "chatcore_metric": chatcore, + "chatcore_cat": chatcore_cat, + **{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()}, + }) + model.train() + + # save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard) + if last_step: output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12 checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname) save_checkpoint( @@ -304,7 +390,8 @@ while True: "window_pattern": model.config.window_pattern, }, "user_config": user_config, # inputs to the training script - } + }, + rank=ddp_rank, ) if last_step: @@ -346,8 +433,7 @@ while True: pct_done = 100 * progress tok_per_sec = int(args.total_batch_size / dt) flops_per_sec = num_flops_per_token * args.total_batch_size / dt - promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) if step > 10: total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m") @@ -364,24 +450,32 @@ while True: "train/epoch": current_epoch, }) + # The garbage collector spends ~500ms scanning for cycles quite frequently. + # We manually manage it to avoid these pauses during training. + if step == 1: + gc.collect() # manually collect a lot of garbage from setup + gc.freeze() # freeze all currently surviving objects and exclude them from GC + gc.disable() # disable GC entirely except: + elif step % 5000 == 0: # every 5000 steps... + gc.collect() # manually collect, just to be safe for very long runs + # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") # Log to report -if not args.dry_run: - from nanochat.report import get_report - get_report().log(section="SFT", data=[ - user_config, # CLI args - { # stats about the training setup - "Number of iterations": step, - "DDP world size": ddp_world_size, - }, - { # stats about training outcomes - "Minimum validation bpb": min_val_bpb, - } - ]) +from nanochat.report import get_report +get_report().log(section="SFT", data=[ + user_config, # CLI args + { # stats about the training setup + "Number of iterations": step, + "DDP world size": ddp_world_size, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb, + } +]) # cleanup wandb_run.finish() # wandb run finish From 8180e1d8c1c3e561b751dcfec54a74b3122c0db5 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Feb 2026 20:23:04 +0000 Subject: [PATCH 35/55] tune the data mixture a bit, load optimizer by default when SFT. These were confirmed to be best settings from sweeps of sft --- nanochat/checkpoint_manager.py | 3 +++ scripts/chat_sft.py | 33 +++++++++++++++++++++++---------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index e24533a..f71524e 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -186,6 +186,9 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None): if step is None: step = find_last_step(checkpoint_dir) optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") + if not os.path.exists(optimizer_path): + log0(f"Optimizer checkpoint not found: {optimizer_path}") + return None log0(f"Loading optimizer state from {optimizer_path}") optimizer_data = torch.load(optimizer_path, map_location=device) return optimizer_data diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index edac3d8..a783ed2 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -43,7 +43,7 @@ parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (e # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") -parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") +parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") # Training horizon parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") # Batch sizes (default: inherit from pretrained checkpoint) @@ -64,6 +64,9 @@ parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number o parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)") parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE") parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE") +# Data mixture +parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)") +parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -132,12 +135,21 @@ token_bytes = get_token_bytes(device=device) optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0) # Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) +# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the +# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and +# restore our fresh SFT LRs after loading. base_dir = get_base_dir() if args.load_optimizer: optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step) - optimizer.load_state_dict(optimizer_data) - del optimizer_data - print0("Loaded optimizer state from pretrained checkpoint") + if optimizer_data is not None: + base_lrs = [group["lr"] for group in optimizer.param_groups] + optimizer.load_state_dict(optimizer_data) + del optimizer_data + for group, base_lr in zip(optimizer.param_groups, base_lrs): + group["lr"] = base_lr + print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)") + else: + print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)") # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: @@ -146,16 +158,17 @@ for group in optimizer.param_groups: # SFT data mixture and DataLoader identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") -train_dataset = TaskMixture([ +train_tasks = [ SmolTalk(split="train"), # 460K rows of general conversations - MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE - GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use - GSM8K(subset="main", split="train"), # 2 epochs of GSM8K CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations - CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these + CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these + *[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch + *[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) -]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows +] +train_dataset = TaskMixture(train_tasks) +print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})") val_dataset = TaskMixture([ SmolTalk(split="test"), # 24K rows in test set MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios From 4a6e47b0c68aa062c62fa859aed2e8dd0d59d684 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 17 Feb 2026 15:44:54 +0000 Subject: [PATCH 36/55] update dev log with recent --- dev/LOG.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/dev/LOG.md b/dev/LOG.md index dec2c06..c0d35e4 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,38 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-17: Pretraining Data Mixture Experiment (negative) + +Tried [hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT](https://huggingface.co/datasets/hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT), a mixture of FinePDFs, DCLM, and FineWeb-EDU. Slightly worse on both model sizes tested: + +- d26 (GPT-2): CORE 0.2602 → 0.2549 +- d18: CORE 0.199 → 0.192 + +This is the fourth failed attempt to beat pure FineWeb-EDU on CORE score. + +--- + +## 2026-02-16: SFT Script Upgrades + +Brought `chat_sft.py` up to parity with `base_train.py` and tuned settings based on SFT sweeps. + +Tuning: + +- **Optimizer warm-start** (`--load-optimizer=1`, default on): loads pretrained momentum buffers via new `load_optimizer_state()` in `checkpoint_manager.py`. LRs are reset to fresh SFT values after load. Loading the optimizer works slightly better but not by too much. +- **LR schedule**: replaced "constant 80%, linear to 0" with warmup/constant/warmdown matching `base_train.py` (`--warmup-ratio`, `--warmdown-ratio`, `--init-lr-frac`, `--final-lr-frac`). Similar to pretraining, warmdown ratio of 0.5 worked the best. `--init-lr-frac` changed from 1.0 slightly lower to 0.8. +- **LR tuning**: attempted to tune all the individual LRs (e.g. does SFT prefer lower LR for embeddings? etc.) but all of this produced negative results. +- **Data mixture**: MMLU epochs 1→3, GSM8K epochs 2→4 (confirmed best from sweeps). Epoch counts now configurable via `--mmlu-epochs` / `--gsm8k-epochs`. Might remove these in the future though. + +Quality of life, footguns, minor fixes: + +- **Hyperparameter inheritance**: SFT now inherits batch sizes and LRs from the pretrained checkpoint metadata by default (CLI overrides still work). Also saved `total_batch_size` to `base_train.py` checkpoint metadata. +- **GC management**: disabled Python GC after step 1 to avoid ~500ms pauses (manual collect every 5000 steps), same as base pretraining. +- **ChatCORE eval**: periodic eval during SFT (`--chatcore-every=200`) across all 6 tasks, logged to wandb. +- **MFU**: uses `get_peak_flops()` for actual GPU instead of hardcoded H100 value. +- Removed `--dry-run` and `--dtype` flags. All ranks now participate in checkpoint save. + +--- + ## 2026-02-05: Auto Batch Size Scaling ### Background From 4800c62f6ed598accb950a8b715a6cab8a264e1e Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Wed, 18 Feb 2026 01:03:46 +0100 Subject: [PATCH 37/55] Fix MockModel's device definition (#535) * fix MockModel's device definition * cleanup --- tests/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_engine.py b/tests/test_engine.py index 0159111..784ffcb 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -31,7 +31,7 @@ class MockModel: def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens self.vocab_size = vocab_size self.config = MockConfig() - self._device = "cpu" + self._device = torch.device("cpu") def get_device(self): return self._device From 0a23f87643945410eb7c0e33951b5acfba05257c Mon Sep 17 00:00:00 2001 From: George Shakan <43767775+georgeshakan@users.noreply.github.com> Date: Wed, 18 Feb 2026 10:42:11 -0500 Subject: [PATCH 38/55] Fix bug in setting precision (#538) --- nanochat/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/common.py b/nanochat/common.py index 9bcd5dd..2dd0792 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -170,7 +170,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps # Precision if device_type == "cuda": - torch.backends.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls + torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() From 77f8fb83037d4bb294fb97f987f27c98526c1d96 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Feb 2026 14:41:53 +0000 Subject: [PATCH 39/55] a number of upgrades to SFT script to bring it up to date w.r.t. pretraining and tuning some of its kwargs based on sweeps --- nanochat/checkpoint_manager.py | 19 ++++ scripts/base_train.py | 1 + scripts/chat_sft.py | 184 +++++++++++++++++++++++++-------- 3 files changed, 159 insertions(+), 45 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 5a95fbf..e24533a 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -170,3 +170,22 @@ def load_model(source, *args, **kwargs): base_dir = get_base_dir() checkpoints_dir = os.path.join(base_dir, model_dir) return load_model_from_dir(checkpoints_dir, *args, **kwargs) + +def load_optimizer_state(source, device, rank, model_tag=None, step=None): + """Load just the optimizer shard for a given rank, without re-loading the model.""" + model_dir = { + "base": "base_checkpoints", + "sft": "chatsft_checkpoints", + "rl": "chatrl_checkpoints", + }[source] + base_dir = get_base_dir() + checkpoints_dir = os.path.join(base_dir, model_dir) + if model_tag is None: + model_tag = find_largest_model(checkpoints_dir) + checkpoint_dir = os.path.join(checkpoints_dir, model_tag) + if step is None: + step = find_last_step(checkpoint_dir) + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") + log0(f"Loading optimizer state from {optimizer_path}") + optimizer_data = torch.load(optimizer_path, map_location=device) + return optimizer_data diff --git a/scripts/base_train.py b/scripts/base_train.py index 996b2ba..bb76e90 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -468,6 +468,7 @@ while True: "user_config": user_config, # inputs to the training script "device_batch_size": args.device_batch_size, "max_seq_len": args.max_seq_len, + "total_batch_size": total_batch_size, "dataloader_state_dict": dataloader_state_dict, "loop_state": { # all loop state (other than step) so that we can resume training "min_val_bpb": min_val_bpb, diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 4c81f06..edac3d8 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -9,6 +9,7 @@ Or torchrun for training: torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16 """ +import gc import argparse import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" @@ -16,12 +17,14 @@ import time import wandb import torch from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops from nanochat.tokenizer import get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint +from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state from nanochat.loss_eval import evaluate_bpb -from nanochat.checkpoint_manager import load_model import torch.distributed as dist +from nanochat.flash_attention import HAS_FA3 +from nanochat.engine import Engine +from scripts.chat_eval import run_chat_eval from tasks.common import TaskMixture from tasks.gsm8k import GSM8K @@ -37,27 +40,30 @@ parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the m parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") -parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") +parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") # Training horizon parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") -# Batch sizes -parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") -parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") -parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") -# Optimization -parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") -parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR") +# Batch sizes (default: inherit from pretrained checkpoint) +parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)") +parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)") +parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)") +# Optimization (default: inherit from pretrained checkpoint) +parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)") +parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)") +parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)") +parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR") +parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") +parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") +parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") # Evaluation -parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") -# Output -parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report") +parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)") +parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on") +parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)") +parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE") +parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -66,20 +72,48 @@ user_config = vars(args).copy() device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 -ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 +if device_type == "cuda": + gpu_device_name = torch.cuda.get_device_name(0) + gpu_peak_flops = get_peak_flops(gpu_device_name) + print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") +else: + gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) +# Flash Attention status +if not HAS_FA3: + print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.") + # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) -pretrain_batch_size = meta.get("device_batch_size", None) -if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: - print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") + +# Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override) +pretrain_user_config = meta.get("user_config", {}) +for name, fallback, source in [ + ("max_seq_len", 2048, meta), + ("device_batch_size", 32, meta), + ("total_batch_size", 524288, meta), + ("embedding_lr", 0.3, pretrain_user_config), + ("unembedding_lr", 0.004, pretrain_user_config), + ("matrix_lr", 0.02, pretrain_user_config), +]: + arg_val = getattr(args, name) + pretrain_val = source.get(name) + if arg_val is None: + resolved = pretrain_val if pretrain_val is not None else fallback + setattr(args, name, resolved) + print0(f"Inherited {name}={resolved} from pretrained checkpoint") + elif pretrain_val is not None and arg_val != pretrain_val: + print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}") + else: + print0(f"Using {name}={arg_val}") + orig_model = model model = torch.compile(model, dynamic=False) depth = model.config.n_layer @@ -94,14 +128,23 @@ print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation ste token_bytes = get_token_bytes(device=device) # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) -optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay) +# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero +optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0) + +# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) +base_dir = get_base_dir() +if args.load_optimizer: + optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step) + optimizer.load_state_dict(optimizer_data) + del optimizer_data + print0("Loaded optimizer state from pretrained checkpoint") + # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: group["lr"] = group["lr"] * args.init_lr_frac group["initial_lr"] = group["lr"] # SFT data mixture and DataLoader -base_dir = get_base_dir() identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") train_dataset = TaskMixture([ SmolTalk(split="train"), # 460K rows of general conversations @@ -236,10 +279,17 @@ train_loader = sft_data_generator_bos_bestfit("train") build_val_loader = lambda: sft_data_generator_bos_bestfit("val") progress = 0 # will go from 0 to 1 over the course of the epoch -# Learning rate scheduler +# Learning rate schedule (linear warmup, constant, linear warmdown) +# Same shape as base_train but uses progress (0→1) instead of absolute step counts, +# because SFT doesn't always know num_iterations in advance (dataset-driven stopping). def get_lr_multiplier(progress): - # first 80% of training: no decay, then linearly ramp down to 0. - return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 + if progress < args.warmup_ratio: + return (progress + 1e-8) / args.warmup_ratio + elif progress <= 1.0 - args.warmdown_ratio: + return 1.0 + else: + decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio + return (1 - decay) * 1.0 + decay * args.final_lr_frac # Momentum scheduler for Muon optimizer def get_muon_momentum(it): @@ -282,8 +332,44 @@ while True: }) model.train() - # save checkpoint at the end of the run (only on master process) - if master_process and last_step and not args.dry_run: + # once in a while: estimate the ChatCORE metric (all ranks participate) + # use the original uncompiled model because the inputs keep changing shape + chatcore_results = {} + if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)): + model.eval() + engine = Engine(orig_model, tokenizer) + all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee'] + categorical_tasks = {'ARC-Easy', 'ARC-Challenge', 'MMLU'} + baseline_accuracies = { + 'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25, + 'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0, + } + task_results = {} + for task_name in all_tasks: + limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample + max_problems = None if limit < 0 else limit # -1 means no limit + with autocast_ctx: + acc = run_chat_eval(task_name, orig_model, tokenizer, engine, + batch_size=args.device_batch_size, max_problems=max_problems) + task_results[task_name] = acc + print0(f" {task_name}: {100*acc:.2f}%") + # Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect) + def centered_mean(tasks): + return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks) + chatcore = centered_mean(all_tasks) + chatcore_cat = centered_mean(categorical_tasks) + print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}") + wandb_run.log({ + "step": step, + "total_training_flops": flops_so_far, + "chatcore_metric": chatcore, + "chatcore_cat": chatcore_cat, + **{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()}, + }) + model.train() + + # save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard) + if last_step: output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12 checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname) save_checkpoint( @@ -304,7 +390,8 @@ while True: "window_pattern": model.config.window_pattern, }, "user_config": user_config, # inputs to the training script - } + }, + rank=ddp_rank, ) if last_step: @@ -346,8 +433,7 @@ while True: pct_done = 100 * progress tok_per_sec = int(args.total_batch_size / dt) flops_per_sec = num_flops_per_token * args.total_batch_size / dt - promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) if step > 10: total_training_time += dt # only count the time after the first 10 steps print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m") @@ -364,24 +450,32 @@ while True: "train/epoch": current_epoch, }) + # The garbage collector spends ~500ms scanning for cycles quite frequently. + # We manually manage it to avoid these pauses during training. + if step == 1: + gc.collect() # manually collect a lot of garbage from setup + gc.freeze() # freeze all currently surviving objects and exclude them from GC + gc.disable() # disable GC entirely except: + elif step % 5000 == 0: # every 5000 steps... + gc.collect() # manually collect, just to be safe for very long runs + # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") # Log to report -if not args.dry_run: - from nanochat.report import get_report - get_report().log(section="SFT", data=[ - user_config, # CLI args - { # stats about the training setup - "Number of iterations": step, - "DDP world size": ddp_world_size, - }, - { # stats about training outcomes - "Minimum validation bpb": min_val_bpb, - } - ]) +from nanochat.report import get_report +get_report().log(section="SFT", data=[ + user_config, # CLI args + { # stats about the training setup + "Number of iterations": step, + "DDP world size": ddp_world_size, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb, + } +]) # cleanup wandb_run.finish() # wandb run finish From 1415fb761797f94a4933c1a79f8d1fc2e63b9793 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Feb 2026 20:23:04 +0000 Subject: [PATCH 40/55] tune the data mixture a bit, load optimizer by default when SFT. These were confirmed to be best settings from sweeps of sft --- nanochat/checkpoint_manager.py | 3 +++ scripts/chat_sft.py | 33 +++++++++++++++++++++++---------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index e24533a..f71524e 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -186,6 +186,9 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None): if step is None: step = find_last_step(checkpoint_dir) optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") + if not os.path.exists(optimizer_path): + log0(f"Optimizer checkpoint not found: {optimizer_path}") + return None log0(f"Loading optimizer state from {optimizer_path}") optimizer_data = torch.load(optimizer_path, map_location=device) return optimizer_data diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index edac3d8..a783ed2 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -43,7 +43,7 @@ parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (e # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") -parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") +parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)") # Training horizon parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") # Batch sizes (default: inherit from pretrained checkpoint) @@ -64,6 +64,9 @@ parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number o parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)") parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE") parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE") +# Data mixture +parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)") +parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -132,12 +135,21 @@ token_bytes = get_token_bytes(device=device) optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0) # Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) +# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the +# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and +# restore our fresh SFT LRs after loading. base_dir = get_base_dir() if args.load_optimizer: optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step) - optimizer.load_state_dict(optimizer_data) - del optimizer_data - print0("Loaded optimizer state from pretrained checkpoint") + if optimizer_data is not None: + base_lrs = [group["lr"] for group in optimizer.param_groups] + optimizer.load_state_dict(optimizer_data) + del optimizer_data + for group, base_lr in zip(optimizer.param_groups, base_lrs): + group["lr"] = base_lr + print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)") + else: + print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)") # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: @@ -146,16 +158,17 @@ for group in optimizer.param_groups: # SFT data mixture and DataLoader identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") -train_dataset = TaskMixture([ +train_tasks = [ SmolTalk(split="train"), # 460K rows of general conversations - MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE - GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use - GSM8K(subset="main", split="train"), # 2 epochs of GSM8K CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations - CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these + CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these + *[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch + *[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) -]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows +] +train_dataset = TaskMixture(train_tasks) +print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})") val_dataset = TaskMixture([ SmolTalk(split="test"), # 24K rows in test set MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios From f5fe7925ed913fbddbc268043c79f82c354c43de Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 17 Feb 2026 15:44:54 +0000 Subject: [PATCH 41/55] update dev log with recent --- dev/LOG.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/dev/LOG.md b/dev/LOG.md index dec2c06..c0d35e4 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,38 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-17: Pretraining Data Mixture Experiment (negative) + +Tried [hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT](https://huggingface.co/datasets/hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT), a mixture of FinePDFs, DCLM, and FineWeb-EDU. Slightly worse on both model sizes tested: + +- d26 (GPT-2): CORE 0.2602 → 0.2549 +- d18: CORE 0.199 → 0.192 + +This is the fourth failed attempt to beat pure FineWeb-EDU on CORE score. + +--- + +## 2026-02-16: SFT Script Upgrades + +Brought `chat_sft.py` up to parity with `base_train.py` and tuned settings based on SFT sweeps. + +Tuning: + +- **Optimizer warm-start** (`--load-optimizer=1`, default on): loads pretrained momentum buffers via new `load_optimizer_state()` in `checkpoint_manager.py`. LRs are reset to fresh SFT values after load. Loading the optimizer works slightly better but not by too much. +- **LR schedule**: replaced "constant 80%, linear to 0" with warmup/constant/warmdown matching `base_train.py` (`--warmup-ratio`, `--warmdown-ratio`, `--init-lr-frac`, `--final-lr-frac`). Similar to pretraining, warmdown ratio of 0.5 worked the best. `--init-lr-frac` changed from 1.0 slightly lower to 0.8. +- **LR tuning**: attempted to tune all the individual LRs (e.g. does SFT prefer lower LR for embeddings? etc.) but all of this produced negative results. +- **Data mixture**: MMLU epochs 1→3, GSM8K epochs 2→4 (confirmed best from sweeps). Epoch counts now configurable via `--mmlu-epochs` / `--gsm8k-epochs`. Might remove these in the future though. + +Quality of life, footguns, minor fixes: + +- **Hyperparameter inheritance**: SFT now inherits batch sizes and LRs from the pretrained checkpoint metadata by default (CLI overrides still work). Also saved `total_batch_size` to `base_train.py` checkpoint metadata. +- **GC management**: disabled Python GC after step 1 to avoid ~500ms pauses (manual collect every 5000 steps), same as base pretraining. +- **ChatCORE eval**: periodic eval during SFT (`--chatcore-every=200`) across all 6 tasks, logged to wandb. +- **MFU**: uses `get_peak_flops()` for actual GPU instead of hardcoded H100 value. +- Removed `--dry-run` and `--dtype` flags. All ranks now participate in checkpoint save. + +--- + ## 2026-02-05: Auto Batch Size Scaling ### Background From cac43e851142289d565c2d22fdc9904ee8b62eb1 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Wed, 18 Feb 2026 01:03:46 +0100 Subject: [PATCH 42/55] Fix MockModel's device definition (#535) * fix MockModel's device definition * cleanup --- tests/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_engine.py b/tests/test_engine.py index 0159111..784ffcb 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -31,7 +31,7 @@ class MockModel: def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens self.vocab_size = vocab_size self.config = MockConfig() - self._device = "cpu" + self._device = torch.device("cpu") def get_device(self): return self._device From ad55575326443db6deda6e19126ebf136c66d8b2 Mon Sep 17 00:00:00 2001 From: George Shakan <43767775+georgeshakan@users.noreply.github.com> Date: Wed, 18 Feb 2026 10:42:11 -0500 Subject: [PATCH 43/55] Fix bug in setting precision (#538) --- nanochat/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/common.py b/nanochat/common.py index 9bcd5dd..2dd0792 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -170,7 +170,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps # Precision if device_type == "cuda": - torch.backends.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls + torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() From bac5a35dd74e331ed6012142e0b4e8c0f0af48e8 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 18 Feb 2026 23:17:29 +0000 Subject: [PATCH 44/55] fix minor bug in fp8 application to skip tiny matmuls --- scripts/base_train.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index bb76e90..24091b6 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -170,20 +170,22 @@ if args.fp8: # from torchao.float8 import Float8LinearConfig, convert_to_float8_training import torch.nn as nn - # Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement) + # Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough def fp8_module_filter(mod: nn.Module, fqn: str) -> bool: if not isinstance(mod, nn.Linear): return False - # FP8 requires both in_features and out_features divisible by 16 if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: return False + if min(mod.in_features, mod.out_features) < 128: + return False return True fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe) + num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter) - num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) - num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers - print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)") + num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) + num_skipped = num_linear - num_fp8 + print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)") # Context manager to temporarily disable FP8 so that model evaluation remains in BF16 @contextmanager From bb5137860e24efa995b60468e7b867206ae9dd5c Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 18 Feb 2026 23:26:22 +0000 Subject: [PATCH 45/55] fix comment --- runs/speedrun.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runs/speedrun.sh b/runs/speedrun.sh index 62466c7..c757253 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -69,7 +69,7 @@ python -m scripts.tok_eval echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID -# d24 model (slightly overtrained is enough to beat GPT-2 => increase data:params ratio from compute optimal 10.5 (default) to 12) +# d26 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8.25) torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --target-param-data-ratio=8.25 --device-batch-size=16 --fp8 --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 From 48804bff3a487e43ee1e1533b3cfa0aa5ab0028f Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 18 Feb 2026 23:45:31 +0000 Subject: [PATCH 46/55] report negative result on fineweb dataset --- dev/LOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/dev/LOG.md b/dev/LOG.md index c0d35e4..6ac027c 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,16 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-17: Pretraining Data: FineWeb (negative) + +Tried vanilla fineweb instead of fineweb-edu dataset. Significantly, shockingly worse results: + +- d26 (GPT-2): CORE 0.2602 → 0.2241 + +This is the fifth failed attempt to beat pure FineWeb-EDU on CORE score. + +--- + ## 2026-02-17: Pretraining Data Mixture Experiment (negative) Tried [hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT](https://huggingface.co/datasets/hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT), a mixture of FinePDFs, DCLM, and FineWeb-EDU. Slightly worse on both model sizes tested: From 2dffdc8cf6953c5dc10f1caf37016e9daa675b09 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 19 Feb 2026 02:53:47 +0000 Subject: [PATCH 47/55] document MoE exploration --- dev/LOG.md | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/dev/LOG.md b/dev/LOG.md index 6ac027c..0dfaa98 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,53 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-02-19: Mixture of Experts (negative) + +Implemented a DeepSeekV3-style Mixture of Experts layer as a drop-in replacement for the dense MLP. The MoE branch works and improves per-step validation loss, but is not a net improvement on wall clock time due to MoE overhead (at least for our scale of interest of approx GPT-2 capability). + +### Implementation + +Follows DeepSeekV3 and using torchtitan as reference: + +- **8 routed experts, top-2 routing** with sigmoid gating (not softmax) +- **1 shared expert** (dense MLP processing all tokens, following DeepSeekV3) +- **Auxiliary-loss-free load balancing** (DeepSeekV3's expert bias nudging) +- **Iso-FLOP sizing**: `expert_hidden_dim = round(4 * dim / (top_k + num_shared) / 128) * 128`, so active FLOPs per token match the dense MLP +- **`torch._grouped_mm`** for dispatching tokens to experts in a single kernel (instead of a Python for-loop) +- **3D expert weight tensors** `(num_experts, hidden, dim)` — Muon's Polar Express operates on the last two dims, so each expert is independently orthogonalized +- **Active parameter counting** for scaling laws (only `top_k + shared` experts, not all 8) + +### What was easy + +- The core MoE forward pass: router, sort tokens by expert, grouped matmul, scatter back. Conceptually clean. +- Shared expert: just an `nn.Linear` MLP that runs on all tokens alongside the routed path. +- 3D expert params + Muon: only required fixing `second_momentum_buffer` shape to preserve leading dims. +- Load balancing: DeepSeekV3's bias nudging is simple and effective (~10 lines). + +### What was hard / ugly + +- **`torch._grouped_mm` quirks**: requires bf16 (not fp32), column-major right operand, int32 cumulative offsets. The API is undocumented and only discoverable by trial and error. +- **Token count padding**: torchtitan pads each expert's token count to alignment multiples (8 for bf16) for better grouped_mm throughput. We implemented this with both a pure PyTorch approach and a copy of torchtitan's Triton kernel. Both compiled cleanly (0 graph breaks), but with ~65K tokens across 8 experts, each expert already gets ~8K tokens which is well-aligned. The padding overhead (gather/scatter) actually regressed MFU from 35% to 33%. Reverted. +- **FP8 + MoE**: `torch._grouped_mm` does NOT support FP8. There's a separate `torch._scaled_grouped_mm` API that requires per-row scaling (not per-tensor like our `Float8Linear`). The backward pass for weight gradients needs per-group column-wise scaling, which torchao implements with custom Triton kernels. We investigated thoroughly (see `dev/moe_fp8.md`) but did not implement — would require either depending on `torchao.prototype` (unstable) or writing ~200 lines of custom autograd + quantization code. Partial FP8 support exists: the shared expert's `nn.Linear` layers do get converted, but the routed experts (3D `nn.Parameter`) stay in bf16. + +### Results + +- d18: MFU dropped from ~46% to ~35% (the grouped_mm dispatch + token sorting overhead is significant) +- Per-step improvement in validation loss does not compensate for the throughput hit +- Net negative on wall clock time + +### What remains (if revisited) + +- **FP8 for routed experts**: Use `torch._scaled_grouped_mm` with a custom `_Float8GroupedMatmul` autograd function, with bf16 fallback for weight gradient (avoiding the per-group column-wise Triton kernels). + +What's really needed is a fused "FlashMoE" kernel that handles routing + expert dispatch + matmul in one shot (like FlashAttention did for attention), with all the needed features. This doesn't exist yet. Rawdogging MoE with current PyTorch primitives is painful — lots of sorting, gathering, scattering, and layout wrangling around the actual compute. + +### Verdict + +MoE is not worth the trouble for nanochat right now. The code bloat is substantial (moe.py, router, shared expert, load balancing, optimizer fixes, FP8 gaps, active param counting) and the performance is worse wall-clock at our scale of interest. The fundamental issue is that the grouped_mm dispatch overhead eats the FLOP savings from sparsity, at least at our model scales and sequence lengths. + +--- + ## 2026-02-17: Pretraining Data: FineWeb (negative) Tried vanilla fineweb instead of fineweb-edu dataset. Significantly, shockingly worse results: From c7ba25214276d165eeefca7cb2060587975db189 Mon Sep 17 00:00:00 2001 From: Dipesh Babu <59379458+dipeshbabu@users.noreply.github.com> Date: Fri, 20 Feb 2026 11:03:45 -0500 Subject: [PATCH 48/55] docs: fix typos in experiment log (#547) --- dev/LOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/LOG.md b/dev/LOG.md index 0dfaa98..fce90fd 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -749,7 +749,7 @@ See the branch `fp8_attempt_fail` for: ### Open Questions - Why does the custom op approach use more memory than vanilla BF16? -- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Ahmdal's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized. +- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Amdahl's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized. **Conclusion:** Negative result for now. The implementation works correctly but provides marginal speedup with *increased* memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao. @@ -913,7 +913,7 @@ Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon i - Now defaults to ON for Muon via the `weight_decay` param. AdamW still has no weight decay and is hardcoded to 0 weight decay, might try to re-tune this later. **4. Weight decay schedule** -- Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, them ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is imlpemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule). +- Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, then ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is implemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule). ### Weight Decay Scaling Experiments @@ -957,6 +957,6 @@ Muon was changed to use Polar Express, added NorMuon variance reduction, and cau **Bug Found:** Original implementation clipped local gradients before sync. Since this codebase doesn't use DDP (gradient sync is in the optimizers), each rank was clipping based on its own local norm. Fixed on the branch with proper distributed all-reduce. -**Observartion:** modded-nanogpt does not appear to clip either right now. +**Observation:** modded-nanogpt does not appear to clip either right now. **Summary:** Deleted all grad-clip code paths. The code naturally produces well-behaved gradients. This improves a bit of MFU because we don't have to calculate and sync grad norms. From 83dccc20aeab357ae4bb30d9c2ce938763efa929 Mon Sep 17 00:00:00 2001 From: Anish <12670807+gpu-poor@users.noreply.github.com> Date: Tue, 3 Mar 2026 06:07:47 +0530 Subject: [PATCH 49/55] Restore completion-only loss masking in SFT dataloader (#582) * printing steps count * adding reply only loss for chat * using the mask by render_conversation function of tokeniser * undoing some changes * putting back the comment which got removed accidently, no functionality change --- scripts/chat_sft.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index a783ed2..f31a2d3 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -197,7 +197,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): row_capacity = args.max_seq_len + 1 # +1 for target at last position bos_token = tokenizer.get_bos_token_id() - # Conversation buffer: list of token lists + # Conversation buffer: list of (token_ids, loss_mask) tuples conv_buffer = [] cursor = ddp_rank # Each rank processes different conversations (for fetching) consumed = ddp_rank # Track actual consumption separately from buffering @@ -208,8 +208,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): nonlocal cursor, epoch while len(conv_buffer) < buffer_size: conversation = dataset[cursor] - ids, _ = tokenizer.render_conversation(conversation) - conv_buffer.append(ids) + ids, mask = tokenizer.render_conversation(conversation) + conv_buffer.append((ids, mask)) cursor += ddp_world_size if cursor >= dataset_size: cursor = cursor % dataset_size @@ -218,9 +218,11 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): while True: rows = [] + mask_rows = [] row_lengths = [] # Track actual content length (excluding padding) for each row for _ in range(args.device_batch_size): row = [] + mask_row = [] padded = False while len(row) < row_capacity: # Ensure buffer has conversations @@ -232,7 +234,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): # Find largest conversation that fits entirely best_idx = -1 best_len = 0 - for i, conv in enumerate(conv_buffer): + for i, (conv, _) in enumerate(conv_buffer): conv_len = len(conv) if conv_len <= remaining and conv_len > best_len: best_idx = i @@ -240,14 +242,16 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): if best_idx >= 0: # Found a conversation that fits - use it entirely - conv = conv_buffer.pop(best_idx) + conv, conv_mask = conv_buffer.pop(best_idx) row.extend(conv) + mask_row.extend(conv_mask) consumed += ddp_world_size # Track actual consumption else: # No conversation fits - pad the remainder instead of cropping # This ensures we never discard any tokens content_len = len(row) row.extend([bos_token] * remaining) # Pad with BOS tokens + mask_row.extend([0] * remaining) padded = True break # Row is now full (with padding) @@ -257,6 +261,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): else: row_lengths.append(row_capacity) rows.append(row[:row_capacity]) + mask_rows.append(mask_row[:row_capacity]) # Stopping condition to respect num_iterations, if given it += 1 @@ -280,6 +285,13 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) + # Apply the loss mask from render_conversation (mask=1 for assistant completions, + # mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns + # with targets (shifted by 1). Unmasked positions get -1 (ignore_index). + mask_tensor = torch.tensor(mask_rows, dtype=torch.int8) + mask_targets = mask_tensor[:, 1:].to(device=device) + targets[mask_targets == 0] = -1 + # Mask out padding positions in targets (set to -1 = ignore_index) # For each row, positions >= (content_length - 1) in targets should be masked for i, content_len in enumerate(row_lengths): From aba30cb037bfc010b8e85d0db4f2273a1815100d Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 2 Mar 2026 18:19:37 +0000 Subject: [PATCH 50/55] tune logit softcap? --- dev/LOG.md | 4 ++++ nanochat/gpt.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/dev/LOG.md b/dev/LOG.md index fce90fd..b6e83ef 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,10 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-03-02: SoftCap tuning + +Quick experiment to tune logit softcap on d24 scale. Tried 5..30. 5 was terrible, the rest of them were all about equal with the exception of 20, which was the best. Minor but solid improvement: val loss improved by ~1e-3 (0.716 -> 0.715). Setting as default. + ## 2026-02-19: Mixture of Experts (negative) Implemented a DeepSeekV3-style Mixture of Experts layer as a drop-in replacement for the dense MLP. The MoE branch works and improves per-step validation loss, but is not a net improvement on wall clock time due to MoE overhead (at least for our scale of interest of approx GPT-2 capability). diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..74e39fd 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -407,7 +407,7 @@ class GPT(nn.Module): x = norm(x) # Forward the lm_head (compute logits) - softcap = 15 # smoothly cap the logits to the range [-softcap, softcap] + softcap = 20 # smoothly cap the logits to the range [-softcap, softcap] logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory logits = logits[..., :self.config.vocab_size] # slice to remove padding logits = logits.float() # switch to fp32 for logit softcap and loss computation From b07604ebaa6dffa9e21e0d2d7d0a5980301c6364 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 3 Mar 2026 17:24:31 +0000 Subject: [PATCH 51/55] document the legacy fineweb100b dataset and the new climbmix400b dataset --- dev/repackage_data_reference.py | 58 ++++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/dev/repackage_data_reference.py b/dev/repackage_data_reference.py index 32980a8..0ec61b4 100644 --- a/dev/repackage_data_reference.py +++ b/dev/repackage_data_reference.py @@ -1,5 +1,5 @@ """ -Repackage the FinewebEdu-100B dataset into shards: +Repackage a given dataset into simple parquet shards: - each shard is ~100MB in size (after zstd compression) - parquets are written with row group size of 1000 @@ -10,6 +10,16 @@ The big deal is that our DataLoader will be able to stream the data and cache it along the way on disk, decreasing the training latency. +Historical context: +Originally, nanochat used the FinewebEdu-100B dataset. +Then we switched to the ClimbMix-400B dataset due to superior performance. +This script documents how both were prepared. + +The outputs are here: + +https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle +https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle + NOTE: This file is meant only as reference/documentation of the dataset preparation and it is not used during the project runtime. """ @@ -20,12 +30,37 @@ from datasets import load_dataset import pyarrow.parquet as pq import pyarrow as pa +# You can change these: +dataset_tag = "climbmix" +upload_to_hf = True + +# Dataset configurations: +if dataset_tag == "fineweb_edu": + dataset_kwargs = { + "path": "HuggingFaceFW/fineweb-edu", + "split": "train", + "name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total + } + output_dirname = "fineweb_edu" + data_column_name = "text" + tokenizer = None + upload_tag = "fineweb-edu-100b-shuffle" + +elif dataset_tag == "climbmix": + import tiktoken # the ClimbMix data is stored tokenized with GPT-2 tokenizer + dataset_kwargs = { + "path": "nvidia/Nemotron-ClimbMix", + "split": "train", + } + output_dirname = "climbmix" + data_column_name = "tokens" + tokenizer = tiktoken.encoding_for_model("gpt-2") + upload_tag = "climbmix-400b-shuffle" + +else: + raise ValueError(f"Unknown dataset tag: {dataset_tag}") + # Source dataset -dataset_kwargs = { - "path": "HuggingFaceFW/fineweb-edu", - "split": "train", - "name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total -} ds = load_dataset(**dataset_kwargs) # Shuffle to scramble the order @@ -34,7 +69,7 @@ ndocs = len(ds) # total number of documents to process print(f"Total number of documents: {ndocs}") # Repackage into parquet files -output_dir = "/home/ubuntu/.cache/nanochat/base_data" +output_dir = f"/home/ubuntu/.cache/nanochat/base_data_{output_dirname}" os.makedirs(output_dir, exist_ok=True) # Write to parquet files @@ -47,7 +82,8 @@ total_docs_processed = 0 total_time_spent = 0 t0 = time.time() for doc in ds: - text = doc['text'] + data = doc[data_column_name] + text = tokenizer.decode(data) if tokenizer is not None else data shard_docs.append(text) shard_characters += len(text) collected_enough_chars = shard_characters >= chars_per_shard @@ -79,14 +115,12 @@ for doc in ds: shard_index += 1 # Demonstration of how the data was later uploaded to HuggingFace -def upload(): - import os +if upload_to_hf: from huggingface_hub import HfApi token = os.getenv("HF_TOKEN") api = HfApi(token=token) api.upload_large_folder( folder_path=output_dir, - repo_id="karpathy/fineweb-edu-100b-shuffle", + repo_id=f"karpathy/{upload_tag}", repo_type="dataset", ) -# upload() From 324e69c45d3606095adb6b409078647145165454 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 4 Mar 2026 19:47:12 +0000 Subject: [PATCH 52/55] big, breaking change but large upside: swap previous FineWeb-EDU dataset to NVIDIA ClimbMix dataset. Requires people to download the data shards. The upside is that training GPT-2 capablity model now only takes ~2 hours, down from 2.76 hours, so this is a huge win data-wise --- dev/LOG.md | 23 +++++++++++++++++++ nanochat/dataloader.py | 3 ++- nanochat/dataset.py | 50 ++++++++++++++++++++++++++++++++++-------- runs/speedrun.sh | 10 ++++----- scripts/base_train.py | 4 ++-- 5 files changed, 73 insertions(+), 17 deletions(-) diff --git a/dev/LOG.md b/dev/LOG.md index b6e83ef..b4b3757 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,29 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-03-04: Dataset upgrade: FineWeb-EDU 100B → ClimbMix 400B + +Switched the pretraining dataset from FineWeb-EDU 100B to ClimbMix 400B. This is by far the single biggest improvement to nanochat's GPT-2 speedrun time, bringing it down from **2 hours 46 minutes to 2 hours 1 minute** — a 27% reduction. + +### What is ClimbMix? + +ClimbMix 400B is a curated 400B-token pretraining mixture hosted at `karpathy/climbmix-400b-shuffle` on HuggingFace. It comes form [NVIDIA](https://huggingface.co/datasets/nvidia/Nemotron-ClimbMix). It is a blend of high-quality web text, code, math, and other sources, designed to be a better general-purpose pretraining dataset than FineWeb-EDU alone. + +### What changed + +- **Dataset**: `karpathy/fineweb-edu-100b-shuffle` → `karpathy/climbmix-400b-shuffle` (up to 6543 shards available vs the previous 1823 data shards, allowing for longer training in the future) +- **Data directory**: `base_data/` → `base_data_climbmix/` (clean separation from legacy data) +- **Model depth**: d26 → d24. ClimbMix trains more efficiently, so a smaller model reaches GPT-2 capability +- **Shard count**: Only approx 150 data shards (~7B tokens) are now needed for GPT-2 capability +- **Eval tokens**: doubled from 40 to 80 batches for more stable validation loss estimates +- **Legacy fallback**: added a migration warning in `list_parquet_files()` that detects the old `base_data/` directory and falls back gracefully, so existing users see clear upgrade instructions on `git pull` + +### Context + +This is the sixth attempt at beating FineWeb-EDU on CORE score — the previous five all failed (see entries on 2026-02-17, 2026-02-10, 2026-01-12 below). ClimbMix is the first dataset to convincingly surpass it, and the margin is large enough to also shrink the model from d26 to d24. + +--- + ## 2026-03-02: SoftCap tuning Quick experiment to tune logit softcap on d24 scale. Tried 5..30. 5 was terrible, the rest of them were all about equal with the exception of 20, which was the best. Minor but solid improvement: val loss improved by ~1e-3 (0.716 -> 0.715). Setting as default. diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 125625f..4cb2279 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -32,7 +32,8 @@ def _document_batches(split, resume_state_dict, tokenizer_batch_size): """ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() - parquet_paths = list_parquet_files() + warn_on_legacy = ddp_rank == 0 and split == "train" # rank 0 on train split will warn on legacy + parquet_paths = list_parquet_files(warn_on_legacy=warn_on_legacy) assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?" parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] diff --git a/nanochat/dataset.py b/nanochat/dataset.py index 602daed..fffe722 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -20,19 +20,43 @@ from nanochat.common import get_base_dir # The specifics of the current pretraining dataset # The URL on the internet where the data is hosted and downloaded from on demand -BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main" -MAX_SHARD = 1822 # the last datashard is shard_01822.parquet +BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main" +MAX_SHARD = 6542 # the last datashard is shard_06542.parquet index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames base_dir = get_base_dir() -DATA_DIR = os.path.join(base_dir, "base_data") -os.makedirs(DATA_DIR, exist_ok=True) +DATA_DIR = os.path.join(base_dir, "base_data_climbmix") # ----------------------------------------------------------------------------- # These functions are useful utilities to other modules, can/should be imported -def list_parquet_files(data_dir=None): +def list_parquet_files(data_dir=None, warn_on_legacy=False): """ Looks into a data dir and returns full paths to all parquet files. """ data_dir = DATA_DIR if data_dir is None else data_dir + + # Legacy-supporting code due to the upgrade from FinewebEdu-100B to ClimbMix-400B + # This code will eventually be deleted. + if not os.path.exists(data_dir): + if warn_on_legacy: + print() + print("=" * 80) + print(" WARNING: DATASET UPGRADE REQUIRED") + print("=" * 80) + print() + print(f" Could not find: {data_dir}") + print() + print(" nanochat recently switched from FinewebEdu-100B to ClimbMix-400B.") + print(" Everyone who does `git pull` as of March 4, 2026 is expected to see this message.") + print(" To upgrade to the new ClimbMix-400B dataset, run these two commands:") + print() + print(" python -m nanochat.dataset -n 170 # download ~170 shards, enough for GPT-2, adjust as desired") + print(" python -m scripts.tok_train # re-train tokenizer on new ClimbMix data") + print() + print(" For now, falling back to your old FinewebEdu-100B dataset...") + print("=" * 80) + print() + # attempt a fallback to the legacy data directory + data_dir = os.path.join(base_dir, "base_data") + parquet_files = sorted([ f for f in os.listdir(data_dir) if f.endswith('.parquet') and not f.endswith('.tmp') @@ -110,13 +134,21 @@ def download_single_file(index): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards") - parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable") + parser = argparse.ArgumentParser(description="Download pretraining dataset shards") + parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of train shards to download (default: -1), -1 = disable") parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)") args = parser.parse_args() - num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) - ids_to_download = list(range(num)) + # Prepare the output directory + os.makedirs(DATA_DIR, exist_ok=True) + + # The way this works is that the user specifies the number of train shards to download via the -n flag. + # In addition to that, the validation shard is *always* downloaded and is pinned to be the last shard. + num_train_shards = MAX_SHARD if args.num_files == -1 else min(args.num_files, MAX_SHARD) + ids_to_download = list(range(num_train_shards)) + ids_to_download.append(MAX_SHARD) # always download the validation shard + + # Download the shards print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") print(f"Target directory: {DATA_DIR}") print() diff --git a/runs/speedrun.sh b/runs/speedrun.sh index c757253..fa50694 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -55,9 +55,9 @@ python -m nanochat.report reset # look at dev/repackage_data_reference.py for details on how this data was prepared python -m nanochat.dataset -n 8 # Immediately also kick off downloading more shards in the background while tokenizer trains -# Approximately 350 shards are needed for 10B tokens of data for pretraining. -# The maximum total number of shards available in the entire dataset is 1822. -python -m nanochat.dataset -n 370 & +# Approximately 150 shards are needed for GPT-2 capability pretraining, add 20 for padding. +# The maximum total number of shards available in the entire dataset is 6542. +python -m nanochat.dataset -n 170 & DATASET_DOWNLOAD_PID=$! # train the tokenizer with vocab size 2**15 = 32768 on ~2B characters of data python -m scripts.tok_train @@ -69,8 +69,8 @@ python -m scripts.tok_eval echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID -# d26 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8.25) -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --target-param-data-ratio=8.25 --device-batch-size=16 --fp8 --run=$WANDB_RUN +# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 9.5) +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=9.5 --device-batch-size=16 --fp8 --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 diff --git a/scripts/base_train.py b/scripts/base_train.py index 24091b6..9461e88 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -71,7 +71,7 @@ parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR a parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") # Evaluation parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on") +parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on") parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") @@ -533,7 +533,7 @@ while True: eta_str = f" | eta: {eta_seconds/60:.1f}m" else: eta_str = "" - epoch = dataloader_state_dict["epoch"] + epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}" print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") if step % 100 == 0: log_data = { From 4b4077425b036392b6905026c2f1024f9653d063 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 4 Mar 2026 20:02:07 +0000 Subject: [PATCH 53/55] Document new Leaderboard entry congrats @ddudek for pointing out ClimbMix, time to GPT-2 is now 2.01 hours, down from 2.76 previously --- README.md | 5 +++-- dev/LEADERBOARD.md | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1894ac8..05c7942 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ ![nanochat logo](dev/nanochat.png) ![scaling laws](dev/scaling_laws_jan26.png) -nanochat is the simplest experimental harness for training LLMs. It is designed to run on a single GPU node, the code is minimal/hackable, and it covers all major LLM stages including tokenization, pretraining, finetuning, evaluation, inference, and a chat UI. For example, you can train your own GPT-2 capability LLM (which cost ~$43,000 to train in 2019) for only $72 (~3 hours of 8XH100 GPU node) and then talk to it in a familiar ChatGPT-like web UI. On a spot instance, the total cost can be closer to ~$20. More generally, nanochat is configured out of the box to train an entire miniseries of compute-optimal models by setting one single complexity dial: `--depth`, the number of layers in the GPT transformer model (GPT-2 capability happens to be approximately depth 26). All other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) are calculated automatically in an optimal way. +nanochat is the simplest experimental harness for training LLMs. It is designed to run on a single GPU node, the code is minimal/hackable, and it covers all major LLM stages including tokenization, pretraining, finetuning, evaluation, inference, and a chat UI. For example, you can train your own GPT-2 capability LLM (which cost ~$43,000 to train in 2019) for only $48 (~2 hours of 8XH100 GPU node) and then talk to it in a familiar ChatGPT-like web UI. On a spot instance, the total cost can be closer to ~$15. More generally, nanochat is configured out of the box to train an entire miniseries of compute-optimal models by setting one single complexity dial: `--depth`, the number of layers in the GPT transformer model (GPT-2 capability happens to be approximately depth 26). All other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) are calculated automatically in an optimal way. For questions about the repo, I recommend either using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions about the repo, or use the [Discussions tab](https://github.com/karpathy/nanochat/discussions), or come by the [#nanochat](https://discord.com/channels/1020383067459821711/1427295580895314031) channel on Discord. @@ -17,8 +17,9 @@ Presently, the main focus of development is on tuning the pretraining stage, whi | 1 | 3.04 | 0.74833 | 0.2585 | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy | | 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | a67eba3 | @karpathy | | 3 | 2.76 | 0.74645 | 0.2602 | bump total batch size to 1M tokens | Feb 5 2026 | 2c062aa | @karpathy | +| 4 | 2.02 | 0.71854 | 0.2571 | change dataset to NVIDIA ClimbMix | Mar 4 2026 | 324e69c | @ddudek @karpathy | -The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $43,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 3 hours is ~$72). +The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $43,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 2 hours is ~$48). See [dev/LEADERBOARD.md](dev/LEADERBOARD.md) for more docs on how to interpret and contribute to the leaderboard. diff --git a/dev/LEADERBOARD.md b/dev/LEADERBOARD.md index b8a727f..556ec3c 100644 --- a/dev/LEADERBOARD.md +++ b/dev/LEADERBOARD.md @@ -147,3 +147,47 @@ Minimum validation bpb: 0.74645 ``` The big change here is that the batch size was doubled from 0.5M to 1M, which works better for a d26 model and allowed me to decrease the number of optimization steps a bit via `--target-param-data-ratio` from 8.5 to 8.25. The TLDR is that the original batch size of 0.5M was tuned for d12, but bigger models (e.g. d26) prefer larger total batch size. I determined in experiments that d26 prefers 1M. Then I implemented and merged a principled way to calculate the optimal batch size given depth so that all nanochat models of all depths benefit. See [dev/LOG.md](dev/LOG.md) entry "2026-02-05: Auto Batch Size Scaling" for more detail. + +## Run 4 + +Achived Mar 3 2026 on commit `324e69c`. The big change is the switch from HuggingFace FineWeb-EDU to NVIDIA ClimbMix dataset. `@karpathy` has tried to swap the dataset many times, each time with a negative result (FineWeb, DCLM, Olmo), but ClimbMix produced clear and immediate gains. Credit to `@ddudek` for originally discovering ClimbMix for nanochat and reporting the improvements, which kicked off the followup investigation. + +To reproduce, use the commit above, download at least 150 data shards, train the tokenizer: + +``` +python -m nanochat.dataset -n 150 +python -m scripts.tok_train +``` + +Then kick off the run in the typical way, using a slightly lower than compute optimal ratio of 9.5 (vs compute optimal 10.5), meaning the d24 is slightly undertrained. + +``` +OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --depth=24 \ + --run="d24-climbmix" \ + --model-tag="d24-climbmix" \ + --sample-every=-1 \ + --save-every=-1 \ + --core-metric-max-per-task=-1 \ + --core-metric-every=999999 \ + --target-param-data-ratio=9.5 \ + --device-batch-size=16 \ + --fp8 +``` + +I ran this command 7 individual times. Because our training is mildly non-deterministic, we get a spread of CORE scores, e.g.: + +``` +0.25373 +0.2584 +0.25489 +0.2568 +0.25732 +0.26765 +0.25119 +``` + +Mean is 0.25714 (higher than the GPT-2 threshold needed), max-min is 0.01646. Something to investigate in the future is that even slightly better results can be obtained by randomly shuffling the the data shards (i.e. just going in a different order). This is unexpected because the documents were completely fully shuffled during data construction, so one would expect a relatively uniform data distribution. Indeed, the current default order is unfortunately among the worse ("unlucky") ones you can obtain with different shuffle seeds, but it suffices to beat GPT-2 for now so I am merging. TODO investing a bit more later. + +NOTE: The `val_bpb` is as of this run *NOT* comparable due to the data distribution change to the previous 3 runs. This run happens to be at `0.71854` validation bpb. If the dataset is not changed, the `val_bpb` number is a great, smooth metric to track relative performance w.r.t. and has less noise than CORE. + From 752abc836e7075d3a799e2af8f82bdb2456c60cc Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Wed, 4 Mar 2026 22:58:27 +0100 Subject: [PATCH 54/55] Ensure that inputs and targets are contiguous (#569) * call reshape instead of view in case the tensors are not contiguous * fix directly in data loader instead --- scripts/chat_sft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index f31a2d3..cb9e078 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -282,8 +282,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100): # Build tensors use_cuda = device_type == "cuda" batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) - inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) - targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) + inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous() + targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous() # Apply the loss mask from render_conversation (mask=1 for assistant completions, # mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns From 1076f97059785ed6d763706bf2304ce7721ab75c Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 4 Mar 2026 23:55:24 +0000 Subject: [PATCH 55/55] delete autocast, an unnecessary thorn in my side, manage dtypes directly --- README.md | 21 ++++++++++++ dev/LOG.md | 35 +++++++++++++++++++ nanochat/common.py | 20 +++++++++++ nanochat/engine.py | 23 +++++-------- nanochat/flash_attention.py | 20 +++++++---- nanochat/fp8.py | 12 +++---- nanochat/gpt.py | 45 +++++++++++++++---------- scripts/base_eval.py | 16 +++------ scripts/base_train.py | 55 +++++++++++++++++++++--------- scripts/chat_cli.py | 15 +++------ scripts/chat_eval.py | 30 +++++++---------- scripts/chat_rl.py | 30 ++++++----------- scripts/chat_sft.py | 36 +++++++++++++------- scripts/chat_web.py | 58 ++++++++++++++------------------ tests/test_attention_fallback.py | 9 ++--- 15 files changed, 258 insertions(+), 167 deletions(-) diff --git a/README.md b/README.md index 05c7942..077fd9c 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,27 @@ The important thing to note is that nanochat is written and configured around on The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM that is being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way. +## Precision / dtype + +nanochat does not use `torch.amp.autocast`. Instead, precision is managed explicitly through a single global `COMPUTE_DTYPE` (defined in `nanochat/common.py`). By default this is auto-detected based on your hardware: + +| Hardware | Default dtype | Why | +|----------|--------------|-----| +| CUDA SM 80+ (A100, H100, ...) | `bfloat16` | Native bf16 tensor cores | +| CUDA SM < 80 (V100, T4, ...) | `float32` | No bf16; fp16 available via `NANOCHAT_DTYPE=float16` (uses GradScaler) | +| CPU / MPS | `float32` | No reduced-precision tensor cores | + +You can override the default with the `NANOCHAT_DTYPE` environment variable: + +```bash +NANOCHAT_DTYPE=float32 python -m scripts.chat_cli -p "hello" # force fp32 +NANOCHAT_DTYPE=bfloat16 torchrun --nproc_per_node=8 -m scripts.base_train # force bf16 +``` + +How it works: model weights are stored in fp32 (for optimizer precision), but our custom `Linear` layer casts them to `COMPUTE_DTYPE` during the forward pass. Embeddings are stored directly in `COMPUTE_DTYPE` to save memory. This gives us the same mixed-precision benefit as autocast but with full explicit control over what runs in which precision. + +Note: `float16` training automatically enables a `GradScaler` in `base_train.py` to prevent gradient underflow. SFT suppors this too but RL currently does not. Inference in fp16 works fine everywhere. + ## Guides I've published a number of guides that might contain helpful information, most recent to least recent: diff --git a/dev/LOG.md b/dev/LOG.md index b4b3757..fd5c3c7 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,41 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-03-04: Remove autocast, explicit dtype management, fp16 GradScaler + +Replaced `torch.amp.autocast` throughout the codebase with explicit dtype management via a single `COMPUTE_DTYPE` global. Also added fp16 training support with GradScaler. + +### Motivation + +autocast is "magic we don't control" — it silently decides which ops run in which precision via internal allowlists. For this codebase, autocast was doing very little: the only thing it actually cast was `nn.Linear` weights from fp32 to bf16 for matmuls. `F.rms_norm`, `F.cross_entropy`, and Flash Attention all handle their own dtypes already. By making precision explicit, we gain fine-grained control (e.g. can experiment with fp32 norms) and eliminate an unnecessary layer of abstraction. + +### What changed + +**Core mechanism** (`nanochat/common.py`, `nanochat/gpt.py`): +- `COMPUTE_DTYPE` auto-detected from hardware: SM 80+ → bf16, pre-Ampere → fp32, CPU/MPS → fp32. Override via `NANOCHAT_DTYPE` env var. +- Custom `Linear(nn.Linear)` class that casts weights to match input dtype in forward: `F.linear(x, self.weight.to(dtype=x.dtype))`. This is the single mechanism that replaces autocast. +- Embeddings cast to `COMPUTE_DTYPE` at init (saves memory). Exception: fp16 keeps embeddings fp32 because GradScaler cannot unscale fp16 gradients. +- Embedding output explicitly cast to `COMPUTE_DTYPE` in `GPT.forward()` (no-op for bf16, active for fp16 path). +- RoPE cos/sin cache uses `COMPUTE_DTYPE` instead of hardcoded bf16. + +**Autocast removal** (11 files): +- Deleted `--dtype` CLI flag, `ptdtype` variables, `autocast_ctx` definitions, and all `with autocast_ctx:` blocks from: `base_train.py`, `chat_sft.py`, `chat_rl.py`, `chat_cli.py`, `chat_eval.py`, `chat_web.py`, `base_eval.py`, `engine.py`, `bench_train_toks.py`, `test_e2e_pipeline.py`. + +**fp16 + GradScaler** (`base_train.py`, `chat_sft.py`): +- `scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None` +- Backward: `scaler.scale(loss).backward()` vs plain `loss.backward()` +- After accumulation: `scaler.unscale_(optimizer)` → distributed inf-sync via `scaler._found_inf_per_device(optimizer)` all-reduced with `ReduceOp.MAX` → `scaler.step(optimizer)` → `scaler.update()` +- Zero overhead for bf16/fp32 paths (scaler is None, no branching inside kernels). + +**FP8 fix** (`nanochat/fp8.py`, `base_train.py`): +- `Float8Linear.forward` explicitly casts input to `COMPUTE_DTYPE` (previously relied on autocast). +- `disable_fp8` context manager now creates our custom `Linear` (not vanilla `nn.Linear`) when swapping out Float8Linear during eval. + +**Flash Attention** (`flash_attention.py`): +- FA3 Hopper kernels don't support fp16 or fp32, so `USE_FA3` (module-level constant, resolved once at import) returns False, falling back to SDPA. + +--- + ## 2026-03-04: Dataset upgrade: FineWeb-EDU 100B → ClimbMix 400B Switched the pretraining dataset from FineWeb-EDU 100B to ClimbMix 400B. This is by far the single biggest improvement to nanochat's GPT-2 speedrun time, bringing it down from **2 hours 46 minutes to 2 hours 1 minute** — a 27% reduction. diff --git a/nanochat/common.py b/nanochat/common.py index 2dd0792..bd14fd2 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -10,6 +10,26 @@ import torch import torch.distributed as dist from filelock import FileLock +# The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision. +# Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast. +# Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32" +_DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32} +def _detect_compute_dtype(): + env = os.environ.get("NANOCHAT_DTYPE") + if env is not None: + return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}" + if torch.cuda.is_available(): + # bf16 requires SM 80+ (Ampere: A100, A10, etc.) + # Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores + capability = torch.cuda.get_device_capability() + if capability >= (8, 0): + return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)" + # fp16 training requires GradScaler (not yet implemented), so fall back to fp32. + # Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing. + return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)" + return torch.float32, "auto-detected: no CUDA (CPU/MPS)" +COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype() + class ColoredFormatter(logging.Formatter): """Custom formatter that adds colors to log messages.""" # ANSI color codes diff --git a/nanochat/engine.py b/nanochat/engine.py index a1ba24c..4724c8f 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -19,7 +19,6 @@ from contextlib import contextmanager from collections import deque from nanochat.common import compute_init, autodetect_device_type from nanochat.checkpoint_manager import load_model -from contextlib import nullcontext # ----------------------------------------------------------------------------- # Calculator tool helpers @@ -308,8 +307,6 @@ if __name__ == "__main__": # init compute device_type = autodetect_device_type() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() - # load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="eval") bos_token_id = tokenizer.get_bos_token_id() @@ -322,11 +319,10 @@ if __name__ == "__main__": torch.cuda.synchronize() t0 = time.time() stream = model.generate(prompt_tokens, **kwargs) - with autocast_ctx: - for token in stream: - generated_tokens.append(token) - chunk = tokenizer.decode([token]) - print(chunk, end="", flush=True) + for token in stream: + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time() @@ -338,12 +334,11 @@ if __name__ == "__main__": stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 torch.cuda.synchronize() t0 = time.time() - with autocast_ctx: - for token_column, token_masks in stream: - token = token_column[0] # only print out the first row - generated_tokens.append(token) - chunk = tokenizer.decode([token]) - print(chunk, end="", flush=True) + for token_column, token_masks in stream: + token = token_column[0] # only print out the first row + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time() diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 89ca42b..af2aee3 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -45,14 +45,22 @@ HAS_FA3 = _fa3 is not None _override_impl = None -def _use_fa3(): - """Determine whether to use FA3 based on availability and override.""" +def _resolve_use_fa3(): + """Decide once whether to use FA3, based on availability, override, and dtype.""" if _override_impl == 'fa3': assert HAS_FA3, "Cannot override to FA3: not available on this hardware" return True if _override_impl == 'sdpa': return False - return HAS_FA3 # auto + if HAS_FA3: + # FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback + from nanochat.common import COMPUTE_DTYPE + if COMPUTE_DTYPE == torch.bfloat16: + return True + return False + return False + +USE_FA3 = _resolve_use_fa3() # ============================================================================= @@ -90,7 +98,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): # sliding window (left) if window >= 0 and window < Tk: mask = mask & ((row_idx - col_idx) <= window) - + return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) # ============================================================================= @@ -108,7 +116,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): Returns: Output tensor of shape (B, T, H, D) """ - if _use_fa3(): + if USE_FA3: return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D) @@ -138,7 +146,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N Returns: Output tensor of shape (B, T_new, H, D) """ - if _use_fa3(): + if USE_FA3: return _fa3.flash_attn_with_kvcache( q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 3e88285..f9bf8d5 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -72,6 +72,8 @@ generates a different graph. Numerics are bitwise identical in eager mode. import torch import torch.nn as nn +from nanochat.common import COMPUTE_DTYPE + # Avoid division by zero when computing scale from an all-zeros tensor EPS = 1e-12 @@ -123,7 +125,7 @@ def _to_col_major(x): class _Float8Matmul(torch.autograd.Function): """Custom autograd for the three FP8 GEMMs of a Linear layer. - The forward quantizes input and weight to FP8 and saves + The forward quantizes input and weight to FP8 and saves the quantized tensors + scales for backward. """ @@ -198,11 +200,9 @@ class Float8Linear(nn.Linear): """ def forward(self, input): - # Replicate the autocast behavior of F.linear — when autocast is active, - # we need to manually cast input to the autocast dtype (e.g. bf16), - # since we bypass F.linear's built-in autocast handling. - if torch.is_autocast_enabled(): - input = input.to(torch.get_autocast_gpu_dtype()) + # Cast input to COMPUTE_DTYPE (typically bf16) since _scaled_mm expects + # reduced precision input, and we no longer rely on autocast to do this. + input = input.to(COMPUTE_DTYPE) # _scaled_mm only works on 2D tensors, so flatten batch dimensions orig_shape = input.shape input_2d = input.reshape(-1, orig_shape[-1]) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 74e39fd..04ee5c5 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -19,7 +19,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from nanochat.common import get_dist_info, print0 +from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE from nanochat.optim import MuonAdamW, DistMuonAdamW # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere @@ -40,8 +40,14 @@ class GPTConfig: def norm(x): - # Purely functional rmsnorm with no learnable params - return F.rms_norm(x, (x.size(-1),)) + return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok + +class Linear(nn.Linear): + """nn.Linear that casts weights to match input dtype in forward. + Replaces autocast: master weights stay fp32 for optimizer precision, + but matmuls run in the activation dtype (typically bf16 from embeddings).""" + def forward(self, x): + return F.linear(x, self.weight.to(dtype=x.dtype)) def has_ve(layer_idx, n_layer): @@ -66,12 +72,12 @@ class CausalSelfAttention(nn.Module): self.head_dim = self.n_embd // self.n_head assert self.n_embd % self.n_head == 0 assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 - self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) - self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) - self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) - self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) + self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False) + self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_proj = Linear(self.n_embd, self.n_embd, bias=False) self.ve_gate_channels = 32 - self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None + self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None def forward(self, x, ve, cos_sin, window_size, kv_cache): B, T, C = x.size() @@ -121,8 +127,8 @@ class CausalSelfAttention(nn.Module): class MLP(nn.Module): def __init__(self, config): super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) + self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False) + self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False) def forward(self, x): x = self.c_fc(x) @@ -164,7 +170,7 @@ class GPT(nn.Module): "wte": nn.Embedding(padded_vocab_size, config.n_embd), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), }) - self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False) + self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False) # Per-layer learnable scalars (inspired by modded-nanogpt) # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled) @@ -234,11 +240,13 @@ class GPT(nn.Module): cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.cos, self.sin = cos, sin - # Cast embeddings to bf16: optimizer can tolerate it and it saves memory - if self.transformer.wte.weight.device.type == "cuda": - self.transformer.wte.to(dtype=torch.bfloat16) + # Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision + # embeddings and it saves memory. Exception: fp16 requires fp32 embeddings + # because GradScaler cannot unscale fp16 gradients. + if COMPUTE_DTYPE != torch.float16: + self.transformer.wte.to(dtype=COMPUTE_DTYPE) for ve in self.value_embeds.values(): - ve.to(dtype=torch.bfloat16) + ve.to(dtype=COMPUTE_DTYPE) def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently @@ -253,7 +261,7 @@ class GPT(nn.Module): # calculate the rotation frequencies at each (time, channel) pair freqs = torch.outer(t, inv_freq) cos, sin = freqs.cos(), freqs.sin() - cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16 + cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE) cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting return cos, sin @@ -391,18 +399,19 @@ class GPT(nn.Module): # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" - assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" + assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}" # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache T0 = 0 if kv_cache is None else kv_cache.get_pos() cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length # Forward the trunk of the Transformer x = self.transformer.wte(idx) # embed current token + x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path) x = norm(x) x0 = x # save initial normalized embedding for x0 residual for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 - ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None + ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) x = norm(x) diff --git a/scripts/base_eval.py b/scripts/base_eval.py index e45ae43..a57bbaf 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -29,8 +29,6 @@ import random import zipfile import tempfile import argparse -from contextlib import nullcontext - import torch from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock @@ -199,8 +197,6 @@ def main(): # Distributed / precision setup device_type = autodetect_device_type() if args.device_type == '' else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() - # Load model and tokenizer is_hf_model = args.hf_path is not None if is_hf_model: @@ -244,8 +240,7 @@ def main(): print0("\nConditioned samples:") for prompt in prompts: tokens = tokenizer(prompt, prepend="<|bos|>") - with autocast_ctx: - sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) + sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) sample_str = tokenizer.decode(sample[0]) print0("-" * 80) print0(sample_str) @@ -253,8 +248,7 @@ def main(): print0("\nUnconditioned samples:") tokens = tokenizer("", prepend="<|bos|>") - with autocast_ctx: - uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0) + uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0) for sample in uncond: sample_str = tokenizer.decode(sample) print0("-" * 80) @@ -277,8 +271,7 @@ def main(): for split_name in ["train", "val"]: loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device) - with autocast_ctx: - bpb = evaluate_bpb(model, loader, steps, token_bytes) + bpb = evaluate_bpb(model, loader, steps, token_bytes) bpb_results[split_name] = bpb print0(f"{split_name} bpb: {bpb:.6f}") @@ -287,8 +280,7 @@ def main(): print0("\n" + "="*80) print0("CORE Evaluation") print0("="*80) - with autocast_ctx: - core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task) + core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task) # Write CSV output if ddp_rank == 0: diff --git a/scripts/base_train.py b/scripts/base_train.py index 9461e88..4bf7959 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -19,14 +19,15 @@ import time import math import argparse from dataclasses import asdict -from contextlib import nullcontext, contextmanager +from contextlib import contextmanager import wandb import torch +import torch.distributed as dist -from nanochat.gpt import GPT, GPTConfig +from nanochat.gpt import GPT, GPTConfig, Linear from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -86,7 +87,6 @@ user_config = vars(args).copy() # for logging device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 if device_type == "cuda": @@ -95,17 +95,23 @@ if device_type == "cuda": print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") else: gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS +print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})") # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) # Flash Attention status -if HAS_FA3: +from nanochat.flash_attention import USE_FA3 +using_fa3 = USE_FA3 +if using_fa3: print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") else: print0("!" * 80) - print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") + if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16: + print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback") + else: + print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") print0("WARNING: Training will be less efficient without FA3") if args.window_pattern != "L": print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.") @@ -213,9 +219,9 @@ def disable_fp8(model): yield # No FP8 modules, nothing to do return - # Swap Float8Linear -> nn.Linear (shares the same weight tensor, no copy) + # Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype) for parent, attr_name, fp8_module in fp8_locations: - linear = nn.Linear( + linear = Linear( fp8_module.in_features, fp8_module.out_features, bias=fp8_module.bias is not None, @@ -315,6 +321,12 @@ if resuming: optimizer.load_state_dict(optimizer_data) del optimizer_data +# ----------------------------------------------------------------------------- +# GradScaler for fp16 training (bf16/fp32 don't need it — bf16 has the same exponent range as fp32) +scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None +if scaler is not None: + print0("GradScaler enabled for fp16 training") + # ----------------------------------------------------------------------------- # Initialize the DataLoaders for train/val dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] @@ -405,7 +417,7 @@ while True: model.eval() val_loader = build_val_loader() eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) - with disable_fp8(model), autocast_ctx: + with disable_fp8(model): val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}") if val_bpb < min_val_bpb: @@ -424,7 +436,7 @@ while True: results = {} if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)): model.eval() - with disable_fp8(orig_model), autocast_ctx: + with disable_fp8(orig_model): results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") wandb_run.log({ @@ -451,7 +463,7 @@ while True: engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation for prompt in prompts: tokens = tokenizer(prompt, prepend="<|bos|>") - with disable_fp8(orig_model), autocast_ctx: + with disable_fp8(orig_model): sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) print0(tokenizer.decode(sample[0])) model.train() @@ -491,11 +503,13 @@ while True: synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): - with autocast_ctx: - loss = model(x, y) + loss = model(x, y) train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here - loss.backward() + if scaler is not None: + scaler.scale(loss).backward() + else: + loss.backward() x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward # step the optimizer lrm = get_lr_multiplier(step) @@ -506,7 +520,18 @@ while True: if group['kind'] == 'muon': group["momentum"] = muon_momentum group["weight_decay"] = muon_weight_decay - optimizer.step() + if scaler is not None: + scaler.unscale_(optimizer) + # In distributed training, all ranks must agree on whether to skip the step. + # Each rank may independently encounter inf/nan gradients, so we all-reduce + # the found_inf flag (MAX = if any rank found inf, all ranks skip). + if is_ddp_initialized(): + for v in scaler._found_inf_per_device(optimizer).values(): + dist.all_reduce(v, op=dist.ReduceOp.MAX) + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() model.zero_grad(set_to_none=True) train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point synchronize() diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index 7de7e10..2bcc8aa 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -7,7 +7,6 @@ python -m scripts.chat_cli import argparse import torch from nanochat.common import compute_init, autodetect_device_type -from contextlib import nullcontext from nanochat.engine import Engine from nanochat.checkpoint_manager import load_model @@ -19,15 +18,12 @@ parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the mod parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation') parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter') parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') -parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) args = parser.parse_args() # Init the model and tokenizer device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) -ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) # Special tokens for the chat state machine @@ -87,12 +83,11 @@ while True: } response_tokens = [] print("\nAssistant: ", end="", flush=True) - with autocast_ctx: - for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs): - token = token_column[0] # pop the batch dimension (num_samples=1) - response_tokens.append(token) - token_text = tokenizer.decode([token]) - print(token_text, end="", flush=True) + for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs): + token = token_column[0] # pop the batch dimension (num_samples=1) + response_tokens.append(token) + token_text = tokenizer.decode([token]) + print(token_text, end="", flush=True) print() # we have to ensure that the assistant end token is the last token # so even if generation ends due to max tokens, we have to append it to the end diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index bc15239..858d4c2 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -10,8 +10,6 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy import argparse from functools import partial -from contextlib import nullcontext - import torch import torch.distributed as dist @@ -185,7 +183,6 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl") parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.") - parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) parser.add_argument('-t', '--temperature', type=float, default=0.0) parser.add_argument('-m', '--max-new-tokens', type=int, default=512) parser.add_argument('-n', '--num-samples', type=int, default=1) @@ -199,8 +196,6 @@ if __name__ == "__main__": device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) - ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) engine = Engine(model, tokenizer) @@ -220,19 +215,18 @@ if __name__ == "__main__": # Run all the task evaluations sequentially results = {} for task_name in task_names: - with autocast_ctx: - acc = run_chat_eval( - task_name, - model, tokenizer, engine, - batch_size=args.batch_size, - num_samples=args.num_samples, - max_new_tokens=args.max_new_tokens, - temperature=args.temperature, - top_k=args.top_k, - max_problems=args.max_problems, - ) - results[task_name] = acc - print0(f"{task_name} accuracy: {100 * acc:.2f}%") + acc = run_chat_eval( + task_name, + model, tokenizer, engine, + batch_size=args.batch_size, + num_samples=args.num_samples, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=args.top_k, + max_problems=args.max_problems, + ) + results[task_name] = acc + print0(f"{task_name} accuracy: {100 * acc:.2f}%") # Log to report from nanochat.report import get_report diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index 20a1a0a..cb2cb0e 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -22,8 +22,6 @@ import itertools import wandb import torch import torch.distributed as dist -from contextlib import nullcontext - from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type from nanochat.checkpoint_manager import save_checkpoint, load_model from nanochat.engine import Engine @@ -36,7 +34,6 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K") parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") -parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") @@ -68,8 +65,6 @@ user_config = vars(args).copy() device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. -ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process @@ -108,15 +103,14 @@ def get_batch(): num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs for sampling_step in range(num_sampling_steps): seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32 - with autocast_ctx: - generated_token_sequences_batch, masks_batch = engine.generate_batch( - tokens, - num_samples=args.device_batch_size, - max_tokens=args.max_new_tokens, - temperature=args.temperature, - top_k=args.top_k, - seed=seed, # must make sure to change the seed for each sampling step - ) + generated_token_sequences_batch, masks_batch = engine.generate_batch( + tokens, + num_samples=args.device_batch_size, + max_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=args.top_k, + seed=seed, # must make sure to change the seed for each sampling step + ) generated_token_sequences.extend(generated_token_sequences_batch) masks.extend(masks_batch) @@ -231,9 +225,8 @@ for step in range(num_steps): if step % args.eval_every == 0: model.eval() passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size - with autocast_ctx: - records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0) - records = list(records_iter) # collect all records + records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0) + records = list(records_iter) # collect all records for k in range(1, args.device_batch_size + 1): passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records) num_records = torch.tensor(len(records), dtype=torch.long, device=device) @@ -268,8 +261,7 @@ for step in range(num_steps): rewards = rewards_all[b0:b1] advantages = advantages_all[b0:b1] # Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate - with autocast_ctx: - logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T) + logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T) # Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0. pg_obj = (logp * advantages.unsqueeze(-1)).sum() # normalize by the number of valid tokens, number of passes, and examples_per_rank diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index cb9e078..c1adbb6 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -16,8 +16,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import time import wandb import torch -from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state from nanochat.loss_eval import evaluate_bpb @@ -75,7 +74,7 @@ user_config = vars(args).copy() device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})") synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 if device_type == "cuda": @@ -151,6 +150,11 @@ if args.load_optimizer: else: print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)") +# GradScaler for fp16 training (bf16/fp32 don't need it) +scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None +if scaler is not None: + print0("GradScaler enabled for fp16 training") + # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: group["lr"] = group["lr"] * args.init_lr_frac @@ -344,8 +348,7 @@ while True: model.eval() val_loader = build_val_loader() eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) - with autocast_ctx: - val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) + val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") if val_bpb < min_val_bpb: min_val_bpb = val_bpb @@ -373,9 +376,8 @@ while True: for task_name in all_tasks: limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample max_problems = None if limit < 0 else limit # -1 means no limit - with autocast_ctx: - acc = run_chat_eval(task_name, orig_model, tokenizer, engine, - batch_size=args.device_batch_size, max_problems=max_problems) + acc = run_chat_eval(task_name, orig_model, tokenizer, engine, + batch_size=args.device_batch_size, max_problems=max_problems) task_results[task_name] = acc print0(f" {task_name}: {100*acc:.2f}%") # Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect) @@ -428,11 +430,13 @@ while True: synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): - with autocast_ctx: - loss = model(x, y) + loss = model(x, y) train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here - loss.backward() + if scaler is not None: + scaler.scale(loss).backward() + else: + loss.backward() x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward progress = max(progress, approx_progress) # only increase progress monotonically # step the optimizer @@ -442,7 +446,15 @@ while True: group["lr"] = group["initial_lr"] * lrm if group['kind'] == 'muon': group["momentum"] = muon_momentum - optimizer.step() + if scaler is not None: + scaler.unscale_(optimizer) + if is_ddp_initialized(): + for v in scaler._found_inf_per_device(optimizer).values(): + dist.all_reduce(v, op=dist.ReduceOp.MAX) + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() model.zero_grad(set_to_none=True) synchronize() t1 = time.time() diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 66d7806..ffaf7da 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -44,7 +44,6 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from pydantic import BaseModel from typing import List, Optional, AsyncGenerator from dataclasses import dataclass -from contextlib import nullcontext from nanochat.common import compute_init, autodetect_device_type from nanochat.checkpoint_manager import load_model from nanochat.engine import Engine @@ -69,7 +68,6 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') -parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') args = parser.parse_args() @@ -84,7 +82,6 @@ logger = logging.getLogger(__name__) device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) -ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 @dataclass class Worker: @@ -93,7 +90,6 @@ class Worker: device: torch.device engine: Engine tokenizer: object - autocast_ctx: torch.amp.autocast class WorkerPool: """Pool of workers, each with a model replica on a different GPU.""" @@ -125,14 +121,11 @@ class WorkerPool: model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) engine = Engine(model, tokenizer) - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() - worker = Worker( gpu_id=gpu_id, device=device, engine=engine, tokenizer=tokenizer, - autocast_ctx=autocast_ctx ) self.workers.append(worker) await self.available_workers.put(worker) @@ -279,34 +272,33 @@ async def generate_stream( # Track the last complete UTF-8 string (without replacement characters) last_clean_text = "" - with worker.autocast_ctx: - for token_column, token_masks in worker.engine.generate( - tokens, - num_samples=1, - max_tokens=max_new_tokens, - temperature=temperature, - top_k=top_k, - seed=random.randint(0, 2**31 - 1) - ): - token = token_column[0] + for token_column, token_masks in worker.engine.generate( + tokens, + num_samples=1, + max_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + seed=random.randint(0, 2**31 - 1) + ): + token = token_column[0] - # Stopping criteria - if token == assistant_end or token == bos: - break + # Stopping criteria + if token == assistant_end or token == bos: + break - # Append the token to sequence - accumulated_tokens.append(token) - # Decode all accumulated tokens to get proper UTF-8 handling - # Note that decode is a quite efficient operation, basically table lookup and string concat - current_text = worker.tokenizer.decode(accumulated_tokens) - # Only emit text if it doesn't end with a replacement character - # This ensures we don't emit incomplete UTF-8 sequences - if not current_text.endswith('�'): - # Extract only the new text since last clean decode - new_text = current_text[len(last_clean_text):] - if new_text: # Only yield if there's new content - yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n" - last_clean_text = current_text + # Append the token to sequence + accumulated_tokens.append(token) + # Decode all accumulated tokens to get proper UTF-8 handling + # Note that decode is a quite efficient operation, basically table lookup and string concat + current_text = worker.tokenizer.decode(accumulated_tokens) + # Only emit text if it doesn't end with a replacement character + # This ensures we don't emit incomplete UTF-8 sequences + if not current_text.endswith('�'): + # Extract only the new text since last clean decode + new_text = current_text[len(last_clean_text):] + if new_text: # Only yield if there's new content + yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n" + last_clean_text = current_text yield f"data: {json.dumps({'done': True})}\n\n" diff --git a/tests/test_attention_fallback.py b/tests/test_attention_fallback.py index 9741c7f..3eddc72 100644 --- a/tests/test_attention_fallback.py +++ b/tests/test_attention_fallback.py @@ -21,8 +21,9 @@ from nanochat.engine import KVCache def set_impl(impl): - """Set the implementation override ('fa3', 'sdpa', or None for auto).""" + """Set the implementation override ('fa3', 'sdpa', or None for auto) and re-resolve USE_FA3.""" fa_module._override_impl = impl + fa_module.USE_FA3 = fa_module._resolve_use_fa3() def run_both_impls(fn): @@ -343,19 +344,19 @@ class TestOverrideMechanism: def test_override_fa3(self): """Test that override='fa3' uses FA3.""" set_impl('fa3') - assert fa_module._use_fa3() == True + assert fa_module.USE_FA3 == True set_impl(None) def test_override_sdpa(self): """Test that override='sdpa' uses SDPA.""" set_impl('sdpa') - assert fa_module._use_fa3() == False + assert fa_module.USE_FA3 == False set_impl(None) def test_override_auto(self): """Test that override=None uses auto-detection.""" set_impl(None) - assert fa_module._use_fa3() == HAS_FA3 + assert fa_module.USE_FA3 == HAS_FA3 if __name__ == "__main__":