Cap stop parameter and warn once when it exceeds dataset size

This commit is contained in:
Pyry Takala 2025-11-21 20:51:46 +00:00
parent 85e49943ed
commit a33d04dca1

View File

@ -6,6 +6,9 @@ Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk.
""" """
import random import random
import logging
logger = logging.getLogger(__name__)
class Task: class Task:
""" """
@ -36,14 +39,14 @@ class Task:
start = self.start start = self.start
if self.stop is not None: if self.stop is not None:
num_ex = self.num_examples() num_ex = self.num_examples()
stop = min(self.stop, num_ex) # Gracefully cap at dataset size
if self.stop > num_ex: if self.stop > num_ex:
import warnings # Warn once, then cap stop
warnings.warn( logger.warning(
f"Stop parameter ({self.stop}) exceeds dataset size ({num_ex}). " f"Stop parameter ({self.stop}) exceeds dataset size ({num_ex}). "
f"Using {num_ex} examples instead.", f"Using {num_ex} examples instead."
UserWarning
) )
self.stop = num_ex
stop = self.stop
else: else:
stop = self.num_examples() stop = self.num_examples()
step = self.step step = self.step