mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Merge branch 'master' into master
This commit is contained in:
commit
725599b86d
67
README.md
67
README.md
|
|
@ -101,6 +101,8 @@ nanochat cn be run on CPU or on MPS (if you're on Macbook), and will automatical
|
|||
|
||||
To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages.
|
||||
|
||||
Additionally, to add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164).
|
||||
|
||||
## Questions
|
||||
|
||||
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
|
||||
|
|
@ -121,6 +123,71 @@ I haven't invested too much here but some tests exist, especially for the tokeni
|
|||
python -m pytest tests/test_rustbpe.py -v -s
|
||||
```
|
||||
|
||||
## File structure
|
||||
|
||||
```
|
||||
.
|
||||
├── LICENSE
|
||||
├── README.md
|
||||
├── dev
|
||||
│ ├── gen_synthetic_data.py # Example synthetic data for identity
|
||||
│ ├── generate_logo.html
|
||||
│ ├── nanochat.png
|
||||
│ ├── repackage_data_reference.py # Pretraining data shard generation
|
||||
│ └── runcpu.sh # Small example of how to run on CPU/MPS
|
||||
├── nanochat
|
||||
│ ├── __init__.py # empty
|
||||
│ ├── adamw.py # Distributed AdamW optimizer
|
||||
│ ├── checkpoint_manager.py # Save/Load model checkpoints
|
||||
│ ├── common.py # Misc small utilities, quality of life
|
||||
│ ├── configurator.py # A superior alternative to argparse
|
||||
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
|
||||
│ ├── dataloader.py # Tokenizing Distributed Data Loader
|
||||
│ ├── dataset.py # Download/read utils for pretraining data
|
||||
│ ├── engine.py # Efficient model inference with KV Cache
|
||||
│ ├── execution.py # Allows the LLM to execute Python code as tool
|
||||
│ ├── gpt.py # The GPT nn.Module Transformer
|
||||
│ ├── logo.svg
|
||||
│ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
|
||||
│ ├── muon.py # Distributed Muon optimizer
|
||||
│ ├── report.py # Utilities for writing the nanochat Report
|
||||
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
|
||||
│ └── ui.html # HTML/CSS/JS for nanochat frontend
|
||||
├── pyproject.toml
|
||||
├── run1000.sh # Train the ~$800 nanochat d32
|
||||
├── rustbpe # Custom Rust BPE tokenizer trainer
|
||||
│ ├── Cargo.lock
|
||||
│ ├── Cargo.toml
|
||||
│ ├── README.md # see for why this even exists
|
||||
│ └── src
|
||||
│ └── lib.rs
|
||||
├── scripts
|
||||
│ ├── base_eval.py # Base model: calculate CORE score
|
||||
│ ├── base_loss.py # Base model: calculate bits per byte, sample
|
||||
│ ├── base_train.py # Base model: train
|
||||
│ ├── chat_cli.py # Chat model (SFT/Mid): talk to over CLI
|
||||
│ ├── chat_eval.py # Chat model (SFT/Mid): eval tasks
|
||||
│ ├── chat_rl.py # Chat model (SFT/Mid): reinforcement learning
|
||||
│ ├── chat_sft.py # Chat model: train SFT
|
||||
│ ├── chat_web.py # Chat model (SFT/Mid): talk to over WebUI
|
||||
│ ├── mid_train.py # Chat model: midtraining
|
||||
│ ├── tok_eval.py # Tokenizer: evaluate compression rate
|
||||
│ └── tok_train.py # Tokenizer: train it
|
||||
├── speedrun.sh # Train the ~$100 nanochat d20
|
||||
├── tasks
|
||||
│ ├── arc.py # Multiple choice science questions
|
||||
│ ├── common.py # TaskMixture | TaskSequence
|
||||
│ ├── customjson.py # Make Task from arbitrary jsonl convos
|
||||
│ ├── gsm8k.py # 8K Grade School Math questions
|
||||
│ ├── humaneval.py # Misnomer; Simple Python coding task
|
||||
│ ├── mmlu.py # Multiple choice questions, broad topics
|
||||
│ ├── smoltalk.py # Conglomarate dataset of SmolTalk from HF
|
||||
│ └── spellingbee.py # Task teaching model to spell/count letters
|
||||
├── tests
|
||||
│ └── test_rustbpe.py
|
||||
└── uv.lock
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
nanochat is nowhere finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card.
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ Common utilities for nanochat.
|
|||
import os
|
||||
import re
|
||||
import logging
|
||||
import fcntl
|
||||
import urllib.request
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
|
@ -71,9 +73,46 @@ def get_base_dir():
|
|||
os.makedirs(nanochat_dir, exist_ok=True)
|
||||
return nanochat_dir
|
||||
|
||||
def download_file_with_lock(url, filename):
|
||||
"""
|
||||
Downloads a file from a URL to a local path in the base directory.
|
||||
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
||||
"""
|
||||
base_dir = get_base_dir()
|
||||
file_path = os.path.join(base_dir, filename)
|
||||
lock_path = file_path + ".lock"
|
||||
|
||||
def print0(s="", **kwargs):
|
||||
ddp_rank = int(os.environ.get("RANK", 0))
|
||||
if os.path.exists(file_path):
|
||||
return file_path
|
||||
|
||||
with open(lock_path, 'w') as lock_file:
|
||||
|
||||
# Only a single rank can acquire this lock
|
||||
# All other ranks block until it is released
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||
|
||||
if os.path.exists(file_path):
|
||||
return file_path
|
||||
|
||||
print(f"Downloading {url}...")
|
||||
with urllib.request.urlopen(url) as response:
|
||||
content = response.read().decode('utf-8')
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
print(f"Downloaded to {file_path}")
|
||||
|
||||
# Clean up the lock file after the lock is released
|
||||
try:
|
||||
os.remove(lock_path)
|
||||
except OSError:
|
||||
pass # Ignore if already removed by another process
|
||||
|
||||
return file_path
|
||||
|
||||
def print0(s="",**kwargs):
|
||||
ddp_rank = int(os.environ.get('RANK', 0))
|
||||
if ddp_rank == 0:
|
||||
print(s, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -49,14 +49,38 @@ def eval_with_timeout(formula, max_time=3):
|
|||
|
||||
|
||||
def use_calculator(expr):
|
||||
"""Evaluate a math expression safely."""
|
||||
"""
|
||||
Evaluate a Python expression safely.
|
||||
Supports both math expressions and string operations like .count()
|
||||
"""
|
||||
# Remove commas from numbers
|
||||
expr = expr.replace(",", "")
|
||||
if any(
|
||||
[x not in "0123456789*+-/.() " for x in expr]
|
||||
): # for now disallow non-numeric chars
|
||||
|
||||
# Check if it's a pure math expression (old behavior)
|
||||
if all([x in "0123456789*+-/.() " for x in expr]):
|
||||
if "**" in expr: # disallow power operator
|
||||
return None
|
||||
return eval_with_timeout(expr)
|
||||
|
||||
# Check if it's a string operation we support
|
||||
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
|
||||
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
||||
if not all([x in allowed_chars for x in expr]):
|
||||
return None
|
||||
if "**" in expr: # for now disallow power operator, could be very expensive
|
||||
|
||||
# Disallow dangerous patterns
|
||||
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
|
||||
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
|
||||
'getattr', 'setattr', 'delattr', 'hasattr']
|
||||
expr_lower = expr.lower()
|
||||
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
||||
return None
|
||||
|
||||
# Only allow .count() method for now (can expand later)
|
||||
if '.count(' not in expr:
|
||||
return None
|
||||
|
||||
# Evaluate with timeout
|
||||
return eval_with_timeout(expr)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -341,16 +341,19 @@ class RustBPETokenizer:
|
|||
mask = mask[:max_tokens]
|
||||
return ids, mask
|
||||
|
||||
def visualize_tokenization(self, ids, mask):
|
||||
def visualize_tokenization(self, ids, mask, with_token_id=False):
|
||||
"""Small helper function useful in debugging: visualize the tokenization of render_conversation"""
|
||||
RED = '\033[91m'
|
||||
GREEN = '\033[92m'
|
||||
RESET = '\033[0m'
|
||||
GRAY = '\033[90m'
|
||||
tokens = []
|
||||
for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
|
||||
token_str = self.decode([token_id])
|
||||
color = GREEN if mask_val == 1 else RED
|
||||
tokens.append(f"{color}{token_str}{RESET}")
|
||||
if with_token_id:
|
||||
tokens.append(f"{GRAY}({token_id}){RESET}")
|
||||
return '|'.join(tokens)
|
||||
|
||||
def render_for_completion(self, conversation):
|
||||
|
|
|
|||
|
|
@ -49,6 +49,9 @@ unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
|||
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
||||
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
|
||||
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
# Evaluation
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||
|
|
@ -151,10 +154,6 @@ x, y = next(train_loader) # kick off load of the very first batch of data
|
|||
# Set up hyperparameter schedulers
|
||||
|
||||
# Learning rate scheduler
|
||||
# TODO: experiment with a short warmup for the AdamW params (expecting slight improvement)
|
||||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(warmdown_ratio * num_iterations)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from tasks.humaneval import HumanEval
|
|||
from tasks.mmlu import MMLU
|
||||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.spellingbee import SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
|
||||
|
|
@ -165,6 +166,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
|
|||
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
|
||||
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
|
||||
'GSM8K': partial(GSM8K, subset="main", split="test"),
|
||||
'SpellingBee': partial(SpellingBee, size=256, split="test"),
|
||||
}[task_name]
|
||||
task_object = task_module()
|
||||
# Run the evaluation
|
||||
|
|
@ -204,13 +206,14 @@ if __name__ == "__main__":
|
|||
engine = Engine(model, tokenizer)
|
||||
|
||||
# Get the tasks to evaluate on
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval']
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
|
||||
baseline_accuracies = {
|
||||
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'GSM8K': 0.0, # open-ended => 0%
|
||||
'HumanEval': 0.0, # open-ended => 0%
|
||||
'SpellingBee': 0.0, # open-ended => 0%
|
||||
}
|
||||
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from tasks.arc import ARC
|
|||
from tasks.gsm8k import GSM8K
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SFT Hyperparameters
|
||||
|
|
@ -86,7 +87,9 @@ train_ds = TaskMixture([
|
|||
GSM8K(subset="main", split="train"), # 8K rows
|
||||
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
||||
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
|
||||
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
|
||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from tasks.gsm8k import GSM8K
|
|||
from tasks.mmlu import MMLU
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
|
|
@ -100,7 +101,9 @@ train_dataset = TaskMixture([
|
|||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||
]) # total: 460K + 100K + 8K = 568K rows
|
||||
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
|
||||
val_dataset = TaskMixture([
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
|
|
|
|||
305
tasks/spellingbee.py
Normal file
305
tasks/spellingbee.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
"""
|
||||
Task intended to make nanochat better in spelling and counting, for example:
|
||||
|
||||
"How many r are in strawberry?" -> 3
|
||||
|
||||
An interesting part of this task is that we will get the assistant to
|
||||
solve the problem using a combination of manual counting and Python.
|
||||
This is a good problem solving "instinct" to mix into the model and RL
|
||||
may further refine it to trust one over the other. If we were extra fancy
|
||||
(which we could/should be) we'd add small errors here and there to allow
|
||||
the model also learn recoveries. We can do this in future versions.
|
||||
|
||||
There are two tasks in this file:
|
||||
1. SpellingBee: Counting the number of occurrences of a letter in a word
|
||||
2. SimpleSpelling: Simply spelling words
|
||||
|
||||
(1) is the goal, but (2) exists as a highly condensed version of the part
|
||||
that makes (1) difficult, which is word spelling. This is non-trivial for an
|
||||
LLM because it has to learn how every token (a little semantic chunk/atom)
|
||||
maps to the sequence of individual characters that make it up. Larger models
|
||||
learn this eventually on their own, but if we want this capability to exist
|
||||
in smaller models, we have to actively encourage it by over-representing it
|
||||
in the training data. Midtraining is a good place to do this.
|
||||
|
||||
To preview a few example conversations, run:
|
||||
python -m tasks.spellingbee
|
||||
"""
|
||||
|
||||
import re
|
||||
import random
|
||||
from tasks.common import Task
|
||||
from nanochat.common import download_file_with_lock
|
||||
|
||||
# Letters of the alphabet
|
||||
LETTERS = "abcdefghijklmnopqrstuvwxyz"
|
||||
# A list of 370K English words of large variety
|
||||
WORD_LIST_URL = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words_alpha.txt"
|
||||
|
||||
# Identical to gsm8k's answer extraction
|
||||
ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
|
||||
def extract_answer(completion):
|
||||
"""
|
||||
Extract the numerical answer after #### marker.
|
||||
"""
|
||||
match = ANSWER_RE.search(completion)
|
||||
if match:
|
||||
match_str = match.group(1).strip()
|
||||
match_str = match_str.replace(",", "")
|
||||
return match_str
|
||||
return None
|
||||
|
||||
# User message templates for data augmentation
|
||||
USER_MSG_TEMPLATES = [
|
||||
"How many {letter} are in the word {word}",
|
||||
"How many {letter} are in {word}",
|
||||
"Count the number of {letter} in {word}",
|
||||
"How many times does {letter} appear in {word}",
|
||||
"What's the count of {letter} in {word}",
|
||||
"In the word {word}, how many {letter} are there",
|
||||
"How many letter {letter} are in the word {word}",
|
||||
"Count how many {letter} appear in {word}",
|
||||
"Tell me the number of {letter} in {word}",
|
||||
"How many occurrences of {letter} are in {word}",
|
||||
"Find the count of {letter} in {word}",
|
||||
"Can you count the {letter} letters in {word}",
|
||||
"What is the frequency of {letter} in {word}",
|
||||
"How many {letter}s are in {word}",
|
||||
"How many {letter}'s are in {word}",
|
||||
"Count all the {letter} in {word}",
|
||||
"How many times is {letter} in {word}",
|
||||
"Number of {letter} in {word}",
|
||||
"Total count of {letter} in {word}",
|
||||
"How many {letter} does {word} have",
|
||||
"How many {letter} does {word} contain",
|
||||
"What's the number of {letter} in {word}",
|
||||
"{word} has how many {letter}",
|
||||
"In {word}, count the {letter}",
|
||||
"How many {letter} appear in {word}",
|
||||
"Count the {letter} in {word}",
|
||||
"Give me the count of {letter} in {word}",
|
||||
"How many instances of {letter} in {word}",
|
||||
"Show me how many {letter} are in {word}",
|
||||
"Calculate the number of {letter} in {word}",
|
||||
# Spanish
|
||||
"¿Cuántas {letter} hay en {word}?",
|
||||
"¿Cuántas veces aparece {letter} en {word}?",
|
||||
"Cuenta las {letter} en {word}",
|
||||
"¿Cuántas letras {letter} tiene {word}?",
|
||||
# Chinese (Simplified)
|
||||
"{word}中有多少个{letter}",
|
||||
"{word}里有几个{letter}",
|
||||
"数一下{word}中的{letter}",
|
||||
"{word}这个词里有多少{letter}",
|
||||
# Korean
|
||||
"{word}에 {letter}가 몇 개 있나요",
|
||||
"{word}에서 {letter}의 개수는",
|
||||
"{word}에 {letter}가 몇 번 나오나요",
|
||||
"{word}라는 단어에 {letter}가 몇 개",
|
||||
# French
|
||||
"Combien de {letter} dans {word}",
|
||||
"Combien de fois {letter} apparaît dans {word}",
|
||||
"Compte les {letter} dans {word}",
|
||||
# German
|
||||
"Wie viele {letter} sind in {word}",
|
||||
"Wie oft kommt {letter} in {word} vor",
|
||||
"Zähle die {letter} in {word}",
|
||||
# Japanese
|
||||
"{word}に{letter}は何個ありますか",
|
||||
"{word}の中に{letter}がいくつ",
|
||||
"{word}に{letter}が何回出てくる",
|
||||
]
|
||||
|
||||
class SpellingBee(Task):
|
||||
|
||||
def __init__(self, size=1000, split="train", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||
self.size = size
|
||||
self.split = split
|
||||
filename = WORD_LIST_URL.split("/")[-1]
|
||||
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
|
||||
with open(word_list_path) as f:
|
||||
words = [line.strip() for line in f]
|
||||
self.words = words
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return 'generative'
|
||||
|
||||
def num_examples(self):
|
||||
return self.size
|
||||
|
||||
def get_example(self, index):
|
||||
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
||||
rng = random.Random(seed)
|
||||
|
||||
# pick a random word
|
||||
word = rng.choice(self.words)
|
||||
# pick a letter from it (90%) or a random letter (10%)
|
||||
letter = rng.choice(word) if rng.random() < 0.9 else rng.choice(LETTERS)
|
||||
|
||||
# get the correct answer by simply counting
|
||||
count = word.count(letter)
|
||||
|
||||
# create a user message, with a bunch of variations as data augmentation
|
||||
template = rng.choice(USER_MSG_TEMPLATES)
|
||||
# 30% chance to lowercase the template (lazy people don't use shift)
|
||||
if rng.random() < 0.3:
|
||||
template = template.lower()
|
||||
quote_options = ['', "'", '"']
|
||||
letter_quote = rng.choice(quote_options) # is the letter quoted?
|
||||
word_quote = rng.choice(quote_options) # is the word quoted?
|
||||
letter_wrapped = f"{letter_quote}{letter}{letter_quote}"
|
||||
word_wrapped = f"{word_quote}{word}{word_quote}"
|
||||
user_msg = template.format(letter=letter_wrapped, word=word_wrapped)
|
||||
if rng.random() < 0.5: # 50% of people don't even use question marks
|
||||
user_msg += "?"
|
||||
|
||||
# Now create the ideal assistant response - build as parts (text + tool calls)
|
||||
assistant_parts = []
|
||||
word_letters = ",".join(list(word))
|
||||
manual_text = f"""We are asked to find the number '{letter}' in the word '{word}'. Let me try a manual approach first.
|
||||
|
||||
First spell the word out:
|
||||
{word}:{word_letters}
|
||||
|
||||
Then count the occurrences of '{letter}':
|
||||
"""
|
||||
# Little simulated loop of the solution process
|
||||
# TODO: This is where the fun starts, we could simulate cute little mistakes
|
||||
# and get the model to review its work and recover from them.
|
||||
# You might of course hope this could arise in RL too, but realistically you'd want to help it out a bit.
|
||||
running_count = 0
|
||||
for i, char in enumerate(word, 1):
|
||||
if char == letter:
|
||||
running_count += 1
|
||||
# note: there deliberately cannot be a space here between i and char
|
||||
# because this would create a different token! (e.g. " a" and "a" are different tokens)
|
||||
manual_text += f"{i}:{char} hit! count={running_count}\n"
|
||||
else:
|
||||
manual_text += f"{i}:{char}\n"
|
||||
|
||||
manual_text += f"\nThis gives us {running_count}."
|
||||
assistant_parts.append({"type": "text", "text": manual_text})
|
||||
# Part 2: Python verification
|
||||
assistant_parts.append({"type": "text", "text": "\n\nLet me double check this using Python:\n\n"})
|
||||
# Part 3: Python tool call
|
||||
python_expr = f"'{word}'.count('{letter}')"
|
||||
assistant_parts.append({"type": "python", "text": python_expr})
|
||||
# Part 4: Python output
|
||||
assistant_parts.append({"type": "python_output", "text": str(count)})
|
||||
# Part 5: Final answer
|
||||
assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"})
|
||||
|
||||
# return the full conversation
|
||||
messages = [
|
||||
{"role": "user", "content": user_msg},
|
||||
{"role": "assistant", "content": assistant_parts}
|
||||
]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
return conversation
|
||||
|
||||
def evaluate(self, conversation, assistant_response):
|
||||
"""
|
||||
Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct)
|
||||
Identical to gsm8k's evaluation.
|
||||
"""
|
||||
assert isinstance(assistant_response, str), "Assuming simple string response for now"
|
||||
# First extract the ground truth answer from the conversation
|
||||
assistant_message = conversation['messages'][-1]
|
||||
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
|
||||
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
|
||||
# The last text part contains the final answer with ####
|
||||
last_text_part = assistant_message['content'][-1]['text']
|
||||
# Extract both the ground truth answer and the predicted answer
|
||||
ref_num = extract_answer(last_text_part)
|
||||
pred_num = extract_answer(assistant_response)
|
||||
# Compare and return the success as int
|
||||
is_correct = int(pred_num == ref_num)
|
||||
return is_correct
|
||||
|
||||
def reward(self, conversation, assistant_response):
|
||||
""" Use simple 0-1 reward just like gsm8k."""
|
||||
is_correct = self.evaluate(conversation, assistant_response)
|
||||
is_correct_float = float(is_correct)
|
||||
return is_correct_float
|
||||
|
||||
|
||||
class SimpleSpelling(Task):
|
||||
"""Much simpler task designed to get the model to just practice spelling words."""
|
||||
|
||||
def __init__(self, size=1000, split="train", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||
self.size = size
|
||||
self.split = split
|
||||
filename = WORD_LIST_URL.split("/")[-1]
|
||||
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
|
||||
with open(word_list_path) as f:
|
||||
words = [line.strip() for line in f]
|
||||
rng = random.Random(42)
|
||||
rng.shuffle(words) # use a different word order than the SpellingBee task
|
||||
self.words = words
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return 'generative'
|
||||
|
||||
def num_examples(self):
|
||||
return self.size
|
||||
|
||||
def get_example(self, index):
|
||||
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
||||
rng = random.Random(seed)
|
||||
# pick a random word
|
||||
word = rng.choice(self.words)
|
||||
word_letters = ",".join(list(word))
|
||||
# return the full conversation
|
||||
messages = [
|
||||
{"role": "user", "content": f"Spell the word: {word}"},
|
||||
{"role": "assistant", "content": f"{word}:{word_letters}"}
|
||||
]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
return conversation
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# preview the SpellingBee task, first 10 examples
|
||||
task = SpellingBee()
|
||||
for i in range(10):
|
||||
ex = task.get_example(i)
|
||||
print("=" * 100)
|
||||
print(ex['messages'][0]['content'])
|
||||
print("-" * 100)
|
||||
# Assistant content is now a list of parts
|
||||
assistant_parts = ex['messages'][1]['content']
|
||||
for part in assistant_parts:
|
||||
if part['type'] == 'text':
|
||||
print(part['text'], end='')
|
||||
elif part['type'] == 'python':
|
||||
print(f"<<{part['text']}=", end='')
|
||||
elif part['type'] == 'python_output':
|
||||
print(f"{part['text']}>>", end='')
|
||||
print()
|
||||
print("-" * 100)
|
||||
|
||||
# # preview the SimpleSpelling task, first 10 examples
|
||||
# task = SimpleSpelling()
|
||||
# for i in range(10):
|
||||
# ex = task.get_example(i)
|
||||
# print("=" * 100)
|
||||
# print(ex['messages'][0]['content'])
|
||||
# print("-" * 100)
|
||||
# print(ex['messages'][1]['content'])
|
||||
|
||||
# # also scrutinize the tokenization (last example only)
|
||||
# from nanochat.tokenizer import get_tokenizer
|
||||
# tokenizer = get_tokenizer()
|
||||
# ids, mask = tokenizer.render_conversation(ex)
|
||||
# print(tokenizer.visualize_tokenization(ids, mask, with_token_id=True))
|
||||
Loading…
Reference in New Issue
Block a user