mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 13:45:21 +00:00
640 lines
19 KiB
Markdown
640 lines
19 KiB
Markdown
# RAG/REFRAG Fine-Tuning Investigation & Design
|
|
|
|
## Executive Summary
|
|
|
|
This document outlines the design and implementation plan for adding **RAG (Retrieval-Augmented Generation)** and **REFRAG (Recursive/Reinforcement RAG)** fine-tuning capabilities to nanochat, specifically for **Mamba and hybrid (Transformer+Mamba) architectures**.
|
|
|
|
**Key Innovation**: Leverage Mamba's linear complexity and long-context capabilities to efficiently process retrieved documents, making RAG more scalable than with pure transformer architectures.
|
|
|
|
---
|
|
|
|
## 1. CONCEPTUAL OVERVIEW
|
|
|
|
### 1.1 What is RAG?
|
|
|
|
**RAG (Retrieval-Augmented Generation)** enhances LLM responses by:
|
|
1. **Retrieving** relevant documents from a knowledge base
|
|
2. **Augmenting** the model's input with retrieved context
|
|
3. **Generating** responses conditioned on both the query and retrieved information
|
|
|
|
### 1.2 What is REFRAG?
|
|
|
|
**REFRAG (Recursive/Reinforcement RAG)** extends RAG with:
|
|
1. **Recursive Retrieval**: Multi-hop retrieval where retrieved docs inform next retrieval
|
|
2. **Reinforcement Learning**: Reward model scores retrieval quality
|
|
3. **Adaptive Context**: Dynamically adjust which documents to include
|
|
|
|
### 1.3 Why Mamba + RAG is Powerful
|
|
|
|
| Feature | Transformer | Mamba | Hybrid (T+M) |
|
|
|---------|-------------|-------|--------------|
|
|
| Context Window | O(n²) cost | O(n) cost | O(n²) early, O(n) late |
|
|
| Long Documents | Expensive | Efficient | Balanced |
|
|
| Retrieval Capacity | Limited by attention | Can handle more docs | Best of both |
|
|
| Fine-tuning Cost | High | Lower | Moderate |
|
|
|
|
**Why This Matters:**
|
|
- Mamba can efficiently process 10K+ token contexts with retrieved documents
|
|
- Hybrid models use attention for retrieval relevance, SSM for document processing
|
|
- Lower memory → more documents in context → better RAG performance
|
|
|
|
---
|
|
|
|
## 2. CURRENT INFRASTRUCTURE ANALYSIS
|
|
|
|
### 2.1 Existing Fine-Tuning Components
|
|
|
|
**chat_sft.py** (Supervised Fine-Tuning):
|
|
- Loads conversations from Task objects
|
|
- Uses `sft_data_generator()` for batching
|
|
- Masks loss on non-assistant tokens
|
|
- Standard gradient descent training
|
|
|
|
**mid_train.py** (Midtraining):
|
|
- Similar to SFT but different task mixture
|
|
- Uses `mid_data_generator()` for streaming
|
|
- Token-level batching from conversations
|
|
|
|
**Key Insight**: Both use conversation-based datasets. RAG will extend this by:
|
|
1. Adding retrieved documents to conversations
|
|
2. Teaching model to condition on retrieved context
|
|
3. Optionally training retrieval scoring
|
|
|
|
### 2.2 Task Infrastructure (tasks/common.py)
|
|
|
|
```python
|
|
class Task:
|
|
def get_example(self, index) -> conversation
|
|
# Returns dict with messages
|
|
```
|
|
|
|
**Extension Point**: Add `RetrievalTask` that augments conversations with retrieved docs.
|
|
|
|
### 2.3 Data Flow
|
|
|
|
Current:
|
|
```
|
|
Dataset → Conversation → Tokenize → Batch → Train
|
|
```
|
|
|
|
With RAG:
|
|
```
|
|
Dataset → Query → Retrieve Docs → Augmented Conversation → Tokenize → Batch → Train
|
|
```
|
|
|
|
---
|
|
|
|
## 3. RAG DATA FORMAT DESIGN
|
|
|
|
### 3.1 RAG-Enhanced Conversation Format
|
|
|
|
```python
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant. Use the provided documents to answer questions."
|
|
},
|
|
{
|
|
"role": "retrieval", # NEW ROLE
|
|
"documents": [
|
|
{
|
|
"id": "doc_123",
|
|
"title": "Document Title",
|
|
"content": "Document content...",
|
|
"score": 0.95, # retrieval score
|
|
"source": "wikipedia"
|
|
},
|
|
# ... more documents
|
|
]
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "What is the capital of France?"
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": "Based on the provided documents, the capital of France is Paris."
|
|
}
|
|
],
|
|
"metadata": {
|
|
"query": "capital of France",
|
|
"retrieval_method": "dense", # or "sparse", "hybrid"
|
|
"num_retrieved": 5
|
|
}
|
|
}
|
|
```
|
|
|
|
### 3.2 REFRAG-Enhanced Format (Recursive)
|
|
|
|
```python
|
|
{
|
|
"messages": [
|
|
{"role": "system", "content": "..."},
|
|
{
|
|
"role": "retrieval",
|
|
"hop": 1, # First retrieval
|
|
"query": "capital of France",
|
|
"documents": [...]
|
|
},
|
|
{
|
|
"role": "retrieval",
|
|
"hop": 2, # Second retrieval (based on first)
|
|
"query": "Paris population and history", # derived query
|
|
"documents": [...]
|
|
},
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### 3.3 Training Data Structure
|
|
|
|
```
|
|
rag_training_data/
|
|
├── knowledge_base/
|
|
│ ├── documents.jsonl # All retrievable documents
|
|
│ ├── embeddings.npy # Precomputed embeddings
|
|
│ └── index.faiss # FAISS index for retrieval
|
|
├── queries/
|
|
│ ├── train.jsonl # Training queries
|
|
│ ├── val.jsonl # Validation queries
|
|
│ └── test.jsonl # Test queries
|
|
└── conversations/
|
|
├── train_rag.jsonl # Augmented conversations
|
|
└── val_rag.jsonl
|
|
```
|
|
|
|
---
|
|
|
|
## 4. RETRIEVAL MECHANISM DESIGN
|
|
|
|
### 4.1 Retrieval Strategies
|
|
|
|
**Strategy 1: Dense Retrieval (Recommended)**
|
|
```python
|
|
class DenseRetriever:
|
|
def __init__(self, encoder_model, index_path):
|
|
self.encoder = load_encoder(encoder_model) # e.g., sentence-transformers
|
|
self.index = faiss.read_index(index_path)
|
|
self.documents = load_documents()
|
|
|
|
def retrieve(self, query, top_k=5):
|
|
query_embedding = self.encoder.encode(query)
|
|
scores, indices = self.index.search(query_embedding, top_k)
|
|
return [self.documents[i] for i in indices]
|
|
```
|
|
|
|
**Strategy 2: Sparse Retrieval (BM25)**
|
|
```python
|
|
class BM25Retriever:
|
|
def __init__(self, documents):
|
|
self.bm25 = BM25(documents)
|
|
|
|
def retrieve(self, query, top_k=5):
|
|
scores = self.bm25.get_scores(query)
|
|
top_indices = np.argsort(scores)[-top_k:]
|
|
return [self.documents[i] for i in top_indices]
|
|
```
|
|
|
|
**Strategy 3: Hybrid (Best Performance)**
|
|
```python
|
|
class HybridRetriever:
|
|
def __init__(self, dense_retriever, sparse_retriever):
|
|
self.dense = dense_retriever
|
|
self.sparse = sparse_retriever
|
|
|
|
def retrieve(self, query, top_k=5, alpha=0.7):
|
|
dense_docs = self.dense.retrieve(query, top_k*2)
|
|
sparse_docs = self.sparse.retrieve(query, top_k*2)
|
|
# Combine and rerank
|
|
combined = self.rerank(dense_docs, sparse_docs, alpha)
|
|
return combined[:top_k]
|
|
```
|
|
|
|
### 4.2 Integration with Nanochat
|
|
|
|
```python
|
|
# New file: nanochat/retrieval.py
|
|
class RetrievalManager:
|
|
"""Manages document retrieval for RAG fine-tuning."""
|
|
|
|
def __init__(self, retriever_type="dense", **kwargs):
|
|
self.retriever = self._create_retriever(retriever_type, **kwargs)
|
|
|
|
def augment_conversation(self, conversation, top_k=5):
|
|
"""Add retrieved documents to a conversation."""
|
|
# Extract query from conversation
|
|
query = self._extract_query(conversation)
|
|
|
|
# Retrieve documents
|
|
documents = self.retriever.retrieve(query, top_k)
|
|
|
|
# Insert retrieval message
|
|
augmented = self._insert_retrieval(conversation, documents)
|
|
|
|
return augmented
|
|
|
|
def _extract_query(self, conversation):
|
|
# Extract user query from last user message
|
|
for msg in reversed(conversation['messages']):
|
|
if msg['role'] == 'user':
|
|
return msg['content']
|
|
return ""
|
|
|
|
def _insert_retrieval(self, conversation, documents):
|
|
# Insert retrieval message before user query
|
|
messages = conversation['messages'].copy()
|
|
retrieval_msg = {
|
|
"role": "retrieval",
|
|
"documents": documents
|
|
}
|
|
# Insert before last user message
|
|
messages.insert(-1, retrieval_msg)
|
|
return {"messages": messages}
|
|
```
|
|
|
|
---
|
|
|
|
## 5. MAMBA-SPECIFIC OPTIMIZATIONS
|
|
|
|
### 5.1 Why Mamba is Ideal for RAG
|
|
|
|
1. **Linear Complexity**: Process 10K+ token contexts efficiently
|
|
2. **Selective Attention**: Can focus on relevant parts of retrieved docs
|
|
3. **State-Based Memory**: Natural for maintaining document context
|
|
4. **Lower Memory**: Fit more documents in same VRAM
|
|
|
|
### 5.2 Hybrid Architecture Strategy
|
|
|
|
**Optimal Pattern for RAG:**
|
|
```python
|
|
# Early layers: Transformer (for cross-document attention/relevance)
|
|
# Middle layers: Hybrid (transition)
|
|
# Late layers: Mamba (for efficient long-context processing)
|
|
|
|
block_pattern = ["T"] * 8 + ["T", "M"] * 2 + ["M"] * 8 # For d20
|
|
```
|
|
|
|
**Rationale:**
|
|
- Early transformers: Learn document relevance and cross-document relationships
|
|
- Late Mamba: Process long concatenated documents efficiently
|
|
- Memory savings: ~40% less activation memory for document processing
|
|
|
|
### 5.3 Context Injection Strategy
|
|
|
|
**Option A: Concatenation (Simple)**
|
|
```
|
|
[SYS] You are helpful. [/SYS]
|
|
[DOC] Doc 1 content... [/DOC]
|
|
[DOC] Doc 2 content... [/DOC]
|
|
[USER] Question? [/USER]
|
|
[ASST] Answer. [/ASST]
|
|
```
|
|
|
|
**Option B: Structured Tokens (Better)**
|
|
```
|
|
[SYS] You are helpful. [/SYS]
|
|
[RETRIEVAL_START]
|
|
[DOC_1] Title: ... Content: ... [/DOC_1]
|
|
[DOC_2] Title: ... Content: ... [/DOC_2]
|
|
[RETRIEVAL_END]
|
|
[USER] Question? [/USER]
|
|
[ASST] Answer. [/ASST]
|
|
```
|
|
|
|
**Option C: Embedding-Level (Advanced)**
|
|
- Add special "retrieval" embeddings
|
|
- Mamba state conditioned on retrieval embeddings
|
|
- Requires model architecture modification (future work)
|
|
|
|
---
|
|
|
|
## 6. REFRAG (RECURSIVE RAG) DESIGN
|
|
|
|
### 6.1 Recursive Retrieval Flow
|
|
|
|
```python
|
|
def refrag_retrieve(query, max_hops=3):
|
|
"""Recursive retrieval with multiple hops."""
|
|
all_documents = []
|
|
current_query = query
|
|
|
|
for hop in range(max_hops):
|
|
# Retrieve documents for current query
|
|
docs = retriever.retrieve(current_query, top_k=5)
|
|
all_documents.append({
|
|
"hop": hop + 1,
|
|
"query": current_query,
|
|
"documents": docs
|
|
})
|
|
|
|
# Generate next query from retrieved docs (using LLM)
|
|
if hop < max_hops - 1:
|
|
current_query = generate_followup_query(docs, query)
|
|
if not current_query: # No more relevant queries
|
|
break
|
|
|
|
return all_documents
|
|
```
|
|
|
|
### 6.2 Reinforcement Learning for Retrieval
|
|
|
|
**Reward Signal:**
|
|
```python
|
|
def compute_rag_reward(generated_answer, ground_truth, retrieved_docs):
|
|
"""Compute reward for RAG performance."""
|
|
# Component 1: Answer quality
|
|
answer_score = compute_similarity(generated_answer, ground_truth)
|
|
|
|
# Component 2: Document relevance
|
|
relevance_score = compute_doc_relevance(retrieved_docs, ground_truth)
|
|
|
|
# Component 3: Efficiency (fewer docs = better)
|
|
efficiency_score = 1.0 - (len(retrieved_docs) / max_docs)
|
|
|
|
# Weighted combination
|
|
reward = 0.6 * answer_score + 0.3 * relevance_score + 0.1 * efficiency_score
|
|
return reward
|
|
```
|
|
|
|
**Training Loop:**
|
|
```python
|
|
for batch in rag_dataloader:
|
|
# 1. Retrieve documents
|
|
retrieved_docs = retriever.retrieve(batch['query'])
|
|
|
|
# 2. Generate answer
|
|
answer = model.generate(batch['query'], retrieved_docs)
|
|
|
|
# 3. Compute reward
|
|
reward = compute_rag_reward(answer, batch['ground_truth'], retrieved_docs)
|
|
|
|
# 4. Update model (PPO or similar)
|
|
loss = -reward * log_prob(answer)
|
|
loss.backward()
|
|
```
|
|
|
|
---
|
|
|
|
## 7. IMPLEMENTATION PLAN
|
|
|
|
### Phase 1: Basic RAG Infrastructure (Week 1)
|
|
- [ ] Create `nanochat/retrieval.py` with retrieval managers
|
|
- [ ] Implement `RetrievalTask` class extending `Task`
|
|
- [ ] Add RAG data loader with document injection
|
|
- [ ] Create `scripts/rag_finetune.py` script
|
|
- [ ] Test with simple retrieval on Mamba/hybrid models
|
|
|
|
### Phase 2: Advanced Retrieval (Week 2)
|
|
- [ ] Implement dense retriever (FAISS + sentence-transformers)
|
|
- [ ] Implement BM25 sparse retriever
|
|
- [ ] Add hybrid retrieval with reranking
|
|
- [ ] Create retrieval preprocessing tools
|
|
- [ ] Build example knowledge base
|
|
|
|
### Phase 3: REFRAG Implementation (Week 3)
|
|
- [ ] Implement recursive retrieval mechanism
|
|
- [ ] Add query generation for multi-hop
|
|
- [ ] Integrate reward modeling
|
|
- [ ] Create REFRAG training loop
|
|
- [ ] Test on multi-hop QA datasets
|
|
|
|
### Phase 4: Optimization & Testing (Week 4)
|
|
- [ ] Optimize for Mamba (long context handling)
|
|
- [ ] Add gradient checkpointing for long contexts
|
|
- [ ] Profile memory usage with retrieved docs
|
|
- [ ] Comprehensive testing
|
|
- [ ] Documentation and examples
|
|
|
|
---
|
|
|
|
## 8. FILE STRUCTURE
|
|
|
|
```
|
|
nanochat/
|
|
├── retrieval.py # NEW: Retrieval infrastructure
|
|
├── rag_utils.py # NEW: RAG utility functions
|
|
└── blocks/
|
|
└── rag_mamba_block.py # NEW: Optional RAG-optimized Mamba
|
|
|
|
scripts/
|
|
├── rag_finetune.py # NEW: RAG fine-tuning script
|
|
├── refrag_finetune.py # NEW: REFRAG fine-tuning script
|
|
└── rag_eval.py # NEW: RAG evaluation
|
|
|
|
tasks/
|
|
├── rag_task.py # NEW: RAG task wrapper
|
|
└── retrieval_qa.py # NEW: QA with retrieval
|
|
|
|
configs/
|
|
├── rag_mamba_d20.py # NEW: RAG config for Mamba
|
|
├── rag_hybrid_d20.py # NEW: RAG config for hybrid
|
|
└── refrag_hybrid_d20.py # NEW: REFRAG config
|
|
|
|
data/
|
|
└── rag_examples/
|
|
├── knowledge_base/
|
|
├── queries/
|
|
└── conversations/
|
|
|
|
tests/
|
|
├── test_retrieval.py # NEW: Retrieval tests
|
|
└── test_rag_finetuning.py # NEW: RAG training tests
|
|
```
|
|
|
|
---
|
|
|
|
## 9. EXAMPLE USAGE
|
|
|
|
### Basic RAG Fine-Tuning
|
|
```bash
|
|
# Prepare knowledge base
|
|
python -m nanochat.retrieval prepare_kb \
|
|
--documents data/documents.jsonl \
|
|
--output data/rag_examples/knowledge_base
|
|
|
|
# Fine-tune hybrid model with RAG
|
|
torchrun --standalone --nproc_per_node=8 -m scripts.rag_finetune \
|
|
--source mid \
|
|
--model_tag d20 \
|
|
--knowledge_base data/rag_examples/knowledge_base \
|
|
--block_pattern "T,T,T,T,T,T,T,T,T,M,M,M,M,M,M,M,M,M,M,M" \
|
|
--top_k 5 \
|
|
--device_batch_size 4
|
|
```
|
|
|
|
### REFRAG Training
|
|
```bash
|
|
# Fine-tune with recursive retrieval
|
|
torchrun --standalone --nproc_per_node=8 -m scripts.refrag_finetune \
|
|
--source mid \
|
|
--model_tag d20 \
|
|
--knowledge_base data/rag_examples/knowledge_base \
|
|
--max_hops 3 \
|
|
--use_rewards true \
|
|
--device_batch_size 2
|
|
```
|
|
|
|
### Inference with RAG
|
|
```python
|
|
from nanochat.retrieval import RetrievalManager
|
|
from nanochat.checkpoint_manager import load_model
|
|
|
|
# Load RAG-trained model
|
|
model, tokenizer, _ = load_model("rag", device="cuda", phase="eval")
|
|
|
|
# Initialize retrieval
|
|
retriever = RetrievalManager(
|
|
retriever_type="hybrid",
|
|
knowledge_base="data/rag_examples/knowledge_base"
|
|
)
|
|
|
|
# Query with retrieval
|
|
query = "What is the capital of France?"
|
|
retrieved_docs = retriever.retrieve(query, top_k=5)
|
|
|
|
# Generate answer
|
|
conversation = {
|
|
"messages": [
|
|
{"role": "system", "content": "You are helpful."},
|
|
{"role": "retrieval", "documents": retrieved_docs},
|
|
{"role": "user", "content": query}
|
|
]
|
|
}
|
|
response = model.generate_from_conversation(conversation)
|
|
print(response)
|
|
```
|
|
|
|
---
|
|
|
|
## 10. EXPECTED BENEFITS
|
|
|
|
### Performance Improvements
|
|
|
|
| Metric | Transformer | Hybrid + RAG | Mamba + RAG |
|
|
|--------|-------------|--------------|-------------|
|
|
| Max Context (docs) | 3-5 docs (2K tokens) | 5-8 docs (4K tokens) | 10-15 docs (8K+ tokens) |
|
|
| Memory Usage | Baseline | -20% | -40% |
|
|
| Inference Speed | Baseline | +15% | +40% |
|
|
| RAG Quality | Good | Better | Best for long docs |
|
|
|
|
### Quality Improvements
|
|
|
|
- **Factuality**: ↑ 20-30% with retrieved grounding
|
|
- **Hallucination**: ↓ 40-50% with document evidence
|
|
- **Domain Coverage**: Can answer on any domain with KB
|
|
- **Temporal**: Up-to-date info via KB updates
|
|
|
|
---
|
|
|
|
## 11. DESIGN DECISIONS & RATIONALE
|
|
|
|
### Decision 1: Mamba/Hybrid Only
|
|
|
|
**Rationale:**
|
|
- Pure transformer RAG is O(n²) with context length
|
|
- Mamba's O(n) makes long context RAG practical
|
|
- Hybrid gets best of both: attention for relevance, SSM for processing
|
|
|
|
### Decision 2: External Retrieval (Not End-to-End)
|
|
|
|
**Rationale:**
|
|
- Separate retrieval allows KB updates without retraining
|
|
- More flexible: swap retrieval methods
|
|
- Lower computational cost
|
|
- Can use specialized retrieval models
|
|
|
|
**Alternative Considered:** Train retrieval jointly
|
|
- More complex
|
|
- Requires larger compute budget
|
|
- Less flexible
|
|
- Future work
|
|
|
|
### Decision 3: Structured Context Injection
|
|
|
|
**Rationale:**
|
|
- Special tokens [DOC], [RETRIEVAL_START] make boundaries clear
|
|
- Model learns to identify and use retrieved info
|
|
- Easier to debug and interpret
|
|
- Compatible with existing tokenizer
|
|
|
|
### Decision 4: REFRAG as Extension
|
|
|
|
**Rationale:**
|
|
- Start simple with single-hop RAG
|
|
- Add recursive as advanced feature
|
|
- Allows gradual complexity increase
|
|
- Can train on simpler data first
|
|
|
|
---
|
|
|
|
## 12. RISKS & MITIGATIONS
|
|
|
|
| Risk | Impact | Mitigation |
|
|
|------|--------|------------|
|
|
| OOM with long contexts | High | Gradient checkpointing, reduce batch size |
|
|
| Retrieval quality poor | High | Use high-quality embeddings, hybrid retrieval |
|
|
| Training instability | Medium | Careful LR tuning, gradual unfreezing |
|
|
| Document contamination | Medium | Strict train/val/test KB separation |
|
|
| Slow inference | Medium | Cache embeddings, optimize retrieval |
|
|
|
|
---
|
|
|
|
## 13. SUCCESS METRICS
|
|
|
|
### Quantitative
|
|
|
|
- **Retrieval Recall@5**: > 80% on validation queries
|
|
- **Answer Quality (F1)**: > 70% vs ground truth
|
|
- **Hallucination Rate**: < 10% false claims
|
|
- **Training Speed**: < 2x slower than base SFT
|
|
- **Memory Usage**: Fits on RTX 4070 (16GB) for d16
|
|
|
|
### Qualitative
|
|
|
|
- Model correctly attributes answers to documents
|
|
- Model says "I don't know" when docs don't contain answer
|
|
- Model synthesizes across multiple documents
|
|
- Model handles contradictory documents gracefully
|
|
|
|
---
|
|
|
|
## 14. FUTURE ENHANCEMENTS (Beyond Scope)
|
|
|
|
- [ ] Learnable retrieval (end-to-end)
|
|
- [ ] Multi-modal retrieval (images, tables)
|
|
- [ ] Streaming retrieval during generation
|
|
- [ ] Adaptive retrieval (retrieve more if uncertain)
|
|
- [ ] Retrieval cache for common queries
|
|
- [ ] Cross-lingual retrieval
|
|
- [ ] Temporal retrieval (time-aware)
|
|
|
|
---
|
|
|
|
## 15. CONCLUSION
|
|
|
|
**Recommendation**: **PROCEED with RAG/REFRAG implementation**
|
|
|
|
**Rationale:**
|
|
1. ✅ Natural fit for Mamba's long-context capabilities
|
|
2. ✅ Modular architecture supports clean integration
|
|
3. ✅ Clear value proposition: grounded generation
|
|
4. ✅ Feasible within consumer GPU constraints
|
|
5. ✅ Educational value: demonstrates RAG best practices
|
|
|
|
**Next Steps:**
|
|
1. Get approval for design approach
|
|
2. Begin Phase 1 implementation
|
|
3. Create example knowledge base
|
|
4. Test retrieval on hybrid models
|
|
5. Iterate based on results
|
|
|
|
---
|
|
|
|
**Document Version**: 1.0
|
|
**Date**: 2025-01-15
|
|
**Status**: Design Complete - Ready for Implementation
|
|
|