mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Add validation in Task.__len__() to ensure stop parameter does not exceed the actual dataset size. This prevents IndexError crashes during training when invalid stop values are provided. The validation is centralized in the base Task class and preserves the original lazy evaluation behavior - num_examples() is only called when needed (for validation when stop is provided, or for default value when stop is None). Fixes issue where training would crash with IndexError when iterating over Task instances with stop > dataset_size.
157 lines
5.8 KiB
Python
157 lines
5.8 KiB
Python
"""
|
|
Base class for all Tasks.
|
|
A Task is basically a dataset of conversations, together with some
|
|
metadata and often also evaluation criteria.
|
|
Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk.
|
|
"""
|
|
|
|
import random
|
|
|
|
class Task:
|
|
"""
|
|
Base class of a Task. Allows for lightweight slicing of the underlying dataset.
|
|
"""
|
|
|
|
def __init__(self, start=0, stop=None, step=1):
|
|
# allows a lightweight logical view over a dataset
|
|
assert start >= 0, f"Start must be non-negative, got {start}"
|
|
assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}"
|
|
assert step >= 1, f"Step must be strictly positive, got {step}"
|
|
self.start = start
|
|
self.stop = stop # could be None here
|
|
self.step = step
|
|
|
|
@property
|
|
def eval_type(self):
|
|
# one of 'generative' | 'categorical'
|
|
raise NotImplementedError
|
|
|
|
def num_examples(self):
|
|
raise NotImplementedError
|
|
|
|
def get_example(self, index):
|
|
raise NotImplementedError
|
|
|
|
def __len__(self):
|
|
start = self.start
|
|
if self.stop is not None:
|
|
num_ex = self.num_examples()
|
|
if self.stop > num_ex:
|
|
raise ValueError(
|
|
f"Stop parameter ({self.stop}) exceeds dataset size ({num_ex}). "
|
|
f"Please use stop <= {num_ex} or remove the stop parameter to use the full dataset."
|
|
)
|
|
stop = self.stop
|
|
else:
|
|
stop = self.num_examples()
|
|
step = self.step
|
|
span = stop - start
|
|
num = (span + step - 1) // step # ceil_div(span, step)
|
|
assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns
|
|
return num
|
|
|
|
def __getitem__(self, index: int):
|
|
assert isinstance(index, int), f"Index must be an integer, got {type(index)}"
|
|
physical_index = self.start + index * self.step
|
|
conversation = self.get_example(physical_index)
|
|
return conversation
|
|
|
|
def evaluate(self, problem, completion):
|
|
raise NotImplementedError
|
|
|
|
|
|
class TaskMixture(Task):
|
|
"""
|
|
For SFT Training it becomes useful to train on a tax mixture of datasets.
|
|
Fun trick: if you wish to oversample any task, just pass it in multiple times in the list.
|
|
"""
|
|
|
|
def __init__(self, tasks, **kwargs):
|
|
super().__init__(**kwargs)
|
|
# tasks is a list of Task objects
|
|
self.tasks = tasks
|
|
self.lengths = [len(task) for task in self.tasks]
|
|
self.num_conversations = sum(self.lengths)
|
|
# Build list of all (task_idx, local_idx) pairs
|
|
self.index_map = []
|
|
for task_idx, task_length in enumerate(self.lengths):
|
|
for local_idx in range(task_length):
|
|
self.index_map.append((task_idx, local_idx))
|
|
# Deterministically shuffle to mix tasks throughout training
|
|
rng = random.Random(42)
|
|
rng.shuffle(self.index_map)
|
|
# Note: this is not the most elegant or best solution, but it's ok for now
|
|
|
|
def num_examples(self):
|
|
return self.num_conversations
|
|
|
|
def get_example(self, index):
|
|
"""
|
|
Access conversations according to a deterministic shuffle of all examples.
|
|
This ensures tasks are mixed throughout training, regardless of dataset size.
|
|
"""
|
|
assert 0 <= index < self.num_conversations, f"Index {index} out of range for mixture with {self.num_conversations} conversations"
|
|
task_idx, local_idx = self.index_map[index]
|
|
return self.tasks[task_idx][local_idx]
|
|
|
|
|
|
class TaskSequence(Task):
|
|
"""
|
|
For SFT Training sometimes we want to sequentially train on a list of tasks.
|
|
This is useful for cases that require a training curriculum.
|
|
"""
|
|
|
|
def __init__(self, tasks, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.tasks = tasks
|
|
self.lengths = [len(task) for task in self.tasks]
|
|
self.num_conversations = sum(self.lengths)
|
|
|
|
def num_examples(self):
|
|
return self.num_conversations
|
|
|
|
def get_example(self, index):
|
|
assert 0 <= index < self.num_conversations, f"Index {index} out of range for sequence with {self.num_conversations} conversations"
|
|
for task_idx, task_length in enumerate(self.lengths):
|
|
if index < task_length:
|
|
return self.tasks[task_idx][index]
|
|
index -= task_length
|
|
|
|
|
|
def render_mc(question, letters, choices):
|
|
"""
|
|
The common multiple choice rendering format we will use.
|
|
|
|
Note two important design decisions:
|
|
1)
|
|
Bigger models don't care as much, but smaller models prefer to have
|
|
the letter *after* the choice, which results in better binding.
|
|
2)
|
|
There is no whitespace between the delimiter (=) and the letter.
|
|
This is actually critical because the tokenizer has different token ids
|
|
for " A" vs. "A". The assistant responses will be just the letter itself,
|
|
i.e. "A", so it is important that here in the prompt it is the exact same
|
|
token, i.e. "A" with no whitespace before it. Again, bigger models don't care
|
|
about this too much, but smaller models do care about some of these details.
|
|
"""
|
|
query = f"Multiple Choice question: {question}\n"
|
|
query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)])
|
|
query += "\nRespond only with the letter of the correct answer."
|
|
return query
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# very lightweight test of slicing
|
|
from tasks.mmlu import MMLU
|
|
|
|
ds = MMLU(subset="auxiliary_train", split="train")
|
|
print("Length of MMLU: ", len(ds))
|
|
ex = ds[5]
|
|
print("5th example: ", ex)
|
|
|
|
ds = MMLU(subset="auxiliary_train", split="train", start=5, stop=10)
|
|
print("Length of sliced MMLU[5:10]: ", len(ds))
|
|
print("0th example of sliced MMLU: ", ds[0])
|
|
|
|
print("They match: ", ex == ds[0])
|