mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-21 02:50:25 +00:00
add gsm8k-platinum
This commit is contained in:
parent
74d930de49
commit
f696e9ce4c
|
|
@ -164,6 +164,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
|
|||
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
|
||||
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
|
||||
'GSM8K': partial(GSM8K, subset="main", split="test"),
|
||||
'GSM8K-Platinum': partial(GSM8K, subset="platinum", split="test"),
|
||||
}[task_name]
|
||||
task_object = task_module()
|
||||
# Run the evaluation
|
||||
|
|
@ -201,12 +202,13 @@ if __name__ == "__main__":
|
|||
engine = Engine(model, tokenizer)
|
||||
|
||||
# Get the tasks to evaluate on
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval']
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'GSM8K-Platinum', 'HumanEval']
|
||||
baseline_accuracies = {
|
||||
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'GSM8K': 0.0, # open-ended => 0%
|
||||
'GSM8K-Platinum': 0.0, # open-ended => 0%
|
||||
'HumanEval': 0.0, # open-ended => 0%
|
||||
}
|
||||
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
GSM8K evaluation.
|
||||
https://huggingface.co/datasets/openai/gsm8k
|
||||
https://huggingface.co/datasets/madrylab/gsm8k-platinum
|
||||
|
||||
Example problem instance:
|
||||
|
||||
|
|
@ -20,6 +21,24 @@ from tasks.common import Task
|
|||
|
||||
|
||||
GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
|
||||
DATASET_CONFIGS = {
|
||||
"main": {
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"splits": {"train", "test"},
|
||||
},
|
||||
"socratic": {
|
||||
"path": "openai/gsm8k",
|
||||
"name": "socratic",
|
||||
"splits": {"train", "test"},
|
||||
},
|
||||
"platinum": {
|
||||
"path": "madrylab/gsm8k-platinum",
|
||||
"name": "main",
|
||||
"splits": {"test"},
|
||||
},
|
||||
}
|
||||
|
||||
def extract_answer(completion):
|
||||
"""
|
||||
Extract the numerical answer after #### marker.
|
||||
|
|
@ -38,9 +57,10 @@ class GSM8K(Task):
|
|||
|
||||
def __init__(self, subset, split, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic"
|
||||
assert split in ["train", "test"], "GSM8K split must be train|test"
|
||||
self.ds = load_dataset("openai/gsm8k", subset, split=split).shuffle(seed=42)
|
||||
assert subset in DATASET_CONFIGS, f"GSM8K subset must be one of {sorted(DATASET_CONFIGS)}"
|
||||
config = DATASET_CONFIGS[subset]
|
||||
assert split in config["splits"], f"GSM8K subset '{subset}' does not support split '{split}'"
|
||||
self.ds = load_dataset(config["path"], config["name"], split=split).shuffle(seed=42)
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
|
|
|
|||
42
tests/test_gsm8k.py
Normal file
42
tests/test_gsm8k.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
import pytest
|
||||
|
||||
from tasks.gsm8k import DATASET_CONFIGS, GSM8K
|
||||
|
||||
# Simple test to check we are getting the correct rows from the gsm8k datasets.
|
||||
# It does not verify the actual content of the dataset itself.
|
||||
EXPECTED_COUNTS = {
|
||||
("main", "train"): 7473,
|
||||
("main", "test"): 1319,
|
||||
("socratic", "train"): 7473,
|
||||
("socratic", "test"): 1319,
|
||||
("platinum", "test"): 1209,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"subset, split, expected",
|
||||
[
|
||||
(subset, split, count)
|
||||
for (subset, split), count in sorted(EXPECTED_COUNTS.items())
|
||||
],
|
||||
)
|
||||
def test_gsm8k_real_dataset_counts(subset, split, expected):
|
||||
task = GSM8K(subset=subset, split=split)
|
||||
assert task.num_examples() == expected
|
||||
|
||||
|
||||
def test_gsm8k_conversation_structure():
|
||||
task = GSM8K(subset="main", split="test")
|
||||
conversation = task.get_example(0)
|
||||
assert conversation["messages"][0]["role"] == "user"
|
||||
assert conversation["messages"][1]["role"] == "assistant"
|
||||
assert isinstance(conversation["messages"][0]["content"], str)
|
||||
assert isinstance(conversation["messages"][1]["content"], list)
|
||||
|
||||
|
||||
def test_gsm8k_invalid_split_guard():
|
||||
for subset, config in DATASET_CONFIGS.items():
|
||||
disallowed = {"train", "test"} - config["splits"]
|
||||
for split in disallowed:
|
||||
with pytest.raises(AssertionError):
|
||||
GSM8K(subset=subset, split=split)
|
||||
Loading…
Reference in New Issue
Block a user