This commit is contained in:
Richard Hsu 2025-10-26 12:55:33 -07:00
parent 6da078295a
commit 58b38fcd81

View File

@ -29,8 +29,8 @@ python -m tasks.spellingbee
import re import re
import random import random
from tasks.common import Task from tasks.common import Task
from nanochat.common import download_file_with_lock
MAX_NUM = 9999999
# Letters of the alphabet # Letters of the alphabet
# A list of 370K English words of large variety # A list of 370K English words of large variety
@ -130,14 +130,14 @@ class SpellingBeeDigits(Task):
def get_example(self, index): def get_example(self, index):
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0 seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
rng = random.Random(seed) rng = random
# pick a random digit # pick a random digit
word = str(rng.randint(1, 999999999999999)) word = str(rng.randint(1, MAX_NUM))
# pick a letter from it (90%) or a random letter (10%) # pick a letter from it (90%) or a random letter (10%)
letter = str(random.randint(0, 9)) letter = str(rng.randint(0, 9))
# get the correct answer by simply counting # get the correct answer by simply counting
@ -248,9 +248,9 @@ class SimpleSpellingDigits(Task):
def get_example(self, index): def get_example(self, index):
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0 seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
rng = random.Random(seed) rng = random
# pick a random word # pick a random word
word = str(random.randint(0, 7890987689)) word = str(rng.randint(0, MAX_NUM))
word_letters = ",".join(list(word)) word_letters = ",".join(list(word))
# return the full conversation # return the full conversation
messages = [ messages = [
@ -286,7 +286,7 @@ if __name__ == "__main__":
# # preview the SimpleSpellingDigits task, first 10 examples # # preview the SimpleSpellingDigits task, first 10 examples
task = SimpleSpellingDigits() task = SimpleSpellingDigits()
for i in range(3): for i in range(5):
ex = task.get_example(i) ex = task.get_example(i)
print("=" * 100) print("=" * 100)
print(ex['messages'][0]['content']) print(ex['messages'][0]['content'])