diff --git a/tasks/common.py b/tasks/common.py index dcd2e91..f2228eb 100644 --- a/tasks/common.py +++ b/tasks/common.py @@ -6,6 +6,9 @@ Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk. """ import random +import logging + +logger = logging.getLogger(__name__) class Task: """ @@ -34,7 +37,15 @@ class Task: def __len__(self): start = self.start - stop = self.num_examples() if self.stop is None else self.stop + num_ex = self.num_examples() + if self.stop is not None and self.stop > num_ex: + # Warn once, then cap stop + logger.warning( + f"Stop parameter ({self.stop}) exceeds dataset size ({num_ex}). " + f"Using {num_ex} examples instead." + ) + self.stop = num_ex + stop = num_ex if self.stop is None else self.stop step = self.step span = stop - start num = (span + step - 1) // step # ceil_div(span, step)