diff --git a/tasks/common.py b/tasks/common.py index a63cb7a..540ff6b 100644 --- a/tasks/common.py +++ b/tasks/common.py @@ -6,6 +6,9 @@ Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk. """ import random +import logging + +logger = logging.getLogger(__name__) class Task: """ @@ -36,14 +39,14 @@ class Task: start = self.start if self.stop is not None: num_ex = self.num_examples() - stop = min(self.stop, num_ex) # Gracefully cap at dataset size if self.stop > num_ex: - import warnings - warnings.warn( + # Warn once, then cap stop + logger.warning( f"Stop parameter ({self.stop}) exceeds dataset size ({num_ex}). " - f"Using {num_ex} examples instead.", - UserWarning + f"Using {num_ex} examples instead." ) + self.stop = num_ex + stop = self.stop else: stop = self.num_examples() step = self.step