nanochat/RAG_REFRAG_INVESTIGATION.md

19 KiB

RAG/REFRAG Fine-Tuning Investigation & Design

Executive Summary

This document outlines the design and implementation plan for adding RAG (Retrieval-Augmented Generation) and REFRAG (Recursive/Reinforcement RAG) fine-tuning capabilities to nanochat, specifically for Mamba and hybrid (Transformer+Mamba) architectures.

Key Innovation: Leverage Mamba's linear complexity and long-context capabilities to efficiently process retrieved documents, making RAG more scalable than with pure transformer architectures.


1. CONCEPTUAL OVERVIEW

1.1 What is RAG?

RAG (Retrieval-Augmented Generation) enhances LLM responses by:

  1. Retrieving relevant documents from a knowledge base
  2. Augmenting the model's input with retrieved context
  3. Generating responses conditioned on both the query and retrieved information

1.2 What is REFRAG?

REFRAG (Recursive/Reinforcement RAG) extends RAG with:

  1. Recursive Retrieval: Multi-hop retrieval where retrieved docs inform next retrieval
  2. Reinforcement Learning: Reward model scores retrieval quality
  3. Adaptive Context: Dynamically adjust which documents to include

1.3 Why Mamba + RAG is Powerful

Feature Transformer Mamba Hybrid (T+M)
Context Window O(n²) cost O(n) cost O(n²) early, O(n) late
Long Documents Expensive Efficient Balanced
Retrieval Capacity Limited by attention Can handle more docs Best of both
Fine-tuning Cost High Lower Moderate

Why This Matters:

  • Mamba can efficiently process 10K+ token contexts with retrieved documents
  • Hybrid models use attention for retrieval relevance, SSM for document processing
  • Lower memory → more documents in context → better RAG performance

2. CURRENT INFRASTRUCTURE ANALYSIS

2.1 Existing Fine-Tuning Components

chat_sft.py (Supervised Fine-Tuning):

  • Loads conversations from Task objects
  • Uses sft_data_generator() for batching
  • Masks loss on non-assistant tokens
  • Standard gradient descent training

mid_train.py (Midtraining):

  • Similar to SFT but different task mixture
  • Uses mid_data_generator() for streaming
  • Token-level batching from conversations

Key Insight: Both use conversation-based datasets. RAG will extend this by:

  1. Adding retrieved documents to conversations
  2. Teaching model to condition on retrieved context
  3. Optionally training retrieval scoring

2.2 Task Infrastructure (tasks/common.py)

class Task:
    def get_example(self, index) -> conversation
    # Returns dict with messages

Extension Point: Add RetrievalTask that augments conversations with retrieved docs.

2.3 Data Flow

Current:

Dataset → Conversation → Tokenize → Batch → Train

With RAG:

Dataset → Query → Retrieve Docs → Augmented Conversation → Tokenize → Batch → Train

3. RAG DATA FORMAT DESIGN

3.1 RAG-Enhanced Conversation Format

{
    "messages": [
        {
            "role": "system",
            "content": "You are a helpful assistant. Use the provided documents to answer questions."
        },
        {
            "role": "retrieval",  # NEW ROLE
            "documents": [
                {
                    "id": "doc_123",
                    "title": "Document Title",
                    "content": "Document content...",
                    "score": 0.95,  # retrieval score
                    "source": "wikipedia"
                },
                # ... more documents
            ]
        },
        {
            "role": "user",
            "content": "What is the capital of France?"
        },
        {
            "role": "assistant",
            "content": "Based on the provided documents, the capital of France is Paris."
        }
    ],
    "metadata": {
        "query": "capital of France",
        "retrieval_method": "dense",  # or "sparse", "hybrid"
        "num_retrieved": 5
    }
}

3.2 REFRAG-Enhanced Format (Recursive)

