mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-09 22:02:14 +00:00
minor
This commit is contained in:
parent
6da078295a
commit
58b38fcd81
|
|
@ -29,8 +29,8 @@ python -m tasks.spellingbee
|
||||||
import re
|
import re
|
||||||
import random
|
import random
|
||||||
from tasks.common import Task
|
from tasks.common import Task
|
||||||
from nanochat.common import download_file_with_lock
|
|
||||||
|
|
||||||
|
MAX_NUM = 9999999
|
||||||
# Letters of the alphabet
|
# Letters of the alphabet
|
||||||
# A list of 370K English words of large variety
|
# A list of 370K English words of large variety
|
||||||
|
|
||||||
|
|
@ -130,14 +130,14 @@ class SpellingBeeDigits(Task):
|
||||||
|
|
||||||
def get_example(self, index):
|
def get_example(self, index):
|
||||||
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
||||||
rng = random.Random(seed)
|
rng = random
|
||||||
|
|
||||||
# pick a random digit
|
# pick a random digit
|
||||||
word = str(rng.randint(1, 999999999999999))
|
word = str(rng.randint(1, MAX_NUM))
|
||||||
|
|
||||||
|
|
||||||
# pick a letter from it (90%) or a random letter (10%)
|
# pick a letter from it (90%) or a random letter (10%)
|
||||||
letter = str(random.randint(0, 9))
|
letter = str(rng.randint(0, 9))
|
||||||
|
|
||||||
|
|
||||||
# get the correct answer by simply counting
|
# get the correct answer by simply counting
|
||||||
|
|
@ -248,9 +248,9 @@ class SimpleSpellingDigits(Task):
|
||||||
|
|
||||||
def get_example(self, index):
|
def get_example(self, index):
|
||||||
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
||||||
rng = random.Random(seed)
|
rng = random
|
||||||
# pick a random word
|
# pick a random word
|
||||||
word = str(random.randint(0, 7890987689))
|
word = str(rng.randint(0, MAX_NUM))
|
||||||
word_letters = ",".join(list(word))
|
word_letters = ",".join(list(word))
|
||||||
# return the full conversation
|
# return the full conversation
|
||||||
messages = [
|
messages = [
|
||||||
|
|
@ -286,7 +286,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# # preview the SimpleSpellingDigits task, first 10 examples
|
# # preview the SimpleSpellingDigits task, first 10 examples
|
||||||
task = SimpleSpellingDigits()
|
task = SimpleSpellingDigits()
|
||||||
for i in range(3):
|
for i in range(5):
|
||||||
ex = task.get_example(i)
|
ex = task.get_example(i)
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
print(ex['messages'][0]['content'])
|
print(ex['messages'][0]['content'])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user