From a641b6ca966fdabe81d8c30f25b287f3de9039a3 Mon Sep 17 00:00:00 2001 From: Mathieu Lacage Date: Fri, 13 Mar 2026 13:19:10 +0100 Subject: [PATCH] MMLU main split is named auxiliary_train, not train --- scripts/chat_sft.py | 2 +- tasks/common.py | 4 ++-- tasks/mmlu.py | 9 ++------- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c1adbb6..ab886a7 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -166,7 +166,7 @@ train_tasks = [ SmolTalk(split="train"), # 460K rows of general conversations CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these - *[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch + *[MMLU(subset="all", split="auxiliary_train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch *[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) diff --git a/tasks/common.py b/tasks/common.py index 2d6ddd8..211ff3f 100644 --- a/tasks/common.py +++ b/tasks/common.py @@ -135,12 +135,12 @@ if __name__ == "__main__": # very lightweight test of slicing from tasks.mmlu import MMLU - ds = MMLU(subset="auxiliary_train", split="train") + ds = MMLU(subset="all", split="auxiliary_train") print("Length of MMLU: ", len(ds)) ex = ds[5] print("5th example: ", ex) - ds = MMLU(subset="auxiliary_train", split="train", start=5, stop=10) + ds = MMLU(subset="all", split="auxiliary_train", start=5, stop=10) print("Length of sliced MMLU[5:10]: ", len(ds)) print("0th example of sliced MMLU: ", ds[0]) diff --git a/tasks/mmlu.py b/tasks/mmlu.py index 3ba2254..4721f9f 100644 --- a/tasks/mmlu.py +++ b/tasks/mmlu.py @@ -13,16 +13,11 @@ class MMLU(Task): def __init__(self, subset, split, **kwargs): super().__init__(**kwargs) - assert subset in ["all", "auxiliary_train"], f"subset {subset} must be all|auxiliary_train" - assert split in ["train", "validation", "dev", "test"], f"split {split} must be train|validation|dev|test" - if subset == "auxiliary_train": - assert split == "train", "auxiliary_train must be split into train" + assert subset in ["all"], f"subset {subset} must be all" + assert split in ["auxiliary_train", "validation", "dev", "test"], f"split {split} must be auxiliary_train|validation|dev|test" self.subset = subset self.split = split self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42) - if subset == "auxiliary_train": - # I don't understand why but the auxiliary_train rows have some weird additional 'train' wrapper - self.ds = self.ds.map(lambda row: row['train'], remove_columns=['train']) @property def eval_type(self):