{
    "messages": [
        {"role": "system", "content": "..."},
        {
            "role": "retrieval",
            "hop": 1,  # First retrieval
            "query": "capital of France",
            "documents": [...]
        },
        {
            "role": "retrieval",
            "hop": 2,  # Second retrieval (based on first)
            "query": "Paris population and history",  # derived query
            "documents": [...]
        },
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

3.3 Training Data Structure

rag_training_data/
├── knowledge_base/
│   ├── documents.jsonl      # All retrievable documents
│   ├── embeddings.npy       # Precomputed embeddings
│   └── index.faiss          # FAISS index for retrieval
├── queries/
│   ├── train.jsonl          # Training queries
│   ├── val.jsonl            # Validation queries
│   └── test.jsonl           # Test queries
└── conversations/
    ├── train_rag.jsonl      # Augmented conversations
    └── val_rag.jsonl

4. RETRIEVAL MECHANISM DESIGN

4.1 Retrieval Strategies

Strategy 1: Dense Retrieval (Recommended)

class DenseRetriever:
    def __init__(self, encoder_model, index_path):
        self.encoder = load_encoder(encoder_model)  # e.g., sentence-transformers
        self.index = faiss.read_index(index_path)
        self.documents = load_documents()
    
    def retrieve(self, query, top_k=5):
        query_embedding = self.encoder.encode(query)
        scores, indices = self.index.search(query_embedding, top_k)
        return [self.documents[i] for i in indices]

Strategy 2: Sparse Retrieval (BM25)

class BM25Retriever:
    def __init__(self, documents):
        self.bm25 = BM25(documents)
    
    def retrieve(self, query, top_k=5):
        scores = self.bm25.get_scores(query)
        top_indices = np.argsort(scores)[-top_k:]
        return [self.documents[i] for i in top_indices]

Strategy 3: Hybrid (Best Performance)

class HybridRetriever:
    def __init__(self, dense_retriever, sparse_retriever):
        self.dense = dense_retriever
        self.sparse = sparse_retriever
    
    def retrieve(self, query, top_k=5, alpha=0.7):
        dense_docs = self.dense.retrieve(query, top_k*2)
        sparse_docs = self.sparse.retrieve(query, top_k*2)
        # Combine and rerank
        combined = self.rerank(dense_docs, sparse_docs, alpha)
        return combined[:top_k]

4.2 Integration with Nanochat

# New file: nanochat/retrieval.py
class RetrievalManager:
    """Manages document retrieval for RAG fine-tuning."""
    
    def __init__(self, retriever_type="dense", **kwargs):
        self.retriever = self._create_retriever(retriever_type, **kwargs)
    
    def augment_conversation(self, conversation, top_k=5):
        """Add retrieved documents to a conversation."""
        # Extract query from conversation
        query = self._extract_query(conversation)
        
        # Retrieve documents
        documents = self.retriever.retrieve(query, top_k)
        
        # Insert retrieval message
        augmented = self._insert_retrieval(conversation, documents)
        
        return augmented
    
    def _extract_query(self, conversation):
        # Extract user query from last user message
        for msg in reversed(conversation['messages']):
            if msg['role'] == 'user':
                return msg['content']
        return ""
    
    def _insert_retrieval(self, conversation, documents):
        # Insert retrieval message before user query
        messages = conversation['messages'].copy()
        retrieval_msg = {
            "role": "retrieval",
            "documents": documents
        }
        # Insert before last user message
        messages.insert(-1, retrieval_msg)
        return {"messages": messages}

5. MAMBA-SPECIFIC OPTIMIZATIONS

5.1 Why Mamba is Ideal for RAG

  1. Linear Complexity: Process 10K+ token contexts efficiently
  2. Selective Attention: Can focus on relevant parts of retrieved docs
  3. State-Based Memory: Natural for maintaining document context
  4. Lower Memory: Fit more documents in same VRAM

5.2 Hybrid Architecture Strategy

Optimal Pattern for RAG:

# Early layers: Transformer (for cross-document attention/relevance)
# Middle layers: Hybrid (transition)
# Late layers: Mamba (for efficient long-context processing)

block_pattern = ["T"] * 8 + ["T", "M"] * 2 + ["M"] * 8  # For d20

Rationale:

  • Early transformers: Learn document relevance and cross-document relationships
  • Late Mamba: Process long concatenated documents efficiently
  • Memory savings: ~40% less activation memory for document processing

5.3 Context Injection Strategy

Option A: Concatenation (Simple)

[SYS] You are helpful. [/SYS]
[DOC] Doc 1 content... [/DOC]
[DOC] Doc 2 content... [/DOC]
[USER] Question? [/USER]
[ASST] Answer. [/ASST]

Option B: Structured Tokens (Better)

[SYS] You are helpful. [/SYS]
[RETRIEVAL_START]
[DOC_1] Title: ... Content: ... [/DOC_1]
[DOC_2] Title: ... Content: ... [/DOC_2]
[RETRIEVAL_END]
[USER] Question? [/USER]
[ASST] Answer. [/ASST]

Option C: Embedding-Level (Advanced)

  • Add special "retrieval" embeddings
  • Mamba state conditioned on retrieval embeddings
  • Requires model architecture modification (future work)

6. REFRAG (RECURSIVE RAG) DESIGN

6.1 Recursive Retrieval Flow

def refrag_retrieve(query, max_hops=3):
    """Recursive retrieval with multiple hops."""
    all_documents = []
    current_query = query
    
    for hop in range(max_hops):
        # Retrieve documents for current query
        docs = retriever.retrieve(current_query, top_k=5)
        all_documents.append({
            "hop": hop + 1,
            "query": current_query,
            "documents": docs
        })
        
        # Generate next query from retrieved docs (using LLM)
        if hop < max_hops - 1:
            current_query = generate_followup_query(docs, query)
            if not current_query:  # No more relevant queries
                break
    
    return all_documents

6.2 Reinforcement Learning for Retrieval

Reward Signal:

def compute_rag_reward(generated_answer, ground_truth, retrieved_docs):
    """Compute reward for RAG performance."""
    # Component 1: Answer quality
    answer_score = compute_similarity(generated_answer, ground_truth)
    
    # Component 2: Document relevance
    relevance_score = compute_doc_relevance(retrieved_docs, ground_truth)
    
    # Component 3: Efficiency (fewer docs = better)
    efficiency_score = 1.0 - (len(retrieved_docs) / max_docs)
    
    # Weighted combination
    reward = 0.6 * answer_score + 0.3 * relevance_score + 0.1 * efficiency_score
    return reward

Training Loop:

for batch in rag_dataloader:
    # 1. Retrieve documents
    retrieved_docs = retriever.retrieve(batch['query'])
    
    # 2. Generate answer
    answer = model.generate(batch['query'], retrieved_docs)
    
    # 3. Compute reward
    reward = compute_rag_reward(answer, batch['ground_truth'], retrieved_docs)
    
    # 4. Update model (PPO or similar)
    loss = -reward * log_prob(answer)
    loss.backward()

7. IMPLEMENTATION PLAN

Phase 1: Basic RAG Infrastructure (Week 1)

  • Create nanochat/retrieval.py with retrieval managers
  • Implement RetrievalTask class extending Task
  • Add RAG data loader with document injection
  • Create scripts/rag_finetune.py script
  • Test with simple retrieval on Mamba/hybrid models

Phase 2: Advanced Retrieval (Week 2)

  • Implement dense retriever (FAISS + sentence-transformers)
  • Implement BM25 sparse retriever
  • Add hybrid retrieval with reranking
  • Create retrieval preprocessing tools
  • Build example knowledge base

Phase 3: REFRAG Implementation (Week 3)

  • Implement recursive retrieval mechanism
  • Add query generation for multi-hop
  • Integrate reward modeling
  • Create REFRAG training loop
  • Test on multi-hop QA datasets

Phase 4: Optimization & Testing (Week 4)

  • Optimize for Mamba (long context handling)
  • Add gradient checkpointing for long contexts
  • Profile memory usage with retrieved docs
  • Comprehensive testing
  • Documentation and examples

8. FILE STRUCTURE

nanochat/
├── retrieval.py                    # NEW: Retrieval infrastructure
├── rag_utils.py                    # NEW: RAG utility functions
└── blocks/
    └── rag_mamba_block.py         # NEW: Optional RAG-optimized Mamba

scripts/
├── rag_finetune.py                # NEW: RAG fine-tuning script
├── refrag_finetune.py             # NEW: REFRAG fine-tuning script
└── rag_eval.py                    # NEW: RAG evaluation

tasks/
├── rag_task.py                    # NEW: RAG task wrapper
└── retrieval_qa.py                # NEW: QA with retrieval

configs/
├── rag_mamba_d20.py               # NEW: RAG config for Mamba
├── rag_hybrid_d20.py              # NEW: RAG config for hybrid
└── refrag_hybrid_d20.py           # NEW: REFRAG config

data/
└── rag_examples/
    ├── knowledge_base/
    ├── queries/
    └── conversations/

tests/
├── test_retrieval.py              # NEW: Retrieval tests
└── test_rag_finetuning.py         # NEW: RAG training tests

9. EXAMPLE USAGE

Basic RAG Fine-Tuning

# Prepare knowledge base
python -m nanochat.retrieval prepare_kb \
  --documents data/documents.jsonl \
  --output data/rag_examples/knowledge_base

# Fine-tune hybrid model with RAG
torchrun --standalone --nproc_per_node=8 -m scripts.rag_finetune \
  --source mid \
  --model_tag d20 \
  --knowledge_base data/rag_examples/knowledge_base \
  --block_pattern "T,T,T,T,T,T,T,T,T,M,M,M,M,M,M,M,M,M,M,M" \
  --top_k 5 \
  --device_batch_size 4

REFRAG Training

# Fine-tune with recursive retrieval
torchrun --standalone --nproc_per_node=8 -m scripts.refrag_finetune \
  --source mid \
  --model_tag d20 \
  --knowledge_base data/rag_examples/knowledge_base \
  --max_hops 3 \
  --use_rewards true \
  --device_batch_size 2

Inference with RAG

from nanochat.retrieval import RetrievalManager
from nanochat.checkpoint_manager import load_model

# Load RAG-trained model
model, tokenizer, _ = load_model("rag", device="cuda", phase="eval")

# Initialize retrieval
retriever = RetrievalManager(
    retriever_type="hybrid",
    knowledge_base="data/rag_examples/knowledge_base"
)

# Query with retrieval
query = "What is the capital of France?"
retrieved_docs = retriever.retrieve(query, top_k=5)

# Generate answer
conversation = {
    "messages": [
        {"role": "system", "content": "You are helpful."},
        {"role": "retrieval", "documents": retrieved_docs},
        {"role": "user", "content": query}
    ]
}
response = model.generate_from_conversation(conversation)
print(response)

10. EXPECTED BENEFITS

Performance Improvements

Metric Transformer Hybrid + RAG Mamba + RAG
Max Context (docs) 3-5 docs (2K tokens) 5-8 docs (4K tokens) 10-15 docs (8K+ tokens)
Memory Usage Baseline -20% -40%
Inference Speed Baseline +15% +40%
RAG Quality Good Better Best for long docs

Quality Improvements

  • Factuality: ↑ 20-30% with retrieved grounding
  • Hallucination: ↓ 40-50% with document evidence
  • Domain Coverage: Can answer on any domain with KB
  • Temporal: Up-to-date info via KB updates

11. DESIGN DECISIONS & RATIONALE

Decision 1: Mamba/Hybrid Only

Rationale:

  • Pure transformer RAG is O(n²) with context length
  • Mamba's O(n) makes long context RAG practical
  • Hybrid gets best of both: attention for relevance, SSM for processing

Decision 2: External Retrieval (Not End-to-End)

Rationale:

  • Separate retrieval allows KB updates without retraining
  • More flexible: swap retrieval methods
  • Lower computational cost
  • Can use specialized retrieval models

Alternative Considered: Train retrieval jointly

  • More complex
  • Requires larger compute budget
  • Less flexible
  • Future work

Decision 3: Structured Context Injection

Rationale:

  • Special tokens [DOC], [RETRIEVAL_START] make boundaries clear
  • Model learns to identify and use retrieved info
  • Easier to debug and interpret
  • Compatible with existing tokenizer

Decision 4: REFRAG as Extension

Rationale:

  • Start simple with single-hop RAG
  • Add recursive as advanced feature
  • Allows gradual complexity increase
  • Can train on simpler data first

12. RISKS & MITIGATIONS

Risk Impact Mitigation
OOM with long contexts High Gradient checkpointing, reduce batch size
Retrieval quality poor High Use high-quality embeddings, hybrid retrieval
Training instability Medium Careful LR tuning, gradual unfreezing
Document contamination Medium Strict train/val/test KB separation
Slow inference Medium Cache embeddings, optimize retrieval

13. SUCCESS METRICS

Quantitative

  • Retrieval Recall@5: > 80% on validation queries
  • Answer Quality (F1): > 70% vs ground truth
  • Hallucination Rate: < 10% false claims
  • Training Speed: < 2x slower than base SFT
  • Memory Usage: Fits on RTX 4070 (16GB) for d16

Qualitative

  • Model correctly attributes answers to documents
  • Model says "I don't know" when docs don't contain answer
  • Model synthesizes across multiple documents
  • Model handles contradictory documents gracefully

14. FUTURE ENHANCEMENTS (Beyond Scope)

  • Learnable retrieval (end-to-end)
  • Multi-modal retrieval (images, tables)
  • Streaming retrieval during generation
  • Adaptive retrieval (retrieve more if uncertain)
  • Retrieval cache for common queries
  • Cross-lingual retrieval
  • Temporal retrieval (time-aware)

15. CONCLUSION

Recommendation: PROCEED with RAG/REFRAG implementation

Rationale:

  1. Natural fit for Mamba's long-context capabilities
  2. Modular architecture supports clean integration
  3. Clear value proposition: grounded generation
  4. Feasible within consumer GPU constraints
  5. Educational value: demonstrates RAG best practices

Next Steps:

  1. Get approval for design approach
  2. Begin Phase 1 implementation
  3. Create example knowledge base
  4. Test retrieval on hybrid models
  5. Iterate based on results

Document Version: 1.0 Date: 2025-01-15 Status: Design Complete - Ready for Implementation