nanochat/scripts/rag_finetune.py
2025-10-15 11:19:36 +02:00

389 lines
13 KiB
Python

"""
RAG Fine-tuning Script for Mamba and Hybrid Models
Fine-tune a pretrained model with retrieval-augmented generation.
Optimized for Mamba and hybrid (Transformer+Mamba) architectures.
Usage:
# Single GPU
python -m scripts.rag_finetune --knowledge_base data/kb
# Multi-GPU
torchrun --standalone --nproc_per_node=8 -m scripts.rag_finetune \
--knowledge_base data/kb \
--source mid \
--retriever_type dense
Only works with Mamba or hybrid models (block_pattern must contain "M").
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
import wandb
import torch
import torch.distributed as dist
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
from nanochat.checkpoint_manager import load_model, save_checkpoint
from nanochat.engine import Engine
from nanochat.retrieval import RetrievalManager
from nanochat.rag_utils import render_rag_conversation_for_tokenizer
from tasks.rag_task import RAGTask, create_rag_task
from tasks.smoltalk import SmolTalk
from tasks.mmlu import MMLU
from tasks.arc import ARC
from tasks.gsm8k import GSM8K
# -----------------------------------------------------------------------------
# RAG Fine-tuning Hyperparameters
run = "dummy" # wandb run name
# Model options
source = "mid" # base|mid - which checkpoint to load
model_tag = None # model tag to load
step = None # step to load
# RAG options
knowledge_base = None # REQUIRED: path to knowledge base
retriever_type = "simple" # simple|dense
top_k = 5 # number of documents to retrieve
max_doc_length = 500 # max characters per document in prompt
insert_position = "before_user" # where to insert retrieval
# Task options
base_tasks = "SmolTalk" # comma-separated: SmolTalk,MMLU,ARC-Easy,GSM8K
task_samples = 10000 # samples per task (-1 = all)
# Training options
dtype = "bfloat16"
device_batch_size = 4 # smaller due to longer contexts with RAG
num_epochs = 1
max_iterations = -1
target_examples_per_step = 32
# Optimization
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
weight_decay = 0.0
init_lr_frac = 0.02 # start with lower LR for stability
# Evaluation
eval_every = 100
eval_steps = 50
# Allow CLI overrides
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read())
user_config = {k: globals()[k] for k in config_keys}
# -----------------------------------------------------------------------------
# Validate
if knowledge_base is None:
raise ValueError("--knowledge_base is required for RAG fine-tuning")
if not os.path.exists(knowledge_base):
raise FileNotFoundError(f"Knowledge base not found: {knowledge_base}")
# Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
master_process = ddp_rank == 0
dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype_torch)
# WandB logging
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(
project="nanochat-rag",
name=run,
config=user_config,
save_code=True
)
# Load model and tokenizer
print0(f"Loading model from {source} checkpoint...")
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
# Validate model has Mamba blocks
block_pattern = model.config.block_pattern
if block_pattern is None or "M" not in "".join(block_pattern):
raise ValueError(
"RAG fine-tuning requires Mamba or hybrid models. "
f"Current block_pattern: {block_pattern}. "
"Please use a model with Mamba blocks (contains 'M')."
)
print0(f"✓ Model has block pattern: {block_pattern}")
print0(f" Transformer blocks: {block_pattern.count('T')}")
print0(f" Mamba blocks: {block_pattern.count('M')}")
orig_model = model
# Don't compile for RAG (variable-length contexts)
# model = torch.compile(model, dynamic=True)
# Initialize retrieval manager
print0(f"Loading knowledge base from {knowledge_base}...")
print0(f"Using retriever type: {retriever_type}")
retrieval_manager = RetrievalManager(
retriever_type=retriever_type,
knowledge_base_path=knowledge_base
)
print0("✓ Knowledge base loaded")
# -----------------------------------------------------------------------------
# Create RAG tasks
print0(f"Creating RAG tasks from base tasks: {base_tasks}")
task_list = base_tasks.split(",")
train_rag_tasks = []
val_rag_tasks = []
for task_name in task_list:
task_name = task_name.strip()
print0(f" Creating RAG wrapper for {task_name}...")
# Create training task
try:
train_task = create_rag_task(
task_name=task_name,
split="train",
knowledge_base_path=knowledge_base,
retriever_type=retriever_type,
top_k=top_k,
stop=task_samples if task_samples > 0 else None
)
train_rag_tasks.append(train_task)
print0(f" Train: {len(train_task)} examples")
except Exception as e:
print0(f" Warning: Could not create train task for {task_name}: {e}")
# Create validation task
try:
val_task = create_rag_task(
task_name=task_name,
split="test" if task_name == "SmolTalk" else "val",
knowledge_base_path=knowledge_base,
retriever_type=retriever_type,
top_k=top_k,
stop=1000 # Limit validation size
)
val_rag_tasks.append(val_task)
print0(f" Val: {len(val_task)} examples")
except Exception as e:
print0(f" Warning: Could not create val task for {task_name}: {e}")
# Combine tasks
from tasks.common import TaskMixture
train_ds = TaskMixture(train_rag_tasks) if len(train_rag_tasks) > 1 else train_rag_tasks[0]
val_ds = TaskMixture(val_rag_tasks) if len(val_rag_tasks) > 1 else (val_rag_tasks[0] if val_rag_tasks else train_rag_tasks[0])
print0(f"\n✓ Total training examples: {len(train_ds)}")
print0(f"✓ Total validation examples: {len(val_ds)}")
# -----------------------------------------------------------------------------
# DataLoader for RAG
def rag_data_generator(dataset, batch_size):
"""Data generator for RAG training with retrieved documents."""
pad_token_id = tokenizer.encode_special("<|assistant_end|>")
def collate_and_yield(batch):
"""Collate RAG conversations into batch."""
nrows = len(batch)
ncols = max(len(ids) for ids, mask in batch) - 1
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
targets = torch.full((nrows, ncols), -1, dtype=torch.long)
for i, (ids, mask) in enumerate(batch):
n = len(ids)
ids_tensor = torch.tensor(ids, dtype=torch.long)
inputs[i, :n-1] = ids_tensor[:-1]
row_targets = ids_tensor[1:]
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
row_targets[mask_tensor == 0] = -1
targets[i, :n-1] = row_targets
inputs = inputs.to(device)
targets = targets.to(device)
return inputs, targets
batch = []
while True:
for i in range(ddp_rank, len(dataset), ddp_world_size):
# Get RAG-augmented conversation
conversation = dataset[i]
# Render to tokens
ids, mask = tokenizer.render_conversation(conversation)
# Truncate if too long (RAG contexts can be long)
max_len = 4096 # Allow longer contexts for Mamba
if len(ids) > max_len:
ids = ids[:max_len]
mask = mask[:max_len]
batch.append((ids, mask))
if len(batch) == batch_size:
yield collate_and_yield(batch)
batch = []
# Calculate gradient accumulation
examples_per_step = device_batch_size * ddp_world_size
print0(f"\nTraining configuration:")
print0(f" Device batch size: {device_batch_size}")
print0(f" Examples per step: {examples_per_step}")
assert target_examples_per_step % examples_per_step == 0
grad_accum_steps = target_examples_per_step // examples_per_step
print0(f" Gradient accumulation steps: {grad_accum_steps}")
# Calculate iterations
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
if max_iterations >= 0 and num_iterations > max_iterations:
num_iterations = max_iterations
print0(f" Number of iterations: {num_iterations}")
train_loader = rag_data_generator(train_ds, batch_size=device_batch_size)
build_val_loader = lambda: rag_data_generator(val_ds, batch_size=device_batch_size)
# -----------------------------------------------------------------------------
# Initialize optimizer
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr,
embedding_lr=embedding_lr,
matrix_lr=matrix_lr,
weight_decay=weight_decay,
)
# Set initial LR
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["initial_lr"] = group["lr"]
# -----------------------------------------------------------------------------
# Training loop
print0("\n" + "="*80)
print0("Starting RAG Fine-Tuning")
print0("="*80 + "\n")
def get_lr_multiplier(it):
"""Linear decay to 0."""
return 1.0 - it / num_iterations
# Training loop
step = 0
train_iter = iter(train_loader)
best_val_loss = float('inf')
for step in range(num_iterations):
last_step = step == num_iterations - 1
# Validation
if last_step or step % eval_every == 0:
model.eval()
val_iter = iter(build_val_loader())
losses = []
for _ in range(eval_steps):
val_inputs, val_targets = next(val_iter)
with torch.no_grad(), autocast_ctx:
loss = model(val_inputs, val_targets)
losses.append(loss)
val_loss = torch.stack(losses).mean()
if ddp:
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
val_loss = val_loss.item()
if val_loss < best_val_loss:
best_val_loss = val_loss
print0(f"Step {step:05d} | Val loss: {val_loss:.6f} | Best: {best_val_loss:.6f}")
wandb_run.log({"step": step, "val_loss": val_loss, "best_val_loss": best_val_loss})
model.train()
if last_step:
break
# Training step
for micro_step in range(grad_accum_steps):
train_inputs, train_targets = next(train_iter)
with autocast_ctx:
loss = model(train_inputs, train_targets)
train_loss = loss.detach()
loss = loss / grad_accum_steps
loss.backward()
# Update
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
# Logging
if step % 10 == 0:
train_loss_item = train_loss.item()
print0(f"Step {step:05d}/{num_iterations:05d} | Train loss: {train_loss_item:.6f} | LR mult: {lrm:.4f}")
wandb_run.log({"step": step, "train_loss": train_loss_item, "lrm": lrm})
step += 1
# Save final model
if master_process:
base_dir = get_base_dir()
depth = model.config.n_layer
model_tag_out = f"d{depth}_rag"
checkpoint_dir = os.path.join(base_dir, "rag_checkpoints", model_tag_out)
model_config_kwargs = {
k: v for k, v in model.config.__dict__.items()
if not k.startswith('_')
}
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(),
None,
{
"step": step,
"val_loss": val_loss,
"best_val_loss": best_val_loss,
"model_config": model_config_kwargs,
"rag_config": {
"knowledge_base": knowledge_base,
"retriever_type": retriever_type,
"top_k": top_k,
"base_tasks": base_tasks
}
}
)
print0(f"\n✅ Saved RAG model to {checkpoint_dir}")
# Log to report
from nanochat.report import get_report
get_report().log(section="RAG Fine-Tuning", data=[
user_config,
{
"Training examples": len(train_ds),
"Number of iterations": num_iterations,
"Final val loss": val_loss,
"Best val loss": best_val_loss,
"Knowledge base": knowledge_base,
"Retriever type": retriever_type,
"Top-k documents": top_k
}
])
print0("\n" + "="*80)
print0("RAG Fine-Tuning Complete!")
print0("="*80)
# Cleanup
wandb_run.finish()
compute_cleanup()