nanochat/tasks/common.py
google-labs-jules[bot] 51927a9e60 feat: Add comprehensive end-to-end documentation
This commit introduces extensive documentation across the entire nanochat codebase. The goal is to make the project more accessible, educational, and easier for new contributors to understand.

Key additions include:
- A new "Codebase Overview and Data Flow" section in the main README.md, providing a high-level guide to the project structure and training pipeline.
- Detailed, educational docstrings and inline comments in all Python modules within the `nanochat/`, `scripts/`, and `tasks/` directories.
- Explanations of the rationale and implementation details for key components, including Python equivalents for non-Python code where applicable.
- A new `README.md` in the `rustbpe/` directory explaining the BPE algorithm and the decision to use Rust.
- Comprehensive comments in shell scripts and development scripts in the `dev/` directory, clarifying their purpose and usage.
2025-11-24 12:57:49 +00:00

158 lines
5.9 KiB
Python

"""
This module provides the base classes and common utilities for defining evaluation
and training tasks in nanochat.
The core components are:
- `Task`: An abstract base class that represents a dataset of conversations.
- `TaskMixture`: A class for combining multiple tasks into a single, shuffled dataset.
- `TaskSequence`: A class for combining multiple tasks in a sequential manner.
- `render_mc`: A helper function for formatting multiple-choice questions.
"""
import random
class Task:
"""
Abstract base class for a task, which is essentially a dataset of conversations.
This class supports lightweight slicing of the underlying dataset using `start`,
`stop`, and `step` parameters, similar to Python's list slicing.
"""
def __init__(self, start=0, stop=None, step=1):
# allows a lightweight logical view over a dataset
assert start >= 0, f"Start must be non-negative, got {start}"
assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}"
assert step >= 1, f"Step must be strictly positive, got {step}"
self.start = start
self.stop = stop # could be None here
self.step = step
@property
def eval_type(self):
"""The type of evaluation for this task, either 'generative' or 'categorical'."""
# one of 'generative' | 'categorical'
raise NotImplementedError
def num_examples(self):
"""Returns the total number of examples in the underlying dataset."""
raise NotImplementedError
def get_example(self, index):
"""Retrieves a single example from the underlying dataset by its physical index."""
raise NotImplementedError
def __len__(self):
"""Returns the number of examples in the (potentially sliced) view of the dataset."""
start = self.start
stop = self.num_examples() if self.stop is None else self.stop
step = self.step
span = stop - start
num = (span + step - 1) // step # ceil_div(span, step)
assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns
return num
def __getitem__(self, index: int):
"""Retrieves an example by its logical index in the (sliced) view."""
assert isinstance(index, int), f"Index must be an integer, got {type(index)}"
physical_index = self.start + index * self.step
conversation = self.get_example(physical_index)
return conversation
def evaluate(self, problem, completion):
"""Evaluates a model's completion for a given problem."""
raise NotImplementedError
class TaskMixture(Task):
"""
Combines multiple tasks into a single, deterministically shuffled dataset.
This is useful for creating a diverse training mixture for SFT. To oversample
a task, simply include it multiple times in the `tasks` list.
"""
def __init__(self, tasks, **kwargs):
super().__init__(**kwargs)
# tasks is a list of Task objects
self.tasks = tasks
self.lengths = [len(task) for task in self.tasks]
self.num_conversations = sum(self.lengths)
# Build list of all (task_idx, local_idx) pairs
self.index_map = []
for task_idx, task_length in enumerate(self.lengths):
for local_idx in range(task_length):
self.index_map.append((task_idx, local_idx))
# Deterministically shuffle to mix tasks throughout training
rng = random.Random(42)
rng.shuffle(self.index_map)
# Note: this is not the most elegant or best solution, but it's ok for now
def num_examples(self):
return self.num_conversations
def get_example(self, index):
"""
Access conversations according to a deterministic shuffle of all examples.
This ensures tasks are mixed throughout training, regardless of dataset size.
"""
assert 0 <= index < self.num_conversations, f"Index {index} out of range for mixture with {self.num_conversations} conversations"
task_idx, local_idx = self.index_map[index]
return self.tasks[task_idx][local_idx]
class TaskSequence(Task):
"""
Combines multiple tasks sequentially, which is useful for creating a
training curriculum.
"""
def __init__(self, tasks, **kwargs):
super().__init__(**kwargs)
self.tasks = tasks
self.lengths = [len(task) for task in self.tasks]
self.num_conversations = sum(self.lengths)
def num_examples(self):
return self.num_conversations
def get_example(self, index):
assert 0 <= index < self.num_conversations, f"Index {index} out of range for sequence with {self.num_conversations} conversations"
for task_idx, task_length in enumerate(self.lengths):
if index < task_length:
return self.tasks[task_idx][index]
index -= task_length
def render_mc(question, letters, choices):
"""
Formats a multiple-choice question into a standardized prompt.
Args:
question (str): The question text.
letters (list[str]): The letters for the choices (e.g., ['A', 'B', 'C']).
choices (list[str]): The text of the choices.
Returns:
str: The formatted prompt.
"""
query = f"Multiple Choice question: {question}\n"
query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)])
query += "\nRespond only with the letter of the correct answer."
return query
if __name__ == "__main__":
# very lightweight test of slicing
from tasks.mmlu import MMLU
ds = MMLU(subset="auxiliary_train", split="train")
print("Length of MMLU: ", len(ds))
ex = ds[5]
print("5th example: ", ex)
ds = MMLU(subset="auxiliary_train", split="train", start=5, stop=10)
print("Length of sliced MMLU[5:10]: ", len(ds))
print("0th example of sliced MMLU: ", ds[0])
print("They match: ", ex == ds[0])