diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index c77a89e3..40aef9c8 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -24,6 +24,7 @@ from tasks.mmlu import MMLU from tasks.arc import ARC from tasks.gsm8k import GSM8K from tasks.spellingbee import SpellingBee +from tasks.spellingbee_digits import SpellingBeeDigits # ----------------------------------------------------------------------------- # Generative evaluation loop (we go one problem at a time, sample, evaluate) @@ -167,6 +168,7 @@ def run_chat_eval(task_name, model, tokenizer, engine, 'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"), 'GSM8K': partial(GSM8K, subset="main", split="test"), 'SpellingBee': partial(SpellingBee, size=256, split="test"), + 'SpellingBeeDigits': partial(SpellingBeeDigits, size=100, split = "test") }[task_name] task_object = task_module() # Run the evaluation @@ -214,6 +216,7 @@ if __name__ == "__main__": 'GSM8K': 0.0, # open-ended => 0% 'HumanEval': 0.0, # open-ended => 0% 'SpellingBee': 0.0, # open-ended => 0% + 'SpellingBeeDigits': 0.0, # open-ended => 0% } task_names = all_tasks if args.task_name is None else args.task_name.split('|') diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index e6e4565b..a3b70856 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -29,6 +29,7 @@ from tasks.gsm8k import GSM8K from tasks.smoltalk import SmolTalk from tasks.customjson import CustomJSON from tasks.spellingbee import SimpleSpelling, SpellingBee +from tasks.spellingbee_digits import SimpleSpellingDigits, SpellingBeeDigits # ----------------------------------------------------------------------------- # SFT Hyperparameters @@ -89,6 +90,8 @@ train_ds = TaskMixture([ CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple') SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) + SimpleSpellingDigits(size=300, split="train"), # 300 rows of Spelling Bee for digits (e.g. spell out '0' in '27384920'?) + SpellingBeeDigits(size=300, split="train"), # 300 rows of Spelling Bee for digits (e.g. how many '0' in '38284090'?) ]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) diff --git a/scripts/mid_train.py b/scripts/mid_train.py index eedb2620..49961f6a 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -29,6 +29,7 @@ from tasks.mmlu import MMLU from tasks.smoltalk import SmolTalk from tasks.customjson import CustomJSON from tasks.spellingbee import SimpleSpelling, SpellingBee +from tasks.spellingbee_digits import SimpleSpellingDigits, SpellingBeeDigits # ----------------------------------------------------------------------------- run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) @@ -103,7 +104,9 @@ train_dataset = TaskMixture([ CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) -]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows + SimpleSpellingDigits(size=200000, split="train"), # 80K rows of Spelling Bee (e.g. spell out digits in '64642?) + SpellingBeeDigits(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) +]) # total: 460K + 100K + 8K + 200K + 80K + 80K = 928K rows val_dataset = TaskMixture([ SmolTalk(split="test"), # 24K rows in test set MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios diff --git a/tasks/spellingbee_digits.py b/tasks/spellingbee_digits.py new file mode 100644 index 00000000..2de6c893 --- /dev/null +++ b/tasks/spellingbee_digits.py @@ -0,0 +1,300 @@ +""" +Task intended to make nanochat better in spelling and counting, for example: + +"How many r are in strawberry?" -> 3 + +An interesting part of this task is that we will get the assistant to +solve the problem using a combination of manual counting and Python. +This is a good problem solving "instinct" to mix into the model and RL +may further refine it to trust one over the other. If we were extra fancy +(which we could/should be) we'd add small errors here and there to allow +the model also learn recoveries. We can do this in future versions. + +There are two tasks in this file: +1. SpellingBee: Counting the number of occurrences of a letter in a word +2. SimpleSpelling: Simply spelling words + +(1) is the goal, but (2) exists as a highly condensed version of the part +that makes (1) difficult, which is word spelling. This is non-trivial for an +LLM because it has to learn how every token (a little semantic chunk/atom) +maps to the sequence of individual characters that make it up. Larger models +learn this eventually on their own, but if we want this capability to exist +in smaller models, we have to actively encourage it by over-representing it +in the training data. Midtraining is a good place to do this. + +To preview a few example conversations, run: +python -m tasks.spellingbee +""" + +import re +import random +from tasks.common import Task +from nanochat.common import download_file_with_lock + +# Letters of the alphabet +# A list of 370K English words of large variety + +# Identical to gsm8k's answer extraction +ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)") +def extract_answer(completion): + """ + Extract the numerical answer after #### marker. + """ + match = ANSWER_RE.search(completion) + if match: + match_str = match.group(1).strip() + match_str = match_str.replace(",", "") + return match_str + return None + +# User message templates for data augmentation +USER_MSG_TEMPLATES = [ + "How many {letter} are in the word {word}", + "How many {letter} are in {word}", + "Count the number of {letter} in {word}", + "How many times does {letter} appear in {word}", + "What's the count of {letter} in {word}", + "In the word {word}, how many {letter} are there", + "How many letter {letter} are in the word {word}", + "Count how many {letter} appear in {word}", + "Tell me the number of {letter} in {word}", + "How many occurrences of {letter} are in {word}", + "Find the count of {letter} in {word}", + "Can you count the {letter} letters in {word}", + "What is the frequency of {letter} in {word}", + "How many {letter}s are in {word}", + "How many {letter}'s are in {word}", + "Count all the {letter} in {word}", + "How many times is {letter} in {word}", + "Number of {letter} in {word}", + "Total count of {letter} in {word}", + "How many {letter} does {word} have", + "How many {letter} does {word} contain", + "What's the number of {letter} in {word}", + "{word} has how many {letter}", + "In {word}, count the {letter}", + "How many {letter} appear in {word}", + "Count the {letter} in {word}", + "Give me the count of {letter} in {word}", + "How many instances of {letter} in {word}", + "Show me how many {letter} are in {word}", + "Calculate the number of {letter} in {word}", + # Spanish + "¿Cuántas {letter} hay en {word}?", + "¿Cuántas veces aparece {letter} en {word}?", + "Cuenta las {letter} en {word}", + "¿Cuántas letras {letter} tiene {word}?", + # Chinese (Simplified) + "{word}中有多少个{letter}", + "{word}里有几个{letter}", + "数一下{word}中的{letter}", + "{word}这个词里有多少{letter}", + # Chinese (Traditional) + "{word}中有多少個{letter}", + "{word}裡有幾個{letter}", + "數一下{word}中的{letter}", + "{word}這個詞裡有多少{letter}", + # Korean + "{word}에 {letter}가 몇 개 있나요", + "{word}에서 {letter}의 개수는", + "{word}에 {letter}가 몇 번 나오나요", + "{word}라는 단어에 {letter}가 몇 개", + # French + "Combien de {letter} dans {word}", + "Combien de fois {letter} apparaît dans {word}", + "Compte les {letter} dans {word}", + # German + "Wie viele {letter} sind in {word}", + "Wie oft kommt {letter} in {word} vor", + "Zähle die {letter} in {word}", + # Japanese + "{word}に{letter}は何個ありますか", + "{word}の中に{letter}がいくつ", + "{word}に{letter}が何回出てくる", +] + +class SpellingBeeDigits(Task): + + def __init__(self, size=1000, split="train", **kwargs): + super().__init__(**kwargs) + assert split in ["train", "test"], "SpellingBee split must be train|test" + self.size = size + self.split = split + + @property + def eval_type(self): + return 'generative' + + def num_examples(self): + return self.size + + def get_example(self, index): + seed = index if self.split == "train" else -(index + 1) # avoid collision at 0 + rng = random.Random(seed) + + # pick a random digit + word = str(rng.randint(1, 999999999999999)) + + + # pick a letter from it (90%) or a random letter (10%) + letter = str(random.randint(0, 9)) + + + # get the correct answer by simply counting + count = word.count(letter) + + + # create a user message, with a bunch of variations as data augmentation + template = rng.choice(USER_MSG_TEMPLATES) + # 30% chance to lowercase the template (lazy people don't use shift) + if rng.random() < 0.3: + template = template.lower() + quote_options = ['', "'", '"'] + letter_quote = rng.choice(quote_options) # is the letter quoted? + word_quote = rng.choice(quote_options) # is the word quoted? + letter_wrapped = f"{letter_quote}{letter}{letter_quote}" + word_wrapped = f"{word_quote}{word}{word_quote}" + user_msg = template.format(letter=letter_wrapped, word=word_wrapped) + if rng.random() < 0.5: # 50% of people don't even use question marks + user_msg += "?" + + # Now create the ideal assistant response - build as parts (text + tool calls) + assistant_parts = [] + word_letters = ",".join(list(word)) + manual_text = f"""We are asked to find the number of '{letter}' in the word '{word}'. Let me try a manual approach first. + +First spell the digits out: +{word}:{word_letters} + +Then count the occurrences of '{letter}': +""" + # Little simulated loop of the solution process + # TODO: This is where the fun starts, we could simulate cute little mistakes + # and get the model to review its work and recover from them. + # You might of course hope this could arise in RL too, but realistically you'd want to help it out a bit. + running_count = 0 + for i, char in enumerate(word, 1): + if char == letter: + running_count += 1 + # note: there deliberately cannot be a space here between i and char + # because this would create a different token! (e.g. " a" and "a" are different tokens) + manual_text += f"{i}:{char} hit! count={running_count}\n" + else: + manual_text += f"{i}:{char}\n" + + manual_text += f"\nThis gives us {running_count}." + assistant_parts.append({"type": "text", "text": manual_text}) + # Part 2: Python verification + assistant_parts.append({"type": "text", "text": "\n\nLet me double check this using Python:\n\n"}) + # Part 3: Python tool call + python_expr = f"'{word}'.count('{letter}')" + assistant_parts.append({"type": "python", "text": python_expr}) + # Part 4: Python output + assistant_parts.append({"type": "python_output", "text": str(count)}) + # Part 5: Final answer + assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"}) + + # return the full conversation + messages = [ + {"role": "user", "content": user_msg}, + {"role": "assistant", "content": assistant_parts} + ] + conversation = { + "messages": messages, + } + return conversation + + def evaluate(self, conversation, assistant_response): + """ + Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct) + Identical to gsm8k's evaluation. + """ + assert isinstance(assistant_response, str), "Assuming simple string response for now" + # First extract the ground truth answer from the conversation + assistant_message = conversation['messages'][-1] + assert assistant_message['role'] == "assistant", "Last message must be from the Assistant" + assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts" + # The last text part contains the final answer with #### + last_text_part = assistant_message['content'][-1]['text'] + # Extract both the ground truth answer and the predicted answer + ref_num = extract_answer(last_text_part) + pred_num = extract_answer(assistant_response) + # Compare and return the success as int + is_correct = int(pred_num == ref_num) + return is_correct + + def reward(self, conversation, assistant_response): + """ Use simple 0-1 reward just like gsm8k.""" + is_correct = self.evaluate(conversation, assistant_response) + is_correct_float = float(is_correct) + return is_correct_float + + +class SimpleSpellingDigits(Task): + """Much simpler task designed to get the model to just practice spelling words.""" + + def __init__(self, size=1000, split="train", **kwargs): + super().__init__(**kwargs) + assert split in ["train", "test"], "SpellingBee split must be train|test" + self.size = size + self.split = split + + @property + def eval_type(self): + return 'generative' + + def num_examples(self): + return self.size + + def get_example(self, index): + seed = index if self.split == "train" else -(index + 1) # avoid collision at 0 + rng = random.Random(seed) + # pick a random word + word = str(random.randint(0, 7890987689)) + word_letters = ",".join(list(word)) + # return the full conversation + messages = [ + {"role": "user", "content": f"Spell the word: {word}"}, + {"role": "assistant", "content": f"{word}:{word_letters}"} + ] + conversation = { + "messages": messages, + } + return conversation + + +if __name__ == "__main__": + + # preview the SpellingBee task, first 10 examples + task = SpellingBeeDigits() + for i in range(3): + ex = task.get_example(i) + print("=" * 100) + print(ex['messages'][0]['content']) + print("-" * 100) + # Assistant content is now a list of parts + assistant_parts = ex['messages'][1]['content'] + for part in assistant_parts: + if part['type'] == 'text': + print(part['text'], end='') + elif part['type'] == 'python': + print(f"<<{part['text']}=", end='') + elif part['type'] == 'python_output': + print(f"{part['text']}>>", end='') + print() + print("-" * 100) + + # # preview the SimpleSpellingDigits task, first 10 examples + task = SimpleSpellingDigits() + for i in range(3): + ex = task.get_example(i) + print("=" * 100) + print(ex['messages'][0]['content']) + print("-" * 100) + print(ex['messages'][1]['content']) + + # also scrutinize the tokenization (last example only) + # from nanochat.tokenizer import get_tokenizer + # tokenizer = get_tokenizer() + # ids, mask = tokenizer.render_conversation(ex) + # print(tokenizer.visualize_tokenization(ids, mask, with_token_id=True))