mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-19 04:29:09 +00:00
Extension - supporting Spelling Bee but for digits (e.g. how many 2 in 6789022?)
This commit is contained in:
parent
c75fe54aa7
commit
6da078295a
|
|
@ -24,6 +24,7 @@ from tasks.mmlu import MMLU
|
|||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.spellingbee import SpellingBee
|
||||
from tasks.spellingbee_digits import SpellingBeeDigits
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
|
||||
|
|
@ -167,6 +168,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
|
|||
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
|
||||
'GSM8K': partial(GSM8K, subset="main", split="test"),
|
||||
'SpellingBee': partial(SpellingBee, size=256, split="test"),
|
||||
'SpellingBeeDigits': partial(SpellingBeeDigits, size=100, split = "test")
|
||||
}[task_name]
|
||||
task_object = task_module()
|
||||
# Run the evaluation
|
||||
|
|
@ -214,6 +216,7 @@ if __name__ == "__main__":
|
|||
'GSM8K': 0.0, # open-ended => 0%
|
||||
'HumanEval': 0.0, # open-ended => 0%
|
||||
'SpellingBee': 0.0, # open-ended => 0%
|
||||
'SpellingBeeDigits': 0.0, # open-ended => 0%
|
||||
}
|
||||
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from tasks.gsm8k import GSM8K
|
|||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
from tasks.spellingbee_digits import SimpleSpellingDigits, SpellingBeeDigits
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SFT Hyperparameters
|
||||
|
|
@ -89,6 +90,8 @@ train_ds = TaskMixture([
|
|||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
||||
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
SimpleSpellingDigits(size=300, split="train"), # 300 rows of Spelling Bee for digits (e.g. spell out '0' in '27384920'?)
|
||||
SpellingBeeDigits(size=300, split="train"), # 300 rows of Spelling Bee for digits (e.g. how many '0' in '38284090'?)
|
||||
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
|
||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from tasks.mmlu import MMLU
|
|||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
from tasks.spellingbee_digits import SimpleSpellingDigits, SpellingBeeDigits
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
|
|
@ -103,7 +104,9 @@ train_dataset = TaskMixture([
|
|||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
|
||||
SimpleSpellingDigits(size=200000, split="train"), # 80K rows of Spelling Bee (e.g. spell out digits in '64642?)
|
||||
SpellingBeeDigits(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # total: 460K + 100K + 8K + 200K + 80K + 80K = 928K rows
|
||||
val_dataset = TaskMixture([
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
|
|
|
|||
300
tasks/spellingbee_digits.py
Normal file
300
tasks/spellingbee_digits.py
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
"""
|
||||
Task intended to make nanochat better in spelling and counting, for example:
|
||||
|
||||
"How many r are in strawberry?" -> 3
|
||||
|
||||
An interesting part of this task is that we will get the assistant to
|
||||
solve the problem using a combination of manual counting and Python.
|
||||
This is a good problem solving "instinct" to mix into the model and RL
|
||||
may further refine it to trust one over the other. If we were extra fancy
|
||||
(which we could/should be) we'd add small errors here and there to allow
|
||||
the model also learn recoveries. We can do this in future versions.
|
||||
|
||||
There are two tasks in this file:
|
||||
1. SpellingBee: Counting the number of occurrences of a letter in a word
|
||||
2. SimpleSpelling: Simply spelling words
|
||||
|
||||
(1) is the goal, but (2) exists as a highly condensed version of the part
|
||||
that makes (1) difficult, which is word spelling. This is non-trivial for an
|
||||
LLM because it has to learn how every token (a little semantic chunk/atom)
|
||||
maps to the sequence of individual characters that make it up. Larger models
|
||||
learn this eventually on their own, but if we want this capability to exist
|
||||
in smaller models, we have to actively encourage it by over-representing it
|
||||
in the training data. Midtraining is a good place to do this.
|
||||
|
||||
To preview a few example conversations, run:
|
||||
python -m tasks.spellingbee
|
||||
"""
|
||||
|
||||
import re
|
||||
import random
|
||||
from tasks.common import Task
|
||||
from nanochat.common import download_file_with_lock
|
||||
|
||||
# Letters of the alphabet
|
||||
# A list of 370K English words of large variety
|
||||
|
||||
# Identical to gsm8k's answer extraction
|
||||
ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
|
||||
def extract_answer(completion):
|
||||
"""
|
||||
Extract the numerical answer after #### marker.
|
||||
"""
|
||||
match = ANSWER_RE.search(completion)
|
||||
if match:
|
||||
match_str = match.group(1).strip()
|
||||
match_str = match_str.replace(",", "")
|
||||
return match_str
|
||||
return None
|
||||
|
||||
# User message templates for data augmentation
|
||||
USER_MSG_TEMPLATES = [
|
||||
"How many {letter} are in the word {word}",
|
||||
"How many {letter} are in {word}",
|
||||
"Count the number of {letter} in {word}",
|
||||
"How many times does {letter} appear in {word}",
|
||||
"What's the count of {letter} in {word}",
|
||||
"In the word {word}, how many {letter} are there",
|
||||
"How many letter {letter} are in the word {word}",
|
||||
"Count how many {letter} appear in {word}",
|
||||
"Tell me the number of {letter} in {word}",
|
||||
"How many occurrences of {letter} are in {word}",
|
||||
"Find the count of {letter} in {word}",
|
||||
"Can you count the {letter} letters in {word}",
|
||||
"What is the frequency of {letter} in {word}",
|
||||
"How many {letter}s are in {word}",
|
||||
"How many {letter}'s are in {word}",
|
||||
"Count all the {letter} in {word}",
|
||||
"How many times is {letter} in {word}",
|
||||
"Number of {letter} in {word}",
|
||||
"Total count of {letter} in {word}",
|
||||
"How many {letter} does {word} have",
|
||||
"How many {letter} does {word} contain",
|
||||
"What's the number of {letter} in {word}",
|
||||
"{word} has how many {letter}",
|
||||
"In {word}, count the {letter}",
|
||||
"How many {letter} appear in {word}",
|
||||
"Count the {letter} in {word}",
|
||||
"Give me the count of {letter} in {word}",
|
||||
"How many instances of {letter} in {word}",
|
||||
"Show me how many {letter} are in {word}",
|
||||
"Calculate the number of {letter} in {word}",
|
||||
# Spanish
|
||||
"¿Cuántas {letter} hay en {word}?",
|
||||
"¿Cuántas veces aparece {letter} en {word}?",
|
||||
"Cuenta las {letter} en {word}",
|
||||
"¿Cuántas letras {letter} tiene {word}?",
|
||||
# Chinese (Simplified)
|
||||
"{word}中有多少个{letter}",
|
||||
"{word}里有几个{letter}",
|
||||
"数一下{word}中的{letter}",
|
||||
"{word}这个词里有多少{letter}",
|
||||
# Chinese (Traditional)
|
||||
"{word}中有多少個{letter}",
|
||||
"{word}裡有幾個{letter}",
|
||||
"數一下{word}中的{letter}",
|
||||
"{word}這個詞裡有多少{letter}",
|
||||
# Korean
|
||||
"{word}에 {letter}가 몇 개 있나요",
|
||||
"{word}에서 {letter}의 개수는",
|
||||
"{word}에 {letter}가 몇 번 나오나요",
|
||||
"{word}라는 단어에 {letter}가 몇 개",
|
||||
# French
|
||||
"Combien de {letter} dans {word}",
|
||||
"Combien de fois {letter} apparaît dans {word}",
|
||||
"Compte les {letter} dans {word}",
|
||||
# German
|
||||
"Wie viele {letter} sind in {word}",
|
||||
"Wie oft kommt {letter} in {word} vor",
|
||||
"Zähle die {letter} in {word}",
|
||||
# Japanese
|
||||
"{word}に{letter}は何個ありますか",
|
||||
"{word}の中に{letter}がいくつ",
|
||||
"{word}に{letter}が何回出てくる",
|
||||
]
|
||||
|
||||
class SpellingBeeDigits(Task):
|
||||
|
||||
def __init__(self, size=1000, split="train", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||
self.size = size
|
||||
self.split = split
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return 'generative'
|
||||
|
||||
def num_examples(self):
|
||||
return self.size
|
||||
|
||||
def get_example(self, index):
|
||||
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
||||
rng = random.Random(seed)
|
||||
|
||||
# pick a random digit
|
||||
word = str(rng.randint(1, 999999999999999))
|
||||
|
||||
|
||||
# pick a letter from it (90%) or a random letter (10%)
|
||||
letter = str(random.randint(0, 9))
|
||||
|
||||
|
||||
# get the correct answer by simply counting
|
||||
count = word.count(letter)
|
||||
|
||||
|
||||
# create a user message, with a bunch of variations as data augmentation
|
||||
template = rng.choice(USER_MSG_TEMPLATES)
|
||||
# 30% chance to lowercase the template (lazy people don't use shift)
|
||||
if rng.random() < 0.3:
|
||||
template = template.lower()
|
||||
quote_options = ['', "'", '"']
|
||||
letter_quote = rng.choice(quote_options) # is the letter quoted?
|
||||
word_quote = rng.choice(quote_options) # is the word quoted?
|
||||
letter_wrapped = f"{letter_quote}{letter}{letter_quote}"
|
||||
word_wrapped = f"{word_quote}{word}{word_quote}"
|
||||
user_msg = template.format(letter=letter_wrapped, word=word_wrapped)
|
||||
if rng.random() < 0.5: # 50% of people don't even use question marks
|
||||
user_msg += "?"
|
||||
|
||||
# Now create the ideal assistant response - build as parts (text + tool calls)
|
||||
assistant_parts = []
|
||||
word_letters = ",".join(list(word))
|
||||
manual_text = f"""We are asked to find the number of '{letter}' in the word '{word}'. Let me try a manual approach first.
|
||||
|
||||
First spell the digits out:
|
||||
{word}:{word_letters}
|
||||
|
||||
Then count the occurrences of '{letter}':
|
||||
"""
|
||||
# Little simulated loop of the solution process
|
||||
# TODO: This is where the fun starts, we could simulate cute little mistakes
|
||||
# and get the model to review its work and recover from them.
|
||||
# You might of course hope this could arise in RL too, but realistically you'd want to help it out a bit.
|
||||
running_count = 0
|
||||
for i, char in enumerate(word, 1):
|
||||
if char == letter:
|
||||
running_count += 1
|
||||
# note: there deliberately cannot be a space here between i and char
|
||||
# because this would create a different token! (e.g. " a" and "a" are different tokens)
|
||||
manual_text += f"{i}:{char} hit! count={running_count}\n"
|
||||
else:
|
||||
manual_text += f"{i}:{char}\n"
|
||||
|
||||
manual_text += f"\nThis gives us {running_count}."
|
||||
assistant_parts.append({"type": "text", "text": manual_text})
|
||||
# Part 2: Python verification
|
||||
assistant_parts.append({"type": "text", "text": "\n\nLet me double check this using Python:\n\n"})
|
||||
# Part 3: Python tool call
|
||||
python_expr = f"'{word}'.count('{letter}')"
|
||||
assistant_parts.append({"type": "python", "text": python_expr})
|
||||
# Part 4: Python output
|
||||
assistant_parts.append({"type": "python_output", "text": str(count)})
|
||||
# Part 5: Final answer
|
||||
assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"})
|
||||
|
||||
# return the full conversation
|
||||
messages = [
|
||||
{"role": "user", "content": user_msg},
|
||||
{"role": "assistant", "content": assistant_parts}
|
||||
]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
return conversation
|
||||
|
||||
def evaluate(self, conversation, assistant_response):
|
||||
"""
|
||||
Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct)
|
||||
Identical to gsm8k's evaluation.
|
||||
"""
|
||||
assert isinstance(assistant_response, str), "Assuming simple string response for now"
|
||||
# First extract the ground truth answer from the conversation
|
||||
assistant_message = conversation['messages'][-1]
|
||||
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
|
||||
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
|
||||
# The last text part contains the final answer with ####
|
||||
last_text_part = assistant_message['content'][-1]['text']
|
||||
# Extract both the ground truth answer and the predicted answer
|
||||
ref_num = extract_answer(last_text_part)
|
||||
pred_num = extract_answer(assistant_response)
|
||||
# Compare and return the success as int
|
||||
is_correct = int(pred_num == ref_num)
|
||||
return is_correct
|
||||
|
||||
def reward(self, conversation, assistant_response):
|
||||
""" Use simple 0-1 reward just like gsm8k."""
|
||||
is_correct = self.evaluate(conversation, assistant_response)
|
||||
is_correct_float = float(is_correct)
|
||||
return is_correct_float
|
||||
|
||||
|
||||
class SimpleSpellingDigits(Task):
|
||||
"""Much simpler task designed to get the model to just practice spelling words."""
|
||||
|
||||
def __init__(self, size=1000, split="train", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||
self.size = size
|
||||
self.split = split
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return 'generative'
|
||||
|
||||
def num_examples(self):
|
||||
return self.size
|
||||
|
||||
def get_example(self, index):
|
||||
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
||||
rng = random.Random(seed)
|
||||
# pick a random word
|
||||
word = str(random.randint(0, 7890987689))
|
||||
word_letters = ",".join(list(word))
|
||||
# return the full conversation
|
||||
messages = [
|
||||
{"role": "user", "content": f"Spell the word: {word}"},
|
||||
{"role": "assistant", "content": f"{word}:{word_letters}"}
|
||||
]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
return conversation
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# preview the SpellingBee task, first 10 examples
|
||||
task = SpellingBeeDigits()
|
||||
for i in range(3):
|
||||
ex = task.get_example(i)
|
||||
print("=" * 100)
|
||||
print(ex['messages'][0]['content'])
|
||||
print("-" * 100)
|
||||
# Assistant content is now a list of parts
|
||||
assistant_parts = ex['messages'][1]['content']
|
||||
for part in assistant_parts:
|
||||
if part['type'] == 'text':
|
||||
print(part['text'], end='')
|
||||
elif part['type'] == 'python':
|
||||
print(f"<<{part['text']}=", end='')
|
||||
elif part['type'] == 'python_output':
|
||||
print(f"{part['text']}>>", end='')
|
||||
print()
|
||||
print("-" * 100)
|
||||
|
||||
# # preview the SimpleSpellingDigits task, first 10 examples
|
||||
task = SimpleSpellingDigits()
|
||||
for i in range(3):
|
||||
ex = task.get_example(i)
|
||||
print("=" * 100)
|
||||
print(ex['messages'][0]['content'])
|
||||
print("-" * 100)
|
||||
print(ex['messages'][1]['content'])
|
||||
|
||||
# also scrutinize the tokenization (last example only)
|
||||
# from nanochat.tokenizer import get_tokenizer
|
||||
# tokenizer = get_tokenizer()
|
||||
# ids, mask = tokenizer.render_conversation(ex)
|
||||
# print(tokenizer.visualize_tokenization(ids, mask, with_token_id=True))
|
||||
Loading…
Reference in New Issue
Block a user