add gsm8k-platinum

This commit is contained in:
Qubitium 2025-10-21 02:24:30 +00:00
parent 74d930de49
commit f696e9ce4c
3 changed files with 68 additions and 4 deletions

View File

@ -164,6 +164,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
'GSM8K': partial(GSM8K, subset="main", split="test"),
'GSM8K-Platinum': partial(GSM8K, subset="platinum", split="test"),
}[task_name]
task_object = task_module()
# Run the evaluation
@ -201,12 +202,13 @@ if __name__ == "__main__":
engine = Engine(model, tokenizer)
# Get the tasks to evaluate on
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval']
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'GSM8K-Platinum', 'HumanEval']
baseline_accuracies = {
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
'GSM8K': 0.0, # open-ended => 0%
'GSM8K-Platinum': 0.0, # open-ended => 0%
'HumanEval': 0.0, # open-ended => 0%
}
task_names = all_tasks if args.task_name is None else args.task_name.split('|')

View File

@ -1,6 +1,7 @@
"""
GSM8K evaluation.
https://huggingface.co/datasets/openai/gsm8k
https://huggingface.co/datasets/madrylab/gsm8k-platinum
Example problem instance:
@ -20,6 +21,24 @@ from tasks.common import Task
GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
DATASET_CONFIGS = {
"main": {
"path": "openai/gsm8k",
"name": "main",
"splits": {"train", "test"},
},
"socratic": {
"path": "openai/gsm8k",
"name": "socratic",
"splits": {"train", "test"},
},
"platinum": {
"path": "madrylab/gsm8k-platinum",
"name": "main",
"splits": {"test"},
},
}
def extract_answer(completion):
"""
Extract the numerical answer after #### marker.
@ -38,9 +57,10 @@ class GSM8K(Task):
def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs)
assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic"
assert split in ["train", "test"], "GSM8K split must be train|test"
self.ds = load_dataset("openai/gsm8k", subset, split=split).shuffle(seed=42)
assert subset in DATASET_CONFIGS, f"GSM8K subset must be one of {sorted(DATASET_CONFIGS)}"
config = DATASET_CONFIGS[subset]
assert split in config["splits"], f"GSM8K subset '{subset}' does not support split '{split}'"
self.ds = load_dataset(config["path"], config["name"], split=split).shuffle(seed=42)
@property
def eval_type(self):

42
tests/test_gsm8k.py Normal file
View File

@ -0,0 +1,42 @@
import pytest
from tasks.gsm8k import DATASET_CONFIGS, GSM8K
# Simple test to check we are getting the correct rows from the gsm8k datasets.
# It does not verify the actual content of the dataset itself.
EXPECTED_COUNTS = {
("main", "train"): 7473,
("main", "test"): 1319,
("socratic", "train"): 7473,
("socratic", "test"): 1319,
("platinum", "test"): 1209,
}
@pytest.mark.parametrize(
"subset, split, expected",
[
(subset, split, count)
for (subset, split), count in sorted(EXPECTED_COUNTS.items())
],
)
def test_gsm8k_real_dataset_counts(subset, split, expected):
task = GSM8K(subset=subset, split=split)
assert task.num_examples() == expected
def test_gsm8k_conversation_structure():
task = GSM8K(subset="main", split="test")
conversation = task.get_example(0)
assert conversation["messages"][0]["role"] == "user"
assert conversation["messages"][1]["role"] == "assistant"
assert isinstance(conversation["messages"][0]["content"], str)
assert isinstance(conversation["messages"][1]["content"], list)
def test_gsm8k_invalid_split_guard():
for subset, config in DATASET_CONFIGS.items():
disallowed = {"train", "test"} - config["splits"]
for split in disallowed:
with pytest.raises(AssertionError):
GSM8K(subset=subset, split=split)