mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
This commit introduces extensive documentation across the entire nanochat codebase. The goal is to make the project more accessible, educational, and easier for new contributors to understand. Key additions include: - A new "Codebase Overview and Data Flow" section in the main README.md, providing a high-level guide to the project structure and training pipeline. - Detailed, educational docstrings and inline comments in all Python modules within the `nanochat/`, `scripts/`, and `tasks/` directories. - Explanations of the rationale and implementation details for key components, including Python equivalents for non-Python code where applicable. - A new `README.md` in the `rustbpe/` directory explaining the BPE algorithm and the decision to use Rust. - Comprehensive comments in shell scripts and development scripts in the `dev/` directory, clarifying their purpose and usage.
116 lines
4.6 KiB
Python
116 lines
4.6 KiB
Python
"""
|
|
This module implements the GSM8K (Grade School Math 8K) task. This dataset consists
|
|
of grade school math word problems that require multi-step reasoning.
|
|
|
|
A unique feature of this dataset is its use of "tool calls" in the answers,
|
|
denoted by `<<expression=result>>`. This module parses these tool calls into a
|
|
structured conversational format for fine-tuning the model's tool-use capabilities.
|
|
|
|
**Reference:**
|
|
- The GSM8K dataset: https://huggingface.co/datasets/openai/gsm8k
|
|
"""
|
|
|
|
import re
|
|
from datasets import load_dataset
|
|
from .common import Task
|
|
|
|
|
|
GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
|
|
def extract_answer(completion):
|
|
"""Extracts the numerical answer from a GSM8K completion string."""
|
|
match = GSM_RE.search(completion)
|
|
if match:
|
|
match_str = match.group(1).strip()
|
|
match_str = match_str.replace(",", "")
|
|
return match_str
|
|
return None
|
|
|
|
|
|
class GSM8K(Task):
|
|
"""
|
|
The GSM8K (Grade School Math 8K) task.
|
|
|
|
Args:
|
|
subset (str): The subset of the dataset, either "main" or "socratic".
|
|
split (str): The data split, either "train" or "test".
|
|
"""
|
|
|
|
def __init__(self, subset, split, **kwargs):
|
|
super().__init__(**kwargs)
|
|
assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic"
|
|
assert split in ["train", "test"], "GSM8K split must be train|test"
|
|
self.ds = load_dataset("openai/gsm8k", subset, split=split).shuffle(seed=42)
|
|
|
|
@property
|
|
def eval_type(self):
|
|
"""Specifies that this is a generative evaluation task."""
|
|
return 'generative'
|
|
|
|
def num_examples(self):
|
|
"""Returns the total number of examples in the dataset."""
|
|
return len(self.ds)
|
|
|
|
def get_example(self, index):
|
|
"""
|
|
Formats a single example, parsing tool calls into a structured conversation.
|
|
"""
|
|
row = self.ds[index]
|
|
question = row['question'] # string of the question prompt
|
|
answer = row['answer'] # string of the full solution and the answer after #### marker
|
|
# Create and return the Conversation object
|
|
# This is tricky because GSM8K uses tool calls, which we need to parse here.
|
|
assistant_message_parts = []
|
|
parts = re.split(r'(<<[^>]+>>)', answer)
|
|
for part in parts:
|
|
if part.startswith('<<') and part.endswith('>>'):
|
|
# This is a calculator tool call
|
|
inner = part[2:-2] # Remove << >>
|
|
# Split on = to get expression and result
|
|
if '=' in inner:
|
|
expr, result = inner.rsplit('=', 1)
|
|
else:
|
|
expr, result = inner, ""
|
|
# Add the tool call as a part
|
|
assistant_message_parts.append({"type": "python", "text": expr})
|
|
# Add the result as a part
|
|
assistant_message_parts.append({"type": "python_output", "text": result})
|
|
else:
|
|
# Regular text in between tool calls
|
|
assistant_message_parts.append({"type": "text", "text": part})
|
|
# No put it all together
|
|
messages = [
|
|
{"role": "user", "content": question}, # note: simple string
|
|
{"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts)
|
|
]
|
|
conversation = {
|
|
"messages": messages,
|
|
}
|
|
return conversation
|
|
|
|
def evaluate(self, conversation, assistant_response):
|
|
"""
|
|
Evaluates the model's response by comparing the extracted numerical answer
|
|
to the ground truth.
|
|
"""
|
|
assert isinstance(assistant_response, str), "Assuming simple string response for now"
|
|
# First extract the ground truth answer
|
|
assistant_message = conversation['messages'][-1]
|
|
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
|
|
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
|
|
last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K
|
|
# Extract both the ground truth answer and the predicted answer
|
|
ref_num = extract_answer(last_text_part)
|
|
pred_num = extract_answer(assistant_response)
|
|
# Compare and return the success as int
|
|
is_correct = int(pred_num == ref_num)
|
|
return is_correct
|
|
|
|
def reward(self, conversation, assistant_response):
|
|
"""
|
|
Provides a reward for reinforcement learning, which is simply whether the
|
|
answer was correct or not.
|
|
"""
|
|
is_correct = self.evaluate(conversation, assistant_response)
|
|
is_correct_float = float(is_correct)
|
|
return is_correct_float
|