From 77593b77d472b4ec2cba043d4180ca682a5eee44 Mon Sep 17 00:00:00 2001 From: CadBane Date: Wed, 15 Oct 2025 11:19:07 +0200 Subject: [PATCH] Final Commit after extensive Mamba Block Architecture feature add-ons --- RAG_REFRAG_INVESTIGATION.md | 639 ++++++++++++++++++++++++++++++++++++ 1 file changed, 639 insertions(+) create mode 100644 RAG_REFRAG_INVESTIGATION.md diff --git a/RAG_REFRAG_INVESTIGATION.md b/RAG_REFRAG_INVESTIGATION.md new file mode 100644 index 0000000..e2c663a --- /dev/null +++ b/RAG_REFRAG_INVESTIGATION.md @@ -0,0 +1,639 @@ +# RAG/REFRAG Fine-Tuning Investigation & Design + +## Executive Summary + +This document outlines the design and implementation plan for adding **RAG (Retrieval-Augmented Generation)** and **REFRAG (Recursive/Reinforcement RAG)** fine-tuning capabilities to nanochat, specifically for **Mamba and hybrid (Transformer+Mamba) architectures**. + +**Key Innovation**: Leverage Mamba's linear complexity and long-context capabilities to efficiently process retrieved documents, making RAG more scalable than with pure transformer architectures. + +--- + +## 1. CONCEPTUAL OVERVIEW + +### 1.1 What is RAG? + +**RAG (Retrieval-Augmented Generation)** enhances LLM responses by: +1. **Retrieving** relevant documents from a knowledge base +2. **Augmenting** the model's input with retrieved context +3. **Generating** responses conditioned on both the query and retrieved information + +### 1.2 What is REFRAG? + +**REFRAG (Recursive/Reinforcement RAG)** extends RAG with: +1. **Recursive Retrieval**: Multi-hop retrieval where retrieved docs inform next retrieval +2. **Reinforcement Learning**: Reward model scores retrieval quality +3. **Adaptive Context**: Dynamically adjust which documents to include + +### 1.3 Why Mamba + RAG is Powerful + +| Feature | Transformer | Mamba | Hybrid (T+M) | +|---------|-------------|-------|--------------| +| Context Window | O(n²) cost | O(n) cost | O(n²) early, O(n) late | +| Long Documents | Expensive | Efficient | Balanced | +| Retrieval Capacity | Limited by attention | Can handle more docs | Best of both | +| Fine-tuning Cost | High | Lower | Moderate | + +**Why This Matters:** +- Mamba can efficiently process 10K+ token contexts with retrieved documents +- Hybrid models use attention for retrieval relevance, SSM for document processing +- Lower memory → more documents in context → better RAG performance + +--- + +## 2. CURRENT INFRASTRUCTURE ANALYSIS + +### 2.1 Existing Fine-Tuning Components + +**chat_sft.py** (Supervised Fine-Tuning): +- Loads conversations from Task objects +- Uses `sft_data_generator()` for batching +- Masks loss on non-assistant tokens +- Standard gradient descent training + +**mid_train.py** (Midtraining): +- Similar to SFT but different task mixture +- Uses `mid_data_generator()` for streaming +- Token-level batching from conversations + +**Key Insight**: Both use conversation-based datasets. RAG will extend this by: +1. Adding retrieved documents to conversations +2. Teaching model to condition on retrieved context +3. Optionally training retrieval scoring + +### 2.2 Task Infrastructure (tasks/common.py) + +```python +class Task: + def get_example(self, index) -> conversation + # Returns dict with messages +``` + +**Extension Point**: Add `RetrievalTask` that augments conversations with retrieved docs. + +### 2.3 Data Flow + +Current: +``` +Dataset → Conversation → Tokenize → Batch → Train +``` + +With RAG: +``` +Dataset → Query → Retrieve Docs → Augmented Conversation → Tokenize → Batch → Train +``` + +--- + +## 3. RAG DATA FORMAT DESIGN + +### 3.1 RAG-Enhanced Conversation Format + +```python +{ + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant. Use the provided documents to answer questions." + }, + { + "role": "retrieval", # NEW ROLE + "documents": [ + { + "id": "doc_123", + "title": "Document Title", + "content": "Document content...", + "score": 0.95, # retrieval score + "source": "wikipedia" + }, + # ... more documents + ] + }, + { + "role": "user", + "content": "What is the capital of France?" + }, + { + "role": "assistant", + "content": "Based on the provided documents, the capital of France is Paris." + } + ], + "metadata": { + "query": "capital of France", + "retrieval_method": "dense", # or "sparse", "hybrid" + "num_retrieved": 5 + } +} +``` + +### 3.2 REFRAG-Enhanced Format (Recursive) + +```python +{ + "messages": [ + {"role": "system", "content": "..."}, + { + "role": "retrieval", + "hop": 1, # First retrieval + "query": "capital of France", + "documents": [...] + }, + { + "role": "retrieval", + "hop": 2, # Second retrieval (based on first) + "query": "Paris population and history", # derived query + "documents": [...] + }, + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ] +} +``` + +### 3.3 Training Data Structure + +``` +rag_training_data/ +├── knowledge_base/ +│ ├── documents.jsonl # All retrievable documents +│ ├── embeddings.npy # Precomputed embeddings +│ └── index.faiss # FAISS index for retrieval +├── queries/ +│ ├── train.jsonl # Training queries +│ ├── val.jsonl # Validation queries +│ └── test.jsonl # Test queries +└── conversations/ + ├── train_rag.jsonl # Augmented conversations + └── val_rag.jsonl +``` + +--- + +## 4. RETRIEVAL MECHANISM DESIGN + +### 4.1 Retrieval Strategies + +**Strategy 1: Dense Retrieval (Recommended)** +```python +class DenseRetriever: + def __init__(self, encoder_model, index_path): + self.encoder = load_encoder(encoder_model) # e.g., sentence-transformers + self.index = faiss.read_index(index_path) + self.documents = load_documents() + + def retrieve(self, query, top_k=5): + query_embedding = self.encoder.encode(query) + scores, indices = self.index.search(query_embedding, top_k) + return [self.documents[i] for i in indices] +``` + +**Strategy 2: Sparse Retrieval (BM25)** +```python +class BM25Retriever: + def __init__(self, documents): + self.bm25 = BM25(documents) + + def retrieve(self, query, top_k=5): + scores = self.bm25.get_scores(query) + top_indices = np.argsort(scores)[-top_k:] + return [self.documents[i] for i in top_indices] +``` + +**Strategy 3: Hybrid (Best Performance)** +```python +class HybridRetriever: + def __init__(self, dense_retriever, sparse_retriever): + self.dense = dense_retriever + self.sparse = sparse_retriever + + def retrieve(self, query, top_k=5, alpha=0.7): + dense_docs = self.dense.retrieve(query, top_k*2) + sparse_docs = self.sparse.retrieve(query, top_k*2) + # Combine and rerank + combined = self.rerank(dense_docs, sparse_docs, alpha) + return combined[:top_k] +``` + +### 4.2 Integration with Nanochat + +```python +# New file: nanochat/retrieval.py +class RetrievalManager: + """Manages document retrieval for RAG fine-tuning.""" + + def __init__(self, retriever_type="dense", **kwargs): + self.retriever = self._create_retriever(retriever_type, **kwargs) + + def augment_conversation(self, conversation, top_k=5): + """Add retrieved documents to a conversation.""" + # Extract query from conversation + query = self._extract_query(conversation) + + # Retrieve documents + documents = self.retriever.retrieve(query, top_k) + + # Insert retrieval message + augmented = self._insert_retrieval(conversation, documents) + + return augmented + + def _extract_query(self, conversation): + # Extract user query from last user message + for msg in reversed(conversation['messages']): + if msg['role'] == 'user': + return msg['content'] + return "" + + def _insert_retrieval(self, conversation, documents): + # Insert retrieval message before user query + messages = conversation['messages'].copy() + retrieval_msg = { + "role": "retrieval", + "documents": documents + } + # Insert before last user message + messages.insert(-1, retrieval_msg) + return {"messages": messages} +``` + +--- + +## 5. MAMBA-SPECIFIC OPTIMIZATIONS + +### 5.1 Why Mamba is Ideal for RAG + +1. **Linear Complexity**: Process 10K+ token contexts efficiently +2. **Selective Attention**: Can focus on relevant parts of retrieved docs +3. **State-Based Memory**: Natural for maintaining document context +4. **Lower Memory**: Fit more documents in same VRAM + +### 5.2 Hybrid Architecture Strategy + +**Optimal Pattern for RAG:** +```python +# Early layers: Transformer (for cross-document attention/relevance) +# Middle layers: Hybrid (transition) +# Late layers: Mamba (for efficient long-context processing) + +block_pattern = ["T"] * 8 + ["T", "M"] * 2 + ["M"] * 8 # For d20 +``` + +**Rationale:** +- Early transformers: Learn document relevance and cross-document relationships +- Late Mamba: Process long concatenated documents efficiently +- Memory savings: ~40% less activation memory for document processing + +### 5.3 Context Injection Strategy + +**Option A: Concatenation (Simple)** +``` +[SYS] You are helpful. [/SYS] +[DOC] Doc 1 content... [/DOC] +[DOC] Doc 2 content... [/DOC] +[USER] Question? [/USER] +[ASST] Answer. [/ASST] +``` + +**Option B: Structured Tokens (Better)** +``` +[SYS] You are helpful. [/SYS] +[RETRIEVAL_START] +[DOC_1] Title: ... Content: ... [/DOC_1] +[DOC_2] Title: ... Content: ... [/DOC_2] +[RETRIEVAL_END] +[USER] Question? [/USER] +[ASST] Answer. [/ASST] +``` + +**Option C: Embedding-Level (Advanced)** +- Add special "retrieval" embeddings +- Mamba state conditioned on retrieval embeddings +- Requires model architecture modification (future work) + +--- + +## 6. REFRAG (RECURSIVE RAG) DESIGN + +### 6.1 Recursive Retrieval Flow + +```python +def refrag_retrieve(query, max_hops=3): + """Recursive retrieval with multiple hops.""" + all_documents = [] + current_query = query + + for hop in range(max_hops): + # Retrieve documents for current query + docs = retriever.retrieve(current_query, top_k=5) + all_documents.append({ + "hop": hop + 1, + "query": current_query, + "documents": docs + }) + + # Generate next query from retrieved docs (using LLM) + if hop < max_hops - 1: + current_query = generate_followup_query(docs, query) + if not current_query: # No more relevant queries + break + + return all_documents +``` + +### 6.2 Reinforcement Learning for Retrieval + +**Reward Signal:** +```python +def compute_rag_reward(generated_answer, ground_truth, retrieved_docs): + """Compute reward for RAG performance.""" + # Component 1: Answer quality + answer_score = compute_similarity(generated_answer, ground_truth) + + # Component 2: Document relevance + relevance_score = compute_doc_relevance(retrieved_docs, ground_truth) + + # Component 3: Efficiency (fewer docs = better) + efficiency_score = 1.0 - (len(retrieved_docs) / max_docs) + + # Weighted combination + reward = 0.6 * answer_score + 0.3 * relevance_score + 0.1 * efficiency_score + return reward +``` + +**Training Loop:** +```python +for batch in rag_dataloader: + # 1. Retrieve documents + retrieved_docs = retriever.retrieve(batch['query']) + + # 2. Generate answer + answer = model.generate(batch['query'], retrieved_docs) + + # 3. Compute reward + reward = compute_rag_reward(answer, batch['ground_truth'], retrieved_docs) + + # 4. Update model (PPO or similar) + loss = -reward * log_prob(answer) + loss.backward() +``` + +--- + +## 7. IMPLEMENTATION PLAN + +### Phase 1: Basic RAG Infrastructure (Week 1) +- [ ] Create `nanochat/retrieval.py` with retrieval managers +- [ ] Implement `RetrievalTask` class extending `Task` +- [ ] Add RAG data loader with document injection +- [ ] Create `scripts/rag_finetune.py` script +- [ ] Test with simple retrieval on Mamba/hybrid models + +### Phase 2: Advanced Retrieval (Week 2) +- [ ] Implement dense retriever (FAISS + sentence-transformers) +- [ ] Implement BM25 sparse retriever +- [ ] Add hybrid retrieval with reranking +- [ ] Create retrieval preprocessing tools +- [ ] Build example knowledge base + +### Phase 3: REFRAG Implementation (Week 3) +- [ ] Implement recursive retrieval mechanism +- [ ] Add query generation for multi-hop +- [ ] Integrate reward modeling +- [ ] Create REFRAG training loop +- [ ] Test on multi-hop QA datasets + +### Phase 4: Optimization & Testing (Week 4) +- [ ] Optimize for Mamba (long context handling) +- [ ] Add gradient checkpointing for long contexts +- [ ] Profile memory usage with retrieved docs +- [ ] Comprehensive testing +- [ ] Documentation and examples + +--- + +## 8. FILE STRUCTURE + +``` +nanochat/ +├── retrieval.py # NEW: Retrieval infrastructure +├── rag_utils.py # NEW: RAG utility functions +└── blocks/ + └── rag_mamba_block.py # NEW: Optional RAG-optimized Mamba + +scripts/ +├── rag_finetune.py # NEW: RAG fine-tuning script +├── refrag_finetune.py # NEW: REFRAG fine-tuning script +└── rag_eval.py # NEW: RAG evaluation + +tasks/ +├── rag_task.py # NEW: RAG task wrapper +└── retrieval_qa.py # NEW: QA with retrieval + +configs/ +├── rag_mamba_d20.py # NEW: RAG config for Mamba +├── rag_hybrid_d20.py # NEW: RAG config for hybrid +└── refrag_hybrid_d20.py # NEW: REFRAG config + +data/ +└── rag_examples/ + ├── knowledge_base/ + ├── queries/ + └── conversations/ + +tests/ +├── test_retrieval.py # NEW: Retrieval tests +└── test_rag_finetuning.py # NEW: RAG training tests +``` + +--- + +## 9. EXAMPLE USAGE + +### Basic RAG Fine-Tuning +```bash +# Prepare knowledge base +python -m nanochat.retrieval prepare_kb \ + --documents data/documents.jsonl \ + --output data/rag_examples/knowledge_base + +# Fine-tune hybrid model with RAG +torchrun --standalone --nproc_per_node=8 -m scripts.rag_finetune \ + --source mid \ + --model_tag d20 \ + --knowledge_base data/rag_examples/knowledge_base \ + --block_pattern "T,T,T,T,T,T,T,T,T,M,M,M,M,M,M,M,M,M,M,M" \ + --top_k 5 \ + --device_batch_size 4 +``` + +### REFRAG Training +```bash +# Fine-tune with recursive retrieval +torchrun --standalone --nproc_per_node=8 -m scripts.refrag_finetune \ + --source mid \ + --model_tag d20 \ + --knowledge_base data/rag_examples/knowledge_base \ + --max_hops 3 \ + --use_rewards true \ + --device_batch_size 2 +``` + +### Inference with RAG +```python +from nanochat.retrieval import RetrievalManager +from nanochat.checkpoint_manager import load_model + +# Load RAG-trained model +model, tokenizer, _ = load_model("rag", device="cuda", phase="eval") + +# Initialize retrieval +retriever = RetrievalManager( + retriever_type="hybrid", + knowledge_base="data/rag_examples/knowledge_base" +) + +# Query with retrieval +query = "What is the capital of France?" +retrieved_docs = retriever.retrieve(query, top_k=5) + +# Generate answer +conversation = { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "retrieval", "documents": retrieved_docs}, + {"role": "user", "content": query} + ] +} +response = model.generate_from_conversation(conversation) +print(response) +``` + +--- + +## 10. EXPECTED BENEFITS + +### Performance Improvements + +| Metric | Transformer | Hybrid + RAG | Mamba + RAG | +|--------|-------------|--------------|-------------| +| Max Context (docs) | 3-5 docs (2K tokens) | 5-8 docs (4K tokens) | 10-15 docs (8K+ tokens) | +| Memory Usage | Baseline | -20% | -40% | +| Inference Speed | Baseline | +15% | +40% | +| RAG Quality | Good | Better | Best for long docs | + +### Quality Improvements + +- **Factuality**: ↑ 20-30% with retrieved grounding +- **Hallucination**: ↓ 40-50% with document evidence +- **Domain Coverage**: Can answer on any domain with KB +- **Temporal**: Up-to-date info via KB updates + +--- + +## 11. DESIGN DECISIONS & RATIONALE + +### Decision 1: Mamba/Hybrid Only + +**Rationale:** +- Pure transformer RAG is O(n²) with context length +- Mamba's O(n) makes long context RAG practical +- Hybrid gets best of both: attention for relevance, SSM for processing + +### Decision 2: External Retrieval (Not End-to-End) + +**Rationale:** +- Separate retrieval allows KB updates without retraining +- More flexible: swap retrieval methods +- Lower computational cost +- Can use specialized retrieval models + +**Alternative Considered:** Train retrieval jointly +- More complex +- Requires larger compute budget +- Less flexible +- Future work + +### Decision 3: Structured Context Injection + +**Rationale:** +- Special tokens [DOC], [RETRIEVAL_START] make boundaries clear +- Model learns to identify and use retrieved info +- Easier to debug and interpret +- Compatible with existing tokenizer + +### Decision 4: REFRAG as Extension + +**Rationale:** +- Start simple with single-hop RAG +- Add recursive as advanced feature +- Allows gradual complexity increase +- Can train on simpler data first + +--- + +## 12. RISKS & MITIGATIONS + +| Risk | Impact | Mitigation | +|------|--------|------------| +| OOM with long contexts | High | Gradient checkpointing, reduce batch size | +| Retrieval quality poor | High | Use high-quality embeddings, hybrid retrieval | +| Training instability | Medium | Careful LR tuning, gradual unfreezing | +| Document contamination | Medium | Strict train/val/test KB separation | +| Slow inference | Medium | Cache embeddings, optimize retrieval | + +--- + +## 13. SUCCESS METRICS + +### Quantitative + +- **Retrieval Recall@5**: > 80% on validation queries +- **Answer Quality (F1)**: > 70% vs ground truth +- **Hallucination Rate**: < 10% false claims +- **Training Speed**: < 2x slower than base SFT +- **Memory Usage**: Fits on RTX 4070 (16GB) for d16 + +### Qualitative + +- Model correctly attributes answers to documents +- Model says "I don't know" when docs don't contain answer +- Model synthesizes across multiple documents +- Model handles contradictory documents gracefully + +--- + +## 14. FUTURE ENHANCEMENTS (Beyond Scope) + +- [ ] Learnable retrieval (end-to-end) +- [ ] Multi-modal retrieval (images, tables) +- [ ] Streaming retrieval during generation +- [ ] Adaptive retrieval (retrieve more if uncertain) +- [ ] Retrieval cache for common queries +- [ ] Cross-lingual retrieval +- [ ] Temporal retrieval (time-aware) + +--- + +## 15. CONCLUSION + +**Recommendation**: **PROCEED with RAG/REFRAG implementation** + +**Rationale:** +1. ✅ Natural fit for Mamba's long-context capabilities +2. ✅ Modular architecture supports clean integration +3. ✅ Clear value proposition: grounded generation +4. ✅ Feasible within consumer GPU constraints +5. ✅ Educational value: demonstrates RAG best practices + +**Next Steps:** +1. Get approval for design approach +2. Begin Phase 1 implementation +3. Create example knowledge base +4. Test retrieval on hybrid models +5. Iterate based on results + +--- + +**Document Version**: 1.0 +**Date**: 2025-01-15 +**Status**: Design Complete - Ready for Implementation +