nanochat/tasks/gsm8k.py
google-labs-jules[bot] 51927a9e60 feat: Add comprehensive end-to-end documentation
This commit introduces extensive documentation across the entire nanochat codebase. The goal is to make the project more accessible, educational, and easier for new contributors to understand.

Key additions include:
- A new "Codebase Overview and Data Flow" section in the main README.md, providing a high-level guide to the project structure and training pipeline.
- Detailed, educational docstrings and inline comments in all Python modules within the `nanochat/`, `scripts/`, and `tasks/` directories.
- Explanations of the rationale and implementation details for key components, including Python equivalents for non-Python code where applicable.
- A new `README.md` in the `rustbpe/` directory explaining the BPE algorithm and the decision to use Rust.
- Comprehensive comments in shell scripts and development scripts in the `dev/` directory, clarifying their purpose and usage.
2025-11-24 12:57:49 +00:00

116 lines
4.6 KiB
Python

"""
This module implements the GSM8K (Grade School Math 8K) task. This dataset consists
of grade school math word problems that require multi-step reasoning.
A unique feature of this dataset is its use of "tool calls" in the answers,
denoted by `<<expression=result>>`. This module parses these tool calls into a
structured conversational format for fine-tuning the model's tool-use capabilities.
**Reference:**
- The GSM8K dataset: https://huggingface.co/datasets/openai/gsm8k
"""
import re
from datasets import load_dataset
from .common import Task
GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion):
"""Extracts the numerical answer from a GSM8K completion string."""
match = GSM_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return match_str
return None
class GSM8K(Task):
"""
The GSM8K (Grade School Math 8K) task.
Args:
subset (str): The subset of the dataset, either "main" or "socratic".
split (str): The data split, either "train" or "test".
"""
def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs)
assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic"
assert split in ["train", "test"], "GSM8K split must be train|test"
self.ds = load_dataset("openai/gsm8k", subset, split=split).shuffle(seed=42)
@property
def eval_type(self):
"""Specifies that this is a generative evaluation task."""
return 'generative'
def num_examples(self):
"""Returns the total number of examples in the dataset."""
return len(self.ds)
def get_example(self, index):
"""
Formats a single example, parsing tool calls into a structured conversation.
"""
row = self.ds[index]
question = row['question'] # string of the question prompt
answer = row['answer'] # string of the full solution and the answer after #### marker
# Create and return the Conversation object
# This is tricky because GSM8K uses tool calls, which we need to parse here.
assistant_message_parts = []
parts = re.split(r'(<<[^>]+>>)', answer)
for part in parts:
if part.startswith('<<') and part.endswith('>>'):
# This is a calculator tool call
inner = part[2:-2] # Remove << >>
# Split on = to get expression and result
if '=' in inner:
expr, result = inner.rsplit('=', 1)
else:
expr, result = inner, ""
# Add the tool call as a part
assistant_message_parts.append({"type": "python", "text": expr})
# Add the result as a part
assistant_message_parts.append({"type": "python_output", "text": result})
else:
# Regular text in between tool calls
assistant_message_parts.append({"type": "text", "text": part})
# No put it all together
messages = [
{"role": "user", "content": question}, # note: simple string
{"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts)
]
conversation = {
"messages": messages,
}
return conversation
def evaluate(self, conversation, assistant_response):
"""
Evaluates the model's response by comparing the extracted numerical answer
to the ground truth.
"""
assert isinstance(assistant_response, str), "Assuming simple string response for now"
# First extract the ground truth answer
assistant_message = conversation['messages'][-1]
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K
# Extract both the ground truth answer and the predicted answer
ref_num = extract_answer(last_text_part)
pred_num = extract_answer(assistant_response)
# Compare and return the success as int
is_correct = int(pred_num == ref_num)
return is_correct
def reward(self, conversation, assistant_response):
"""
Provides a reward for reinforcement learning, which is simply whether the
answer was correct or not.
"""
is_correct = self.evaluate(conversation, assistant_response)
is_correct_float = float(is_correct)
return is_correct_float