mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
supporting multi-turn for spelling bee tasks
This commit is contained in:
parent
58b38fcd81
commit
ae6dd06489
|
|
@ -30,6 +30,7 @@ import re
|
|||
import random
|
||||
from tasks.common import Task
|
||||
from nanochat.common import download_file_with_lock
|
||||
from datasets import load_dataset
|
||||
|
||||
# Letters of the alphabet
|
||||
LETTERS = "abcdefghijklmnopqrstuvwxyz"
|
||||
|
|
@ -122,6 +123,7 @@ class SpellingBee(Task):
|
|||
with open(word_list_path) as f:
|
||||
words = [line.strip() for line in f]
|
||||
self.words = words
|
||||
self.ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=split)
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
|
|
@ -193,10 +195,12 @@ Then count the occurrences of '{letter}':
|
|||
assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"})
|
||||
|
||||
# return the full conversation
|
||||
messages = [
|
||||
row = self.ds[index]
|
||||
ds_messages = row["messages"]
|
||||
messages = ds_messages.extend([
|
||||
{"role": "user", "content": user_msg},
|
||||
{"role": "assistant", "content": assistant_parts}
|
||||
]
|
||||
])
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
|
|
@ -243,6 +247,7 @@ class SimpleSpelling(Task):
|
|||
rng = random.Random(42)
|
||||
rng.shuffle(words) # use a different word order than the SpellingBee task
|
||||
self.words = words
|
||||
self.ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=split)
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
|
|
@ -258,10 +263,12 @@ class SimpleSpelling(Task):
|
|||
word = rng.choice(self.words)
|
||||
word_letters = ",".join(list(word))
|
||||
# return the full conversation
|
||||
messages = [
|
||||
row = self.ds[index]
|
||||
ds_messages = row["messages"]
|
||||
messages = ds_messages.extend([
|
||||
{"role": "user", "content": f"Spell the word: {word}"},
|
||||
{"role": "assistant", "content": f"{word}:{word_letters}"}
|
||||
]
|
||||
])
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ python -m tasks.spellingbee
|
|||
import re
|
||||
import random
|
||||
from tasks.common import Task
|
||||
from datasets import load_dataset
|
||||
|
||||
MAX_NUM = 9999999
|
||||
# Letters of the alphabet
|
||||
|
|
@ -120,6 +121,7 @@ class SpellingBeeDigits(Task):
|
|||
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||
self.size = size
|
||||
self.split = split
|
||||
self.ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=split)
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
|
|
@ -195,10 +197,13 @@ Then count the occurrences of '{letter}':
|
|||
assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"})
|
||||
|
||||
# return the full conversation
|
||||
messages = [
|
||||
row = self.ds[index]
|
||||
ds_messages = row["messages"]
|
||||
|
||||
messages = ds_messages.extend([
|
||||
{"role": "user", "content": user_msg},
|
||||
{"role": "assistant", "content": assistant_parts}
|
||||
]
|
||||
])
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
|
|
@ -238,6 +243,7 @@ class SimpleSpellingDigits(Task):
|
|||
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||
self.size = size
|
||||
self.split = split
|
||||
self.ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=split)
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
|
|
@ -253,10 +259,13 @@ class SimpleSpellingDigits(Task):
|
|||
word = str(rng.randint(0, MAX_NUM))
|
||||
word_letters = ",".join(list(word))
|
||||
# return the full conversation
|
||||
messages = [
|
||||
row = self.ds[index]
|
||||
ds_messages = row["messages"]
|
||||
|
||||
messages = ds_messages.extend([
|
||||
{"role": "user", "content": f"Spell the word: {word}"},
|
||||
{"role": "assistant", "content": f"{word}:{word_letters}"}
|
||||
]
|
||||
])
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user