diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index df6a01a..f076e9c 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -164,6 +164,7 @@ def run_chat_eval(task_name, model, tokenizer, engine, 'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"), 'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"), 'GSM8K': partial(GSM8K, subset="main", split="test"), + 'GSM8K-Platinum': partial(GSM8K, subset="platinum", split="test"), }[task_name] task_object = task_module() # Run the evaluation @@ -201,12 +202,13 @@ if __name__ == "__main__": engine = Engine(model, tokenizer) # Get the tasks to evaluate on - all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval'] + all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'GSM8K-Platinum', 'HumanEval'] baseline_accuracies = { 'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25% 'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25% 'MMLU': 0.25, # multiple choice 1 of 4 => 25% 'GSM8K': 0.0, # open-ended => 0% + 'GSM8K-Platinum': 0.0, # open-ended => 0% 'HumanEval': 0.0, # open-ended => 0% } task_names = all_tasks if args.task_name is None else args.task_name.split('|') diff --git a/tasks/gsm8k.py b/tasks/gsm8k.py index c05e21c..8462e18 100644 --- a/tasks/gsm8k.py +++ b/tasks/gsm8k.py @@ -1,6 +1,7 @@ """ GSM8K evaluation. https://huggingface.co/datasets/openai/gsm8k +https://huggingface.co/datasets/madrylab/gsm8k-platinum Example problem instance: @@ -20,6 +21,24 @@ from tasks.common import Task GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)") +DATASET_CONFIGS = { + "main": { + "path": "openai/gsm8k", + "name": "main", + "splits": {"train", "test"}, + }, + "socratic": { + "path": "openai/gsm8k", + "name": "socratic", + "splits": {"train", "test"}, + }, + "platinum": { + "path": "madrylab/gsm8k-platinum", + "name": "main", + "splits": {"test"}, + }, +} + def extract_answer(completion): """ Extract the numerical answer after #### marker. @@ -38,9 +57,10 @@ class GSM8K(Task): def __init__(self, subset, split, **kwargs): super().__init__(**kwargs) - assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic" - assert split in ["train", "test"], "GSM8K split must be train|test" - self.ds = load_dataset("openai/gsm8k", subset, split=split).shuffle(seed=42) + assert subset in DATASET_CONFIGS, f"GSM8K subset must be one of {sorted(DATASET_CONFIGS)}" + config = DATASET_CONFIGS[subset] + assert split in config["splits"], f"GSM8K subset '{subset}' does not support split '{split}'" + self.ds = load_dataset(config["path"], config["name"], split=split).shuffle(seed=42) @property def eval_type(self): diff --git a/tests/test_gsm8k.py b/tests/test_gsm8k.py new file mode 100644 index 0000000..ab58f2f --- /dev/null +++ b/tests/test_gsm8k.py @@ -0,0 +1,42 @@ +import pytest + +from tasks.gsm8k import DATASET_CONFIGS, GSM8K + +# Simple test to check we are getting the correct rows from the gsm8k datasets. +# It does not verify the actual content of the dataset itself. +EXPECTED_COUNTS = { + ("main", "train"): 7473, + ("main", "test"): 1319, + ("socratic", "train"): 7473, + ("socratic", "test"): 1319, + ("platinum", "test"): 1209, +} + + +@pytest.mark.parametrize( + "subset, split, expected", + [ + (subset, split, count) + for (subset, split), count in sorted(EXPECTED_COUNTS.items()) + ], +) +def test_gsm8k_real_dataset_counts(subset, split, expected): + task = GSM8K(subset=subset, split=split) + assert task.num_examples() == expected + + +def test_gsm8k_conversation_structure(): + task = GSM8K(subset="main", split="test") + conversation = task.get_example(0) + assert conversation["messages"][0]["role"] == "user" + assert conversation["messages"][1]["role"] == "assistant" + assert isinstance(conversation["messages"][0]["content"], str) + assert isinstance(conversation["messages"][1]["content"], list) + + +def test_gsm8k_invalid_split_guard(): + for subset, config in DATASET_CONFIGS.items(): + disallowed = {"train", "test"} - config["splits"] + for split in disallowed: + with pytest.raises(AssertionError): + GSM8K(subset=subset, split=split)