diff --git a/tasks/common.py b/tasks/common.py index dcd2e91..afa47cc 100644 --- a/tasks/common.py +++ b/tasks/common.py @@ -34,7 +34,16 @@ class Task: def __len__(self): start = self.start - stop = self.num_examples() if self.stop is None else self.stop + if self.stop is not None: + num_ex = self.num_examples() + if self.stop > num_ex: + raise ValueError( + f"Stop parameter ({self.stop}) exceeds dataset size ({num_ex}). " + f"Please use stop <= {num_ex} or remove the stop parameter to use the full dataset." + ) + stop = self.stop + else: + stop = self.num_examples() step = self.step span = stop - start num = (span + step - 1) // step # ceil_div(span, step)