mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
add the SpellingBee task so that nanochat can count r in strawberry etc. along the way we had to add a bunch of new functionality, e.g. extend the calculator to support the count function of python. possibly the current TaskMixture uses way too many synthetic examples of SpellingBee because the eval gives us exactly 100% performance on spelling. We can tune this later to reclaim some wall clock time here I think
This commit is contained in:
parent
81597cd616
commit
8892470f29
|
|
@ -5,6 +5,8 @@ Common utilities for nanochat.
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
|
import fcntl
|
||||||
|
import urllib.request
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
@ -56,6 +58,44 @@ def get_base_dir():
|
||||||
os.makedirs(nanochat_dir, exist_ok=True)
|
os.makedirs(nanochat_dir, exist_ok=True)
|
||||||
return nanochat_dir
|
return nanochat_dir
|
||||||
|
|
||||||
|
def download_file_with_lock(url, filename):
|
||||||
|
"""
|
||||||
|
Downloads a file from a URL to a local path in the base directory.
|
||||||
|
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
||||||
|
"""
|
||||||
|
base_dir = get_base_dir()
|
||||||
|
file_path = os.path.join(base_dir, filename)
|
||||||
|
lock_path = file_path + ".lock"
|
||||||
|
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
with open(lock_path, 'w') as lock_file:
|
||||||
|
|
||||||
|
# Only a single rank can acquire this lock
|
||||||
|
# All other ranks block until it is released
|
||||||
|
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||||
|
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
print(f"Downloading {url}...")
|
||||||
|
with urllib.request.urlopen(url) as response:
|
||||||
|
content = response.read().decode('utf-8')
|
||||||
|
|
||||||
|
with open(file_path, 'w') as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
print(f"Downloaded to {file_path}")
|
||||||
|
|
||||||
|
# Clean up the lock file after the lock is released
|
||||||
|
try:
|
||||||
|
os.remove(lock_path)
|
||||||
|
except OSError:
|
||||||
|
pass # Ignore if already removed by another process
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
def print0(s="",**kwargs):
|
def print0(s="",**kwargs):
|
||||||
ddp_rank = int(os.environ.get('RANK', 0))
|
ddp_rank = int(os.environ.get('RANK', 0))
|
||||||
if ddp_rank == 0:
|
if ddp_rank == 0:
|
||||||
|
|
|
||||||
|
|
@ -44,12 +44,38 @@ def eval_with_timeout(formula, max_time=3):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def use_calculator(expr):
|
def use_calculator(expr):
|
||||||
"""Evaluate a math expression safely."""
|
"""
|
||||||
|
Evaluate a Python expression safely.
|
||||||
|
Supports both math expressions and string operations like .count()
|
||||||
|
"""
|
||||||
|
# Remove commas from numbers
|
||||||
expr = expr.replace(",", "")
|
expr = expr.replace(",", "")
|
||||||
if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars
|
|
||||||
|
# Check if it's a pure math expression (old behavior)
|
||||||
|
if all([x in "0123456789*+-/.() " for x in expr]):
|
||||||
|
if "**" in expr: # disallow power operator
|
||||||
|
return None
|
||||||
|
return eval_with_timeout(expr)
|
||||||
|
|
||||||
|
# Check if it's a string operation we support
|
||||||
|
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
|
||||||
|
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
||||||
|
if not all([x in allowed_chars for x in expr]):
|
||||||
return None
|
return None
|
||||||
if "**" in expr: # for now disallow power operator, could be very expensive
|
|
||||||
|
# Disallow dangerous patterns
|
||||||
|
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
|
||||||
|
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
|
||||||
|
'getattr', 'setattr', 'delattr', 'hasattr']
|
||||||
|
expr_lower = expr.lower()
|
||||||
|
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Only allow .count() method for now (can expand later)
|
||||||
|
if '.count(' not in expr:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Evaluate with timeout
|
||||||
return eval_with_timeout(expr)
|
return eval_with_timeout(expr)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ from tasks.humaneval import HumanEval
|
||||||
from tasks.mmlu import MMLU
|
from tasks.mmlu import MMLU
|
||||||
from tasks.arc import ARC
|
from tasks.arc import ARC
|
||||||
from tasks.gsm8k import GSM8K
|
from tasks.gsm8k import GSM8K
|
||||||
|
from tasks.spellingbee import SpellingBee
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
|
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
|
||||||
|
|
@ -165,6 +166,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
|
||||||
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
|
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
|
||||||
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
|
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
|
||||||
'GSM8K': partial(GSM8K, subset="main", split="test"),
|
'GSM8K': partial(GSM8K, subset="main", split="test"),
|
||||||
|
'SpellingBee': partial(SpellingBee, size=256, split="test"),
|
||||||
}[task_name]
|
}[task_name]
|
||||||
task_object = task_module()
|
task_object = task_module()
|
||||||
# Run the evaluation
|
# Run the evaluation
|
||||||
|
|
@ -204,13 +206,14 @@ if __name__ == "__main__":
|
||||||
engine = Engine(model, tokenizer)
|
engine = Engine(model, tokenizer)
|
||||||
|
|
||||||
# Get the tasks to evaluate on
|
# Get the tasks to evaluate on
|
||||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval']
|
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
|
||||||
baseline_accuracies = {
|
baseline_accuracies = {
|
||||||
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
|
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
|
||||||
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
|
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
|
||||||
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
|
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
|
||||||
'GSM8K': 0.0, # open-ended => 0%
|
'GSM8K': 0.0, # open-ended => 0%
|
||||||
'HumanEval': 0.0, # open-ended => 0%
|
'HumanEval': 0.0, # open-ended => 0%
|
||||||
|
'SpellingBee': 0.0, # open-ended => 0%
|
||||||
}
|
}
|
||||||
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
|
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from tasks.arc import ARC
|
||||||
from tasks.gsm8k import GSM8K
|
from tasks.gsm8k import GSM8K
|
||||||
from tasks.smoltalk import SmolTalk
|
from tasks.smoltalk import SmolTalk
|
||||||
from tasks.customjson import CustomJSON
|
from tasks.customjson import CustomJSON
|
||||||
|
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# SFT Hyperparameters
|
# SFT Hyperparameters
|
||||||
|
|
@ -86,7 +87,9 @@ train_ds = TaskMixture([
|
||||||
GSM8K(subset="main", split="train"), # 8K rows
|
GSM8K(subset="main", split="train"), # 8K rows
|
||||||
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
||||||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
||||||
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
|
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
|
||||||
|
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||||
|
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
|
||||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from tasks.gsm8k import GSM8K
|
||||||
from tasks.mmlu import MMLU
|
from tasks.mmlu import MMLU
|
||||||
from tasks.smoltalk import SmolTalk
|
from tasks.smoltalk import SmolTalk
|
||||||
from tasks.customjson import CustomJSON
|
from tasks.customjson import CustomJSON
|
||||||
|
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||||
|
|
@ -100,7 +101,9 @@ train_dataset = TaskMixture([
|
||||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||||
]) # total: 460K + 100K + 8K = 568K rows
|
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||||
|
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||||
|
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
|
||||||
val_dataset = TaskMixture([
|
val_dataset = TaskMixture([
|
||||||
SmolTalk(split="test"), # 24K rows in test set
|
SmolTalk(split="test"), # 24K rows in test set
|
||||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||||
|
|
|
||||||
296
tasks/spellingbee.py
Normal file
296
tasks/spellingbee.py
Normal file
|
|
@ -0,0 +1,296 @@
|
||||||
|
"""
|
||||||
|
Task intended to make nanochat better in spelling and counting, for example:
|
||||||
|
|
||||||
|
"How many r are in strawberry?" -> 3
|
||||||
|
|
||||||
|
An interesting part of this task is that we will get the assistant to
|
||||||
|
solve the problem using a combination of manual counting and Python.
|
||||||
|
This is a good problem solving "instinct" to mix into the model and RL
|
||||||
|
may further refine it to trust one over the other. If we were extra fancy
|
||||||
|
(which we could/should be) we'd add small errors here and there to allow
|
||||||
|
the model also learn recoveries. We can do this in future versions.
|
||||||
|
|
||||||
|
There are two tasks in this file:
|
||||||
|
1. SpellingBee: Counting the number of occurrences of a letter in a word
|
||||||
|
2. SimpleSpelling: Simply spelling words
|
||||||
|
|
||||||
|
(1) is the goal, but (2) exists as a highly condensed version of the part
|
||||||
|
that makes (1) difficult, which is word spelling. This is non-trivial for an
|
||||||
|
LLM because it has to learn how every token (a little semantic chunk/atom)
|
||||||
|
maps to the sequence of individual characters that make it up. Larger models
|
||||||
|
learn this eventually on their own, but if we want this capability to exist
|
||||||
|
in smaller models, we have to actively encourage it by over-representing it
|
||||||
|
in the training data. Midtraining is a good place to do this.
|
||||||
|
|
||||||
|
To preview a few example conversations, run:
|
||||||
|
python -m tasks.spellingbee
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import random
|
||||||
|
from tasks.common import Task
|
||||||
|
from nanochat.common import download_file_with_lock
|
||||||
|
|
||||||
|
# Letters of the alphabet
|
||||||
|
LETTERS = "abcdefghijklmnopqrstuvwxyz"
|
||||||
|
# A list of 370K English words of large variety
|
||||||
|
WORD_LIST_URL = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words_alpha.txt"
|
||||||
|
|
||||||
|
# Identical to gsm8k's answer extraction
|
||||||
|
ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
|
||||||
|
def extract_answer(completion):
|
||||||
|
"""
|
||||||
|
Extract the numerical answer after #### marker.
|
||||||
|
"""
|
||||||
|
match = ANSWER_RE.search(completion)
|
||||||
|
if match:
|
||||||
|
match_str = match.group(1).strip()
|
||||||
|
match_str = match_str.replace(",", "")
|
||||||
|
return match_str
|
||||||
|
return None
|
||||||
|
|
||||||
|
# User message templates for data augmentation
|
||||||
|
USER_MSG_TEMPLATES = [
|
||||||
|
"How many {letter} are in the word {word}",
|
||||||
|
"How many {letter} are in {word}",
|
||||||
|
"Count the number of {letter} in {word}",
|
||||||
|
"How many times does {letter} appear in {word}",
|
||||||
|
"What's the count of {letter} in {word}",
|
||||||
|
"In the word {word}, how many {letter} are there",
|
||||||
|
"How many letter {letter} are in the word {word}",
|
||||||
|
"Count how many {letter} appear in {word}",
|
||||||
|
"Tell me the number of {letter} in {word}",
|
||||||
|
"How many occurrences of {letter} are in {word}",
|
||||||
|
"Find the count of {letter} in {word}",
|
||||||
|
"Can you count the {letter} letters in {word}",
|
||||||
|
"What is the frequency of {letter} in {word}",
|
||||||
|
"How many {letter}s are in {word}",
|
||||||
|
"How many {letter}'s are in {word}",
|
||||||
|
"Count all the {letter} in {word}",
|
||||||
|
"How many times is {letter} in {word}",
|
||||||
|
"Number of {letter} in {word}",
|
||||||
|
"Total count of {letter} in {word}",
|
||||||
|
"How many {letter} does {word} have",
|
||||||
|
"How many {letter} does {word} contain",
|
||||||
|
"What's the number of {letter} in {word}",
|
||||||
|
"{word} has how many {letter}",
|
||||||
|
"In {word}, count the {letter}",
|
||||||
|
"How many {letter} appear in {word}",
|
||||||
|
"Count the {letter} in {word}",
|
||||||
|
"Give me the count of {letter} in {word}",
|
||||||
|
"How many instances of {letter} in {word}",
|
||||||
|
"Show me how many {letter} are in {word}",
|
||||||
|
"Calculate the number of {letter} in {word}",
|
||||||
|
# Spanish
|
||||||
|
"¿Cuántas {letter} hay en {word}?",
|
||||||
|
"¿Cuántas veces aparece {letter} en {word}?",
|
||||||
|
"Cuenta las {letter} en {word}",
|
||||||
|
"¿Cuántas letras {letter} tiene {word}?",
|
||||||
|
# Chinese (Simplified)
|
||||||
|
"{word}中有多少个{letter}",
|
||||||
|
"{word}里有几个{letter}",
|
||||||
|
"数一下{word}中的{letter}",
|
||||||
|
"{word}这个词里有多少{letter}",
|
||||||
|
# Korean
|
||||||
|
"{word}에 {letter}가 몇 개 있나요",
|
||||||
|
"{word}에서 {letter}의 개수는",
|
||||||
|
"{word}에 {letter}가 몇 번 나오나요",
|
||||||
|
"{word}라는 단어에 {letter}가 몇 개",
|
||||||
|
# French
|
||||||
|
"Combien de {letter} dans {word}",
|
||||||
|
"Combien de fois {letter} apparaît dans {word}",
|
||||||
|
"Compte les {letter} dans {word}",
|
||||||
|
# German
|
||||||
|
"Wie viele {letter} sind in {word}",
|
||||||
|
"Wie oft kommt {letter} in {word} vor",
|
||||||
|
"Zähle die {letter} in {word}",
|
||||||
|
# Japanese
|
||||||
|
"{word}に{letter}は何個ありますか",
|
||||||
|
"{word}の中に{letter}がいくつ",
|
||||||
|
"{word}に{letter}が何回出てくる",
|
||||||
|
]
|
||||||
|
|
||||||
|
class SpellingBee(Task):
|
||||||
|
|
||||||
|
def __init__(self, size=1000, split="train", **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||||
|
self.size = size
|
||||||
|
self.split = split
|
||||||
|
filename = WORD_LIST_URL.split("/")[-1]
|
||||||
|
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
|
||||||
|
with open(word_list_path) as f:
|
||||||
|
words = [line.strip() for line in f]
|
||||||
|
self.words = words
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eval_type(self):
|
||||||
|
return 'generative'
|
||||||
|
|
||||||
|
def num_examples(self):
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
def get_example(self, index):
|
||||||
|
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
||||||
|
rng = random.Random(seed)
|
||||||
|
|
||||||
|
# pick a random word
|
||||||
|
word = rng.choice(self.words)
|
||||||
|
# pick a letter from it (90%) or a random letter (10%)
|
||||||
|
letter = rng.choice(word) if rng.random() < 0.9 else rng.choice(LETTERS)
|
||||||
|
|
||||||
|
# get the correct answer by simply counting
|
||||||
|
count = word.count(letter)
|
||||||
|
|
||||||
|
# create a user message, with a bunch of variations as data augmentation
|
||||||
|
template = rng.choice(USER_MSG_TEMPLATES)
|
||||||
|
# 30% chance to lowercase the template (lazy people don't use shift)
|
||||||
|
if rng.random() < 0.3:
|
||||||
|
template = template.lower()
|
||||||
|
quote_options = ['', "'", '"']
|
||||||
|
letter_quote = rng.choice(quote_options) # is the letter quoted?
|
||||||
|
word_quote = rng.choice(quote_options) # is the word quoted?
|
||||||
|
letter_wrapped = f"{letter_quote}{letter}{letter_quote}"
|
||||||
|
word_wrapped = f"{word_quote}{word}{word_quote}"
|
||||||
|
user_msg = template.format(letter=letter_wrapped, word=word_wrapped)
|
||||||
|
if rng.random() < 0.5: # 50% of people don't even use question marks
|
||||||
|
user_msg += "?"
|
||||||
|
|
||||||
|
# Now create the ideal assistant response - build as parts (text + tool calls)
|
||||||
|
assistant_parts = []
|
||||||
|
word_letters = ",".join(list(word))
|
||||||
|
manual_text = f"""We are asked to find the number '{letter}' in the word '{word}'. Let me try a manual approach first.
|
||||||
|
|
||||||
|
First spell the word out:
|
||||||
|
{word}:{word_letters}
|
||||||
|
|
||||||
|
Then count the occurrences of '{letter}':
|
||||||
|
"""
|
||||||
|
# Little simulated loop of the solution process
|
||||||
|
# TODO: This is where the fun starts, we could simulate cute little mistakes
|
||||||
|
# and get the model to review its work and recover from them.
|
||||||
|
# You might of course hope this could arise in RL too, but realistically you'd want to help it out a bit.
|
||||||
|
running_count = 0
|
||||||
|
for i, char in enumerate(word, 1):
|
||||||
|
if char == letter:
|
||||||
|
running_count += 1
|
||||||
|
# note: there deliberately cannot be a space here between i and char
|
||||||
|
# because this would create a different token! (e.g. " a" and "a" are different tokens)
|
||||||
|
manual_text += f"{i}:{char} hit! count={running_count}\n"
|
||||||
|
else:
|
||||||
|
manual_text += f"{i}:{char}\n"
|
||||||
|
|
||||||
|
manual_text += f"\nThis gives us {running_count}."
|
||||||
|
assistant_parts.append({"type": "text", "text": manual_text})
|
||||||
|
# Part 2: Python verification
|
||||||
|
assistant_parts.append({"type": "text", "text": "\n\nLet me double check this using Python:\n\n"})
|
||||||
|
# Part 3: Python tool call
|
||||||
|
python_expr = f"'{word}'.count('{letter}')"
|
||||||
|
assistant_parts.append({"type": "python", "text": python_expr})
|
||||||
|
# Part 4: Python output
|
||||||
|
assistant_parts.append({"type": "python_output", "text": str(count)})
|
||||||
|
# Part 5: Final answer
|
||||||
|
assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"})
|
||||||
|
|
||||||
|
# return the full conversation
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": user_msg},
|
||||||
|
{"role": "assistant", "content": assistant_parts}
|
||||||
|
]
|
||||||
|
conversation = {
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
return conversation
|
||||||
|
|
||||||
|
def evaluate(self, conversation, assistant_response):
|
||||||
|
"""
|
||||||
|
Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct)
|
||||||
|
Identical to gsm8k's evaluation.
|
||||||
|
"""
|
||||||
|
assert isinstance(assistant_response, str), "Assuming simple string response for now"
|
||||||
|
# First extract the ground truth answer from the conversation
|
||||||
|
assistant_message = conversation['messages'][-1]
|
||||||
|
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
|
||||||
|
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
|
||||||
|
# The last text part contains the final answer with ####
|
||||||
|
last_text_part = assistant_message['content'][-1]['text']
|
||||||
|
# Extract both the ground truth answer and the predicted answer
|
||||||
|
ref_num = extract_answer(last_text_part)
|
||||||
|
pred_num = extract_answer(assistant_response)
|
||||||
|
# Compare and return the success as int
|
||||||
|
is_correct = int(pred_num == ref_num)
|
||||||
|
return is_correct
|
||||||
|
|
||||||
|
def reward(self, conversation, assistant_response):
|
||||||
|
""" Use simple 0-1 reward just like gsm8k."""
|
||||||
|
is_correct = self.evaluate(conversation, assistant_response)
|
||||||
|
is_correct_float = float(is_correct)
|
||||||
|
return is_correct_float
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleSpelling(Task):
|
||||||
|
"""Much simpler task designed to get the model to just practice spelling words."""
|
||||||
|
|
||||||
|
def __init__(self, size=1000, split="train", **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||||
|
self.size = size
|
||||||
|
self.split = split
|
||||||
|
filename = WORD_LIST_URL.split("/")[-1]
|
||||||
|
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
|
||||||
|
with open(word_list_path) as f:
|
||||||
|
words = [line.strip() for line in f]
|
||||||
|
rng = random.Random(42)
|
||||||
|
rng.shuffle(words) # use a different word order than the SpellingBee task
|
||||||
|
self.words = words
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eval_type(self):
|
||||||
|
return 'generative'
|
||||||
|
|
||||||
|
def num_examples(self):
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
def get_example(self, index):
|
||||||
|
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
||||||
|
rng = random.Random(seed)
|
||||||
|
# pick a random word
|
||||||
|
word = rng.choice(self.words)
|
||||||
|
word_letters = ",".join(list(word))
|
||||||
|
# return the full conversation
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"Spell the word: {word}"},
|
||||||
|
{"role": "assistant", "content": f"{word}: {word_letters}"}
|
||||||
|
]
|
||||||
|
conversation = {
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
return conversation
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# preview the SpellingBee task, first 10 examples
|
||||||
|
task = SpellingBee()
|
||||||
|
for i in range(10):
|
||||||
|
ex = task.get_example(i)
|
||||||
|
print("=" * 100)
|
||||||
|
print(ex['messages'][0]['content'])
|
||||||
|
print("-" * 100)
|
||||||
|
# Assistant content is now a list of parts
|
||||||
|
assistant_parts = ex['messages'][1]['content']
|
||||||
|
for part in assistant_parts:
|
||||||
|
if part['type'] == 'text':
|
||||||
|
print(part['text'], end='')
|
||||||
|
elif part['type'] == 'python':
|
||||||
|
print(f"<<{part['text']}=", end='')
|
||||||
|
elif part['type'] == 'python_output':
|
||||||
|
print(f"{part['text']}>>", end='')
|
||||||
|
print()
|
||||||
|
print("-" * 100)
|
||||||
|
|
||||||
|
# also scrutinize the tokenization (last example only)
|
||||||
|
# from nanochat.tokenizer import get_tokenizer
|
||||||
|
# tokenizer = get_tokenizer()
|
||||||
|
# ids, mask = tokenizer.render_conversation(ex)
|
||||||
|
# print(tokenizer.visualize_tokenization(ids, mask, with_token_id=True))
|
||||||
Loading…
Reference in New Issue
Block a user