mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
411 lines
13 KiB
Python
411 lines
13 KiB
Python
"""
|
|
RAG Task wrapper for retrieval-augmented training.
|
|
|
|
This module wraps existing tasks with retrieval capabilities, allowing
|
|
fine-tuning on conversations augmented with retrieved documents.
|
|
"""
|
|
|
|
from typing import List, Dict, Any, Optional
|
|
from tasks.common import Task
|
|
from nanochat.retrieval import RetrievalManager, Document
|
|
|
|
|
|
class RAGTask(Task):
|
|
"""
|
|
Wraps an existing task with retrieval-augmented conversations.
|
|
|
|
Example usage:
|
|
base_task = SmolTalk(split="train")
|
|
rag_task = RAGTask(
|
|
base_task=base_task,
|
|
knowledge_base_path="data/kb",
|
|
retriever_type="dense",
|
|
top_k=5
|
|
)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_task: Task,
|
|
knowledge_base_path: str,
|
|
retriever_type: str = "simple",
|
|
top_k: int = 5,
|
|
insert_position: str = "before_user",
|
|
**retriever_kwargs
|
|
):
|
|
"""
|
|
Initialize RAG task.
|
|
|
|
Args:
|
|
base_task: Underlying task to augment
|
|
knowledge_base_path: Path to knowledge base
|
|
retriever_type: "simple" or "dense"
|
|
top_k: Number of documents to retrieve
|
|
insert_position: Where to insert retrieval ("before_user", "after_system")
|
|
**retriever_kwargs: Additional retriever arguments
|
|
"""
|
|
# Don't call super().__init__() yet, we'll do it after setting up retrieval
|
|
self.base_task = base_task
|
|
self.top_k = top_k
|
|
self.insert_position = insert_position
|
|
|
|
# Initialize retrieval manager
|
|
self.retrieval_manager = RetrievalManager(
|
|
retriever_type=retriever_type,
|
|
knowledge_base_path=knowledge_base_path,
|
|
**retriever_kwargs
|
|
)
|
|
|
|
# Now call parent init with base_task's slice parameters
|
|
super().__init__(
|
|
start=base_task.start,
|
|
stop=base_task.stop,
|
|
step=base_task.step
|
|
)
|
|
|
|
@property
|
|
def eval_type(self):
|
|
"""Inherit eval type from base task."""
|
|
return self.base_task.eval_type
|
|
|
|
def num_examples(self):
|
|
"""Return number of examples from base task."""
|
|
return self.base_task.num_examples()
|
|
|
|
def get_example(self, index: int):
|
|
"""Get conversation augmented with retrieved documents."""
|
|
# Get base conversation
|
|
conversation = self.base_task.get_example(index)
|
|
|
|
# Augment with retrieval
|
|
augmented = self.retrieval_manager.augment_conversation(
|
|
conversation,
|
|
top_k=self.top_k,
|
|
insert_position=self.insert_position
|
|
)
|
|
|
|
return augmented
|
|
|
|
def evaluate(self, problem, completion):
|
|
"""Delegate evaluation to base task."""
|
|
return self.base_task.evaluate(problem, completion)
|
|
|
|
|
|
class StaticRAGTask(Task):
|
|
"""
|
|
RAG task with pre-retrieved documents (static dataset).
|
|
|
|
Use this when you have a pre-built dataset of conversations with
|
|
retrieval already included.
|
|
|
|
Example format:
|
|
{
|
|
"messages": [
|
|
{"role": "system", "content": "..."},
|
|
{
|
|
"role": "retrieval",
|
|
"documents": [
|
|
{"id": "doc1", "title": "...", "content": "...", "score": 0.9},
|
|
...
|
|
]
|
|
},
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
conversations_path: str,
|
|
split: str = "train",
|
|
**kwargs
|
|
):
|
|
"""
|
|
Initialize static RAG task.
|
|
|
|
Args:
|
|
conversations_path: Path to JSONL file with RAG conversations
|
|
split: Dataset split ("train", "val", "test")
|
|
**kwargs: Additional Task arguments (start, stop, step)
|
|
"""
|
|
super().__init__(**kwargs)
|
|
self.conversations_path = conversations_path
|
|
self.split = split
|
|
self.conversations = self._load_conversations()
|
|
|
|
def _load_conversations(self) -> List[Dict[str, Any]]:
|
|
"""Load conversations from JSONL file."""
|
|
import json
|
|
conversations = []
|
|
|
|
# Try to load with split suffix first
|
|
paths_to_try = [
|
|
f"{self.conversations_path}_{self.split}.jsonl",
|
|
f"{self.conversations_path}/{self.split}.jsonl",
|
|
self.conversations_path # Fallback to path as-is
|
|
]
|
|
|
|
for path in paths_to_try:
|
|
try:
|
|
with open(path, 'r') as f:
|
|
for line in f:
|
|
conversations.append(json.loads(line))
|
|
return conversations
|
|
except FileNotFoundError:
|
|
continue
|
|
|
|
raise FileNotFoundError(
|
|
f"Could not find conversations file. Tried: {paths_to_try}"
|
|
)
|
|
|
|
@property
|
|
def eval_type(self):
|
|
"""RAG tasks are generative."""
|
|
return "generative"
|
|
|
|
def num_examples(self):
|
|
"""Return number of conversations."""
|
|
return len(self.conversations)
|
|
|
|
def get_example(self, index: int):
|
|
"""Get conversation by index."""
|
|
return self.conversations[index]
|
|
|
|
def evaluate(self, problem, completion):
|
|
"""Basic evaluation (can be overridden)."""
|
|
# Simple exact match for now
|
|
return completion.strip() == problem.get("expected_answer", "").strip()
|
|
|
|
|
|
class MultiHopRAGTask(Task):
|
|
"""
|
|
Multi-hop RAG task with recursive retrieval.
|
|
|
|
This task performs multiple rounds of retrieval, where each round's
|
|
results inform the next query.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_task: Task,
|
|
knowledge_base_path: str,
|
|
retriever_type: str = "dense",
|
|
max_hops: int = 3,
|
|
top_k_per_hop: int = 3,
|
|
**retriever_kwargs
|
|
):
|
|
"""
|
|
Initialize multi-hop RAG task.
|
|
|
|
Args:
|
|
base_task: Underlying task
|
|
knowledge_base_path: Path to knowledge base
|
|
retriever_type: Type of retriever
|
|
max_hops: Maximum number of retrieval hops
|
|
top_k_per_hop: Documents to retrieve per hop
|
|
**retriever_kwargs: Additional retriever arguments
|
|
"""
|
|
self.base_task = base_task
|
|
self.max_hops = max_hops
|
|
self.top_k_per_hop = top_k_per_hop
|
|
|
|
# Initialize retrieval manager
|
|
self.retrieval_manager = RetrievalManager(
|
|
retriever_type=retriever_type,
|
|
knowledge_base_path=knowledge_base_path,
|
|
**retriever_kwargs
|
|
)
|
|
|
|
super().__init__(
|
|
start=base_task.start,
|
|
stop=base_task.stop,
|
|
step=base_task.step
|
|
)
|
|
|
|
@property
|
|
def eval_type(self):
|
|
return self.base_task.eval_type
|
|
|
|
def num_examples(self):
|
|
return self.base_task.num_examples()
|
|
|
|
def _extract_followup_query(self, documents: List[Document], original_query: str) -> Optional[str]:
|
|
"""
|
|
Generate follow-up query based on retrieved documents.
|
|
|
|
For now, this is a simple heuristic. In REFRAG, this would use
|
|
the model itself to generate the next query.
|
|
"""
|
|
# Simple heuristic: extract key terms from top document
|
|
if not documents:
|
|
return None
|
|
|
|
top_doc = documents[0]
|
|
# Extract first sentence or first 50 chars as follow-up context
|
|
content = top_doc.content[:50]
|
|
|
|
# For now, just return None to stop recursion
|
|
# In full REFRAG, we'd use the model to generate queries
|
|
return None
|
|
|
|
def get_example(self, index: int):
|
|
"""Get conversation with multi-hop retrieval."""
|
|
# Get base conversation
|
|
conversation = self.base_task.get_example(index)
|
|
|
|
# Extract initial query
|
|
query = self._extract_query(conversation)
|
|
if not query:
|
|
return conversation
|
|
|
|
# Perform multi-hop retrieval
|
|
all_documents = []
|
|
current_query = query
|
|
|
|
for hop in range(self.max_hops):
|
|
# Retrieve for current query
|
|
documents = self.retrieval_manager.retrieve(current_query, self.top_k_per_hop)
|
|
|
|
if not documents:
|
|
break
|
|
|
|
all_documents.append({
|
|
"hop": hop + 1,
|
|
"query": current_query,
|
|
"documents": [doc.to_dict() for doc in documents]
|
|
})
|
|
|
|
# Generate follow-up query
|
|
if hop < self.max_hops - 1:
|
|
current_query = self._extract_followup_query(documents, query)
|
|
if not current_query:
|
|
break
|
|
|
|
# Insert multi-hop retrieval into conversation
|
|
if all_documents:
|
|
retrieval_msg = {
|
|
"role": "retrieval",
|
|
"multi_hop": True,
|
|
"hops": all_documents
|
|
}
|
|
|
|
messages = conversation.get("messages", []).copy()
|
|
# Insert before last user message
|
|
for i in range(len(messages) - 1, -1, -1):
|
|
if messages[i].get("role") == "user":
|
|
messages.insert(i, retrieval_msg)
|
|
break
|
|
|
|
conversation = {"messages": messages}
|
|
|
|
return conversation
|
|
|
|
def _extract_query(self, conversation: Dict[str, Any]) -> str:
|
|
"""Extract query from conversation."""
|
|
messages = conversation.get("messages", [])
|
|
for msg in reversed(messages):
|
|
if msg.get("role") == "user":
|
|
return msg.get("content", "")
|
|
return ""
|
|
|
|
def evaluate(self, problem, completion):
|
|
"""Delegate to base task."""
|
|
return self.base_task.evaluate(problem, completion)
|
|
|
|
|
|
# Utility function for creating RAG tasks
|
|
def create_rag_task(
|
|
task_name: str,
|
|
split: str,
|
|
knowledge_base_path: str,
|
|
retriever_type: str = "simple",
|
|
top_k: int = 5,
|
|
multi_hop: bool = False,
|
|
**kwargs
|
|
) -> Task:
|
|
"""
|
|
Factory function to create RAG-augmented tasks.
|
|
|
|
Args:
|
|
task_name: Name of base task ("SmolTalk", "MMLU", etc.)
|
|
split: Dataset split
|
|
knowledge_base_path: Path to knowledge base
|
|
retriever_type: Type of retriever
|
|
top_k: Documents to retrieve
|
|
multi_hop: Whether to use multi-hop retrieval
|
|
**kwargs: Additional task arguments
|
|
|
|
Returns:
|
|
RAG task instance
|
|
"""
|
|
# Import base task
|
|
if task_name == "SmolTalk":
|
|
from tasks.smoltalk import SmolTalk
|
|
base_task = SmolTalk(split=split, **kwargs)
|
|
elif task_name == "MMLU":
|
|
from tasks.mmlu import MMLU
|
|
base_task = MMLU(split=split, **kwargs)
|
|
elif task_name == "ARC-Easy":
|
|
from tasks.arc import ARC
|
|
base_task = ARC(subset="ARC-Easy", split=split, **kwargs)
|
|
elif task_name == "ARC-Challenge":
|
|
from tasks.arc import ARC
|
|
base_task = ARC(subset="ARC-Challenge", split=split, **kwargs)
|
|
elif task_name == "GSM8K":
|
|
from tasks.gsm8k import GSM8K
|
|
base_task = GSM8K(split=split, **kwargs)
|
|
else:
|
|
raise ValueError(f"Unknown task: {task_name}")
|
|
|
|
# Wrap with RAG
|
|
if multi_hop:
|
|
return MultiHopRAGTask(
|
|
base_task=base_task,
|
|
knowledge_base_path=knowledge_base_path,
|
|
retriever_type=retriever_type,
|
|
top_k_per_hop=top_k,
|
|
)
|
|
else:
|
|
return RAGTask(
|
|
base_task=base_task,
|
|
knowledge_base_path=knowledge_base_path,
|
|
retriever_type=retriever_type,
|
|
top_k=top_k,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Test RAG task
|
|
import sys
|
|
sys.path.append(".")
|
|
|
|
from tasks.smoltalk import SmolTalk
|
|
|
|
print("Testing RAG task wrapper...")
|
|
|
|
# Create base task
|
|
base_task = SmolTalk(split="train", stop=5)
|
|
print(f"Base task has {len(base_task)} examples")
|
|
|
|
# Note: This will fail without a knowledge base, but shows the API
|
|
try:
|
|
rag_task = RAGTask(
|
|
base_task=base_task,
|
|
knowledge_base_path="data/test_kb",
|
|
retriever_type="simple",
|
|
top_k=3
|
|
)
|
|
print(f"RAG task has {len(rag_task)} examples")
|
|
|
|
# Get an example
|
|
example = rag_task[0]
|
|
print(f"\nExample conversation:")
|
|
for msg in example.get("messages", []):
|
|
print(f" {msg.get('role')}: {str(msg)[:100]}...")
|
|
|
|
except FileNotFoundError as e:
|
|
print(f"Knowledge base not found (expected): {e}")
|
|
print("This is normal for testing without a KB.")
|
|
|