Merge branch 'master' into master

This commit is contained in:
tillo 2025-10-26 15:12:17 +01:00 committed by GitHub
commit 725599b86d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 461 additions and 15 deletions

View File

@ -101,6 +101,8 @@ nanochat cn be run on CPU or on MPS (if you're on Macbook), and will automatical
To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages.
Additionally, to add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164).
## Questions
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
@ -121,6 +123,71 @@ I haven't invested too much here but some tests exist, especially for the tokeni
python -m pytest tests/test_rustbpe.py -v -s
```
## File structure
```
.
├── LICENSE
├── README.md
├── dev
│ ├── gen_synthetic_data.py # Example synthetic data for identity
│ ├── generate_logo.html
│ ├── nanochat.png
│ ├── repackage_data_reference.py # Pretraining data shard generation
│ └── runcpu.sh # Small example of how to run on CPU/MPS
├── nanochat
│ ├── __init__.py # empty
│ ├── adamw.py # Distributed AdamW optimizer
│ ├── checkpoint_manager.py # Save/Load model checkpoints
│ ├── common.py # Misc small utilities, quality of life
│ ├── configurator.py # A superior alternative to argparse
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
│ ├── dataloader.py # Tokenizing Distributed Data Loader
│ ├── dataset.py # Download/read utils for pretraining data
│ ├── engine.py # Efficient model inference with KV Cache
│ ├── execution.py # Allows the LLM to execute Python code as tool
│ ├── gpt.py # The GPT nn.Module Transformer
│ ├── logo.svg
│ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
│ ├── muon.py # Distributed Muon optimizer
│ ├── report.py # Utilities for writing the nanochat Report
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
│ └── ui.html # HTML/CSS/JS for nanochat frontend
├── pyproject.toml
├── run1000.sh # Train the ~$800 nanochat d32
├── rustbpe # Custom Rust BPE tokenizer trainer
│ ├── Cargo.lock
│ ├── Cargo.toml
│ ├── README.md # see for why this even exists
│ └── src
│ └── lib.rs
├── scripts
│ ├── base_eval.py # Base model: calculate CORE score
│ ├── base_loss.py # Base model: calculate bits per byte, sample
│ ├── base_train.py # Base model: train
│ ├── chat_cli.py # Chat model (SFT/Mid): talk to over CLI
│ ├── chat_eval.py # Chat model (SFT/Mid): eval tasks
│ ├── chat_rl.py # Chat model (SFT/Mid): reinforcement learning
│ ├── chat_sft.py # Chat model: train SFT
│ ├── chat_web.py # Chat model (SFT/Mid): talk to over WebUI
│ ├── mid_train.py # Chat model: midtraining
│ ├── tok_eval.py # Tokenizer: evaluate compression rate
│ └── tok_train.py # Tokenizer: train it
├── speedrun.sh # Train the ~$100 nanochat d20
├── tasks
│ ├── arc.py # Multiple choice science questions
│ ├── common.py # TaskMixture | TaskSequence
│ ├── customjson.py # Make Task from arbitrary jsonl convos
│ ├── gsm8k.py # 8K Grade School Math questions
│ ├── humaneval.py # Misnomer; Simple Python coding task
│ ├── mmlu.py # Multiple choice questions, broad topics
│ ├── smoltalk.py # Conglomarate dataset of SmolTalk from HF
│ └── spellingbee.py # Task teaching model to spell/count letters
├── tests
│ └── test_rustbpe.py
└── uv.lock
```
## Contributing
nanochat is nowhere finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card.

View File

@ -5,6 +5,8 @@ Common utilities for nanochat.
import os
import re
import logging
import fcntl
import urllib.request
import torch
import torch.distributed as dist
@ -71,9 +73,46 @@ def get_base_dir():
os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir
def download_file_with_lock(url, filename):
"""
Downloads a file from a URL to a local path in the base directory.
Uses a lock file to prevent concurrent downloads among multiple ranks.
"""
base_dir = get_base_dir()
file_path = os.path.join(base_dir, filename)
lock_path = file_path + ".lock"
def print0(s="", **kwargs):
ddp_rank = int(os.environ.get("RANK", 0))
if os.path.exists(file_path):
return file_path
with open(lock_path, 'w') as lock_file:
# Only a single rank can acquire this lock
# All other ranks block until it is released
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
if os.path.exists(file_path):
return file_path
print(f"Downloading {url}...")
with urllib.request.urlopen(url) as response:
content = response.read().decode('utf-8')
with open(file_path, 'w') as f:
f.write(content)
print(f"Downloaded to {file_path}")
# Clean up the lock file after the lock is released
try:
os.remove(lock_path)
except OSError:
pass # Ignore if already removed by another process
return file_path
def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
print(s, **kwargs)

View File

@ -49,14 +49,38 @@ def eval_with_timeout(formula, max_time=3):
def use_calculator(expr):
"""Evaluate a math expression safely."""
"""
Evaluate a Python expression safely.
Supports both math expressions and string operations like .count()
"""
# Remove commas from numbers
expr = expr.replace(",", "")
if any(
[x not in "0123456789*+-/.() " for x in expr]
): # for now disallow non-numeric chars
# Check if it's a pure math expression (old behavior)
if all([x in "0123456789*+-/.() " for x in expr]):
if "**" in expr: # disallow power operator
return None
return eval_with_timeout(expr)
# Check if it's a string operation we support
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
if not all([x in allowed_chars for x in expr]):
return None
if "**" in expr: # for now disallow power operator, could be very expensive
# Disallow dangerous patterns
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
'getattr', 'setattr', 'delattr', 'hasattr']
expr_lower = expr.lower()
if any(pattern in expr_lower for pattern in dangerous_patterns):
return None
# Only allow .count() method for now (can expand later)
if '.count(' not in expr:
return None
# Evaluate with timeout
return eval_with_timeout(expr)

View File

@ -341,16 +341,19 @@ class RustBPETokenizer:
mask = mask[:max_tokens]
return ids, mask
def visualize_tokenization(self, ids, mask):
def visualize_tokenization(self, ids, mask, with_token_id=False):
"""Small helper function useful in debugging: visualize the tokenization of render_conversation"""
RED = '\033[91m'
GREEN = '\033[92m'
RESET = '\033[0m'
GRAY = '\033[90m'
tokens = []
for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
token_str = self.decode([token_id])
color = GREEN if mask_val == 1 else RED
tokens.append(f"{color}{token_str}{RESET}")
if with_token_id:
tokens.append(f"{GRAY}({token_id}){RESET}")
return '|'.join(tokens)
def render_for_completion(self, conversation):

View File

@ -49,6 +49,9 @@ unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
# Evaluation
eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
@ -151,10 +154,6 @@ x, y = next(train_loader) # kick off load of the very first batch of data
# Set up hyperparameter schedulers
# Learning rate scheduler
# TODO: experiment with a short warmup for the AdamW params (expecting slight improvement)
warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)

View File

@ -23,6 +23,7 @@ from tasks.humaneval import HumanEval
from tasks.mmlu import MMLU
from tasks.arc import ARC
from tasks.gsm8k import GSM8K
from tasks.spellingbee import SpellingBee
# -----------------------------------------------------------------------------
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
@ -165,6 +166,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
'GSM8K': partial(GSM8K, subset="main", split="test"),
'SpellingBee': partial(SpellingBee, size=256, split="test"),
}[task_name]
task_object = task_module()
# Run the evaluation
@ -204,13 +206,14 @@ if __name__ == "__main__":
engine = Engine(model, tokenizer)
# Get the tasks to evaluate on
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval']
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
baseline_accuracies = {
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
'GSM8K': 0.0, # open-ended => 0%
'HumanEval': 0.0, # open-ended => 0%
'SpellingBee': 0.0, # open-ended => 0%
}
task_names = all_tasks if args.task_name is None else args.task_name.split('|')

View File

@ -28,6 +28,7 @@ from tasks.arc import ARC
from tasks.gsm8k import GSM8K
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
# SFT Hyperparameters
@ -86,7 +87,9 @@ train_ds = TaskMixture([
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
# -----------------------------------------------------------------------------

View File

@ -28,6 +28,7 @@ from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
@ -100,7 +101,9 @@ train_dataset = TaskMixture([
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
]) # total: 460K + 100K + 8K = 568K rows
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios

305
tasks/spellingbee.py Normal file
View File

@ -0,0 +1,305 @@
"""
Task intended to make nanochat better in spelling and counting, for example:
"How many r are in strawberry?" -> 3
An interesting part of this task is that we will get the assistant to
solve the problem using a combination of manual counting and Python.
This is a good problem solving "instinct" to mix into the model and RL
may further refine it to trust one over the other. If we were extra fancy
(which we could/should be) we'd add small errors here and there to allow
the model also learn recoveries. We can do this in future versions.
There are two tasks in this file:
1. SpellingBee: Counting the number of occurrences of a letter in a word
2. SimpleSpelling: Simply spelling words
(1) is the goal, but (2) exists as a highly condensed version of the part
that makes (1) difficult, which is word spelling. This is non-trivial for an
LLM because it has to learn how every token (a little semantic chunk/atom)
maps to the sequence of individual characters that make it up. Larger models
learn this eventually on their own, but if we want this capability to exist
in smaller models, we have to actively encourage it by over-representing it
in the training data. Midtraining is a good place to do this.
To preview a few example conversations, run:
python -m tasks.spellingbee
"""
import re
import random
from tasks.common import Task
from nanochat.common import download_file_with_lock
# Letters of the alphabet
LETTERS = "abcdefghijklmnopqrstuvwxyz"
# A list of 370K English words of large variety
WORD_LIST_URL = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words_alpha.txt"
# Identical to gsm8k's answer extraction
ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion):
"""
Extract the numerical answer after #### marker.
"""
match = ANSWER_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return match_str
return None
# User message templates for data augmentation
USER_MSG_TEMPLATES = [
"How many {letter} are in the word {word}",
"How many {letter} are in {word}",
"Count the number of {letter} in {word}",
"How many times does {letter} appear in {word}",
"What's the count of {letter} in {word}",
"In the word {word}, how many {letter} are there",
"How many letter {letter} are in the word {word}",
"Count how many {letter} appear in {word}",
"Tell me the number of {letter} in {word}",
"How many occurrences of {letter} are in {word}",
"Find the count of {letter} in {word}",
"Can you count the {letter} letters in {word}",
"What is the frequency of {letter} in {word}",
"How many {letter}s are in {word}",
"How many {letter}'s are in {word}",
"Count all the {letter} in {word}",
"How many times is {letter} in {word}",
"Number of {letter} in {word}",
"Total count of {letter} in {word}",
"How many {letter} does {word} have",
"How many {letter} does {word} contain",
"What's the number of {letter} in {word}",
"{word} has how many {letter}",
"In {word}, count the {letter}",
"How many {letter} appear in {word}",
"Count the {letter} in {word}",
"Give me the count of {letter} in {word}",
"How many instances of {letter} in {word}",
"Show me how many {letter} are in {word}",
"Calculate the number of {letter} in {word}",
# Spanish
"¿Cuántas {letter} hay en {word}?",
"¿Cuántas veces aparece {letter} en {word}?",
"Cuenta las {letter} en {word}",
"¿Cuántas letras {letter} tiene {word}?",
# Chinese (Simplified)
"{word}中有多少个{letter}",
"{word}里有几个{letter}",
"数一下{word}中的{letter}",
"{word}这个词里有多少{letter}",
# Korean
"{word}{letter}가 몇 개 있나요",
"{word}에서 {letter}의 개수는",
"{word}{letter}가 몇 번 나오나요",
"{word}라는 단어에 {letter}가 몇 개",
# French
"Combien de {letter} dans {word}",
"Combien de fois {letter} apparaît dans {word}",
"Compte les {letter} dans {word}",
# German
"Wie viele {letter} sind in {word}",
"Wie oft kommt {letter} in {word} vor",
"Zähle die {letter} in {word}",
# Japanese
"{word}{letter}は何個ありますか",
"{word}の中に{letter}がいくつ",
"{word}{letter}が何回出てくる",
]
class SpellingBee(Task):
def __init__(self, size=1000, split="train", **kwargs):
super().__init__(**kwargs)
assert split in ["train", "test"], "SpellingBee split must be train|test"
self.size = size
self.split = split
filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path) as f:
words = [line.strip() for line in f]
self.words = words
@property
def eval_type(self):
return 'generative'
def num_examples(self):
return self.size
def get_example(self, index):
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
rng = random.Random(seed)
# pick a random word
word = rng.choice(self.words)
# pick a letter from it (90%) or a random letter (10%)
letter = rng.choice(word) if rng.random() < 0.9 else rng.choice(LETTERS)
# get the correct answer by simply counting
count = word.count(letter)
# create a user message, with a bunch of variations as data augmentation
template = rng.choice(USER_MSG_TEMPLATES)
# 30% chance to lowercase the template (lazy people don't use shift)
if rng.random() < 0.3:
template = template.lower()
quote_options = ['', "'", '"']
letter_quote = rng.choice(quote_options) # is the letter quoted?
word_quote = rng.choice(quote_options) # is the word quoted?
letter_wrapped = f"{letter_quote}{letter}{letter_quote}"
word_wrapped = f"{word_quote}{word}{word_quote}"
user_msg = template.format(letter=letter_wrapped, word=word_wrapped)
if rng.random() < 0.5: # 50% of people don't even use question marks
user_msg += "?"
# Now create the ideal assistant response - build as parts (text + tool calls)
assistant_parts = []
word_letters = ",".join(list(word))
manual_text = f"""We are asked to find the number '{letter}' in the word '{word}'. Let me try a manual approach first.
First spell the word out:
{word}:{word_letters}
Then count the occurrences of '{letter}':
"""
# Little simulated loop of the solution process
# TODO: This is where the fun starts, we could simulate cute little mistakes
# and get the model to review its work and recover from them.
# You might of course hope this could arise in RL too, but realistically you'd want to help it out a bit.
running_count = 0
for i, char in enumerate(word, 1):
if char == letter:
running_count += 1
# note: there deliberately cannot be a space here between i and char
# because this would create a different token! (e.g. " a" and "a" are different tokens)
manual_text += f"{i}:{char} hit! count={running_count}\n"
else:
manual_text += f"{i}:{char}\n"
manual_text += f"\nThis gives us {running_count}."
assistant_parts.append({"type": "text", "text": manual_text})
# Part 2: Python verification
assistant_parts.append({"type": "text", "text": "\n\nLet me double check this using Python:\n\n"})
# Part 3: Python tool call
python_expr = f"'{word}'.count('{letter}')"
assistant_parts.append({"type": "python", "text": python_expr})
# Part 4: Python output
assistant_parts.append({"type": "python_output", "text": str(count)})
# Part 5: Final answer
assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"})
# return the full conversation
messages = [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_parts}
]
conversation = {
"messages": messages,
}
return conversation
def evaluate(self, conversation, assistant_response):
"""
Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct)
Identical to gsm8k's evaluation.
"""
assert isinstance(assistant_response, str), "Assuming simple string response for now"
# First extract the ground truth answer from the conversation
assistant_message = conversation['messages'][-1]
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
# The last text part contains the final answer with ####
last_text_part = assistant_message['content'][-1]['text']
# Extract both the ground truth answer and the predicted answer
ref_num = extract_answer(last_text_part)
pred_num = extract_answer(assistant_response)
# Compare and return the success as int
is_correct = int(pred_num == ref_num)
return is_correct
def reward(self, conversation, assistant_response):
""" Use simple 0-1 reward just like gsm8k."""
is_correct = self.evaluate(conversation, assistant_response)
is_correct_float = float(is_correct)
return is_correct_float
class SimpleSpelling(Task):
"""Much simpler task designed to get the model to just practice spelling words."""
def __init__(self, size=1000, split="train", **kwargs):
super().__init__(**kwargs)
assert split in ["train", "test"], "SpellingBee split must be train|test"
self.size = size
self.split = split
filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path) as f:
words = [line.strip() for line in f]
rng = random.Random(42)
rng.shuffle(words) # use a different word order than the SpellingBee task
self.words = words
@property
def eval_type(self):
return 'generative'
def num_examples(self):
return self.size
def get_example(self, index):
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
rng = random.Random(seed)
# pick a random word
word = rng.choice(self.words)
word_letters = ",".join(list(word))
# return the full conversation
messages = [
{"role": "user", "content": f"Spell the word: {word}"},
{"role": "assistant", "content": f"{word}:{word_letters}"}
]
conversation = {
"messages": messages,
}
return conversation
if __name__ == "__main__":
# preview the SpellingBee task, first 10 examples
task = SpellingBee()
for i in range(10):
ex = task.get_example(i)
print("=" * 100)
print(ex['messages'][0]['content'])
print("-" * 100)
# Assistant content is now a list of parts
assistant_parts = ex['messages'][1]['content']
for part in assistant_parts:
if part['type'] == 'text':
print(part['text'], end='')
elif part['type'] == 'python':
print(f"<<{part['text']}=", end='')
elif part['type'] == 'python_output':
print(f"{part['text']}>>", end='')
print()
print("-" * 100)
# # preview the SimpleSpelling task, first 10 examples
# task = SimpleSpelling()
# for i in range(10):
# ex = task.get_example(i)
# print("=" * 100)
# print(ex['messages'][0]['content'])
# print("-" * 100)
# print(ex['messages'][1]['content'])
# # also scrutinize the tokenization (last example only)
# from nanochat.tokenizer import get_tokenizer
# tokenizer = get_tokenizer()
# ids, mask = tokenizer.render_conversation(ex)
# print(tokenizer.visualize_tokenization(ids, mask, with_token_id=True))