mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
This commit introduces extensive documentation across the entire nanochat codebase. The goal is to make the project more accessible, educational, and easier for new contributors to understand. Key additions include: - A new "Codebase Overview and Data Flow" section in the main README.md, providing a high-level guide to the project structure and training pipeline. - Detailed, educational docstrings and inline comments in all Python modules within the `nanochat/`, `scripts/`, and `tasks/` directories. - Explanations of the rationale and implementation details for key components, including Python equivalents for non-Python code where applicable. - A new `README.md` in the `rustbpe/` directory explaining the BPE algorithm and the decision to use Rust. - Comprehensive comments in shell scripts and development scripts in the `dev/` directory, clarifying their purpose and usage.
158 lines
5.9 KiB
Python
158 lines
5.9 KiB
Python
"""
|
|
This module provides the base classes and common utilities for defining evaluation
|
|
and training tasks in nanochat.
|
|
|
|
The core components are:
|
|
- `Task`: An abstract base class that represents a dataset of conversations.
|
|
- `TaskMixture`: A class for combining multiple tasks into a single, shuffled dataset.
|
|
- `TaskSequence`: A class for combining multiple tasks in a sequential manner.
|
|
- `render_mc`: A helper function for formatting multiple-choice questions.
|
|
"""
|
|
|
|
import random
|
|
|
|
class Task:
|
|
"""
|
|
Abstract base class for a task, which is essentially a dataset of conversations.
|
|
|
|
This class supports lightweight slicing of the underlying dataset using `start`,
|
|
`stop`, and `step` parameters, similar to Python's list slicing.
|
|
"""
|
|
|
|
def __init__(self, start=0, stop=None, step=1):
|
|
# allows a lightweight logical view over a dataset
|
|
assert start >= 0, f"Start must be non-negative, got {start}"
|
|
assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}"
|
|
assert step >= 1, f"Step must be strictly positive, got {step}"
|
|
self.start = start
|
|
self.stop = stop # could be None here
|
|
self.step = step
|
|
|
|
@property
|
|
def eval_type(self):
|
|
"""The type of evaluation for this task, either 'generative' or 'categorical'."""
|
|
# one of 'generative' | 'categorical'
|
|
raise NotImplementedError
|
|
|
|
def num_examples(self):
|
|
"""Returns the total number of examples in the underlying dataset."""
|
|
raise NotImplementedError
|
|
|
|
def get_example(self, index):
|
|
"""Retrieves a single example from the underlying dataset by its physical index."""
|
|
raise NotImplementedError
|
|
|
|
def __len__(self):
|
|
"""Returns the number of examples in the (potentially sliced) view of the dataset."""
|
|
start = self.start
|
|
stop = self.num_examples() if self.stop is None else self.stop
|
|
step = self.step
|
|
span = stop - start
|
|
num = (span + step - 1) // step # ceil_div(span, step)
|
|
assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns
|
|
return num
|
|
|
|
def __getitem__(self, index: int):
|
|
"""Retrieves an example by its logical index in the (sliced) view."""
|
|
assert isinstance(index, int), f"Index must be an integer, got {type(index)}"
|
|
physical_index = self.start + index * self.step
|
|
conversation = self.get_example(physical_index)
|
|
return conversation
|
|
|
|
def evaluate(self, problem, completion):
|
|
"""Evaluates a model's completion for a given problem."""
|
|
raise NotImplementedError
|
|
|
|
|
|
class TaskMixture(Task):
|
|
"""
|
|
Combines multiple tasks into a single, deterministically shuffled dataset.
|
|
This is useful for creating a diverse training mixture for SFT. To oversample
|
|
a task, simply include it multiple times in the `tasks` list.
|
|
"""
|
|
|
|
def __init__(self, tasks, **kwargs):
|
|
super().__init__(**kwargs)
|
|
# tasks is a list of Task objects
|
|
self.tasks = tasks
|
|
self.lengths = [len(task) for task in self.tasks]
|
|
self.num_conversations = sum(self.lengths)
|
|
# Build list of all (task_idx, local_idx) pairs
|
|
self.index_map = []
|
|
for task_idx, task_length in enumerate(self.lengths):
|
|
for local_idx in range(task_length):
|
|
self.index_map.append((task_idx, local_idx))
|
|
# Deterministically shuffle to mix tasks throughout training
|
|
rng = random.Random(42)
|
|
rng.shuffle(self.index_map)
|
|
# Note: this is not the most elegant or best solution, but it's ok for now
|
|
|
|
def num_examples(self):
|
|
return self.num_conversations
|
|
|
|
def get_example(self, index):
|
|
"""
|
|
Access conversations according to a deterministic shuffle of all examples.
|
|
This ensures tasks are mixed throughout training, regardless of dataset size.
|
|
"""
|
|
assert 0 <= index < self.num_conversations, f"Index {index} out of range for mixture with {self.num_conversations} conversations"
|
|
task_idx, local_idx = self.index_map[index]
|
|
return self.tasks[task_idx][local_idx]
|
|
|
|
|
|
class TaskSequence(Task):
|
|
"""
|
|
Combines multiple tasks sequentially, which is useful for creating a
|
|
training curriculum.
|
|
"""
|
|
|
|
def __init__(self, tasks, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.tasks = tasks
|
|
self.lengths = [len(task) for task in self.tasks]
|
|
self.num_conversations = sum(self.lengths)
|
|
|
|
def num_examples(self):
|
|
return self.num_conversations
|
|
|
|
def get_example(self, index):
|
|
assert 0 <= index < self.num_conversations, f"Index {index} out of range for sequence with {self.num_conversations} conversations"
|
|
for task_idx, task_length in enumerate(self.lengths):
|
|
if index < task_length:
|
|
return self.tasks[task_idx][index]
|
|
index -= task_length
|
|
|
|
|
|
def render_mc(question, letters, choices):
|
|
"""
|
|
Formats a multiple-choice question into a standardized prompt.
|
|
|
|
Args:
|
|
question (str): The question text.
|
|
letters (list[str]): The letters for the choices (e.g., ['A', 'B', 'C']).
|
|
choices (list[str]): The text of the choices.
|
|
|
|
Returns:
|
|
str: The formatted prompt.
|
|
"""
|
|
query = f"Multiple Choice question: {question}\n"
|
|
query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)])
|
|
query += "\nRespond only with the letter of the correct answer."
|
|
return query
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# very lightweight test of slicing
|
|
from tasks.mmlu import MMLU
|
|
|
|
ds = MMLU(subset="auxiliary_train", split="train")
|
|
print("Length of MMLU: ", len(ds))
|
|
ex = ds[5]
|
|
print("5th example: ", ex)
|
|
|
|
ds = MMLU(subset="auxiliary_train", split="train", start=5, stop=10)
|
|
print("Length of sliced MMLU[5:10]: ", len(ds))
|
|
print("0th example of sliced MMLU: ", ds[0])
|
|
|
|
print("They match: ", ex == ds[0])
|