nanochat/scripts/refrag_finetune.py
2025-10-15 11:19:36 +02:00

349 lines
11 KiB
Python

"""
REFRAG (Recursive Retrieval-Augmented Generation) Fine-tuning
Train models with multi-hop retrieval and reinforcement learning.
Optimized for Mamba and hybrid architectures.
Usage:
torchrun --standalone --nproc_per_node=8 -m scripts.refrag_finetune \
--knowledge_base data/kb \
--max_hops 3 \
--use_rewards true
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.distributed as dist
import wandb
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
from nanochat.checkpoint_manager import load_model, save_checkpoint
from nanochat.retrieval import RetrievalManager
from nanochat.rag_utils import compute_rag_reward
from tasks.rag_task import MultiHopRAGTask
from tasks.smoltalk import SmolTalk
# -----------------------------------------------------------------------------
# REFRAG Hyperparameters
run = "dummy"
# Model
source = "mid"
model_tag = None
step = None
# RAG
knowledge_base = None # REQUIRED
retriever_type = "dense"
max_hops = 3 # number of retrieval hops
top_k_per_hop = 3 # docs per hop
# RL options
use_rewards = True # use RL-style rewards
reward_weight_answer = 0.6
reward_weight_relevance = 0.3
reward_weight_efficiency = 0.1
# Training
dtype = "bfloat16"
device_batch_size = 2 # smaller for multi-hop (longer contexts)
num_epochs = 1
max_iterations = 500 # REFRAG is expensive, limit iterations
target_examples_per_step = 16
# Optimization
unembedding_lr = 0.002 # lower LR for stability
embedding_lr = 0.1
matrix_lr = 0.01
weight_decay = 0.0
init_lr_frac = 0.01 # very conservative start
# Eval
eval_every = 50
eval_steps = 20
# CLI overrides
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read())
user_config = {k: globals()[k] for k in config_keys}
# -----------------------------------------------------------------------------
if knowledge_base is None:
raise ValueError("--knowledge_base required")
# Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
master_process = ddp_rank == 0
dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype_torch)
# WandB
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(
project="nanochat-refrag",
name=run,
config=user_config
)
# Load model
print0("Loading model...")
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
# Validate Mamba/hybrid
block_pattern = model.config.block_pattern
if block_pattern is None or "M" not in "".join(block_pattern):
raise ValueError("REFRAG requires Mamba or hybrid models")
print0(f"✓ Model: {block_pattern.count('T')} transformer, {block_pattern.count('M')} Mamba blocks")
orig_model = model
# Load retrieval
print0(f"Loading knowledge base...")
retrieval_manager = RetrievalManager(
retriever_type=retriever_type,
knowledge_base_path=knowledge_base
)
print0("✓ Knowledge base loaded")
# Create multi-hop RAG task
print0(f"Creating multi-hop RAG task (max_hops={max_hops})...")
base_task = SmolTalk(split="train", stop=5000) # Limit for REFRAG
train_task = MultiHopRAGTask(
base_task=base_task,
knowledge_base_path=knowledge_base,
retriever_type=retriever_type,
max_hops=max_hops,
top_k_per_hop=top_k_per_hop
)
val_base = SmolTalk(split="test", stop=500)
val_task = MultiHopRAGTask(
base_task=val_base,
knowledge_base_path=knowledge_base,
retriever_type=retriever_type,
max_hops=max_hops,
top_k_per_hop=top_k_per_hop
)
print0(f"✓ Train: {len(train_task)} examples")
print0(f"✓ Val: {len(val_task)} examples")
# DataLoader
def refrag_data_generator(dataset, batch_size):
"""Data generator for REFRAG (handles multi-hop retrieval)."""
pad_token_id = tokenizer.encode_special("<|assistant_end|>")
def collate_and_yield(batch):
nrows = len(batch)
ncols = max(len(ids) for ids, mask, _ in batch) - 1
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
targets = torch.full((nrows, ncols), -1, dtype=torch.long)
rewards_list = []
for i, (ids, mask, reward) in enumerate(batch):
n = len(ids)
ids_tensor = torch.tensor(ids, dtype=torch.long)
inputs[i, :n-1] = ids_tensor[:-1]
row_targets = ids_tensor[1:]
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
row_targets[mask_tensor == 0] = -1
targets[i, :n-1] = row_targets
rewards_list.append(reward)
inputs = inputs.to(device)
targets = targets.to(device)
rewards = torch.tensor(rewards_list, device=device, dtype=dtype_torch)
return inputs, targets, rewards
batch = []
while True:
for i in range(ddp_rank, len(dataset), ddp_world_size):
conversation = dataset[i]
ids, mask = tokenizer.render_conversation(conversation)
# Truncate if needed
max_len = 6144 # Allow longer for multi-hop
if len(ids) > max_len:
ids = ids[:max_len]
mask = mask[:max_len]
# Compute reward if using RL
reward = 1.0 # default
if use_rewards:
# Simple reward: based on conversation structure
# In full RL, would compare generated vs ground truth
reward = compute_refrag_reward(conversation)
batch.append((ids, mask, reward))
if len(batch) == batch_size:
yield collate_and_yield(batch)
batch = []
def compute_refrag_reward(conversation):
"""Compute reward for REFRAG training."""
messages = conversation.get("messages", [])
# Check if retrieval was successful
has_retrieval = any(msg.get("role") == "retrieval" for msg in messages)
if not has_retrieval:
return 0.5 # penalty for no retrieval
# Check if multi-hop
retrieval_msg = next((m for m in messages if m.get("role") == "retrieval"), None)
if retrieval_msg and retrieval_msg.get("multi_hop"):
hops = retrieval_msg.get("hops", [])
num_hops = len(hops)
# Reward more hops (up to max_hops)
hop_reward = min(num_hops / max_hops, 1.0)
else:
hop_reward = 0.3 # penalty for single-hop
# Combine rewards
return 0.5 + 0.5 * hop_reward
# Training setup
examples_per_step = device_batch_size * ddp_world_size
grad_accum_steps = target_examples_per_step // examples_per_step
num_iterations = min(max_iterations, (len(train_task) // target_examples_per_step) * num_epochs)
print0(f"\nTraining configuration:")
print0(f" Device batch size: {device_batch_size}")
print0(f" Gradient accumulation: {grad_accum_steps}")
print0(f" Iterations: {num_iterations}")
train_loader = refrag_data_generator(train_task, device_batch_size)
build_val_loader = lambda: refrag_data_generator(val_task, device_batch_size)
# Optimizer
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr,
embedding_lr=embedding_lr,
matrix_lr=matrix_lr,
weight_decay=weight_decay
)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["initial_lr"] = group["lr"]
# Training loop
print0("\n" + "="*80)
print0("Starting REFRAG Training (Multi-hop RAG with RL)")
print0("="*80 + "\n")
def get_lr_multiplier(it):
return 1.0 - it / num_iterations
step = 0
train_iter = iter(train_loader)
best_val_loss = float('inf')
for step in range(num_iterations):
last_step = step == num_iterations - 1
# Validation
if last_step or step % eval_every == 0:
model.eval()
val_iter = iter(build_val_loader())
losses = []
rewards_list = []
for _ in range(eval_steps):
val_inputs, val_targets, val_rewards = next(val_iter)
with torch.no_grad(), autocast_ctx:
loss = model(val_inputs, val_targets)
losses.append(loss)
rewards_list.append(val_rewards.mean())
val_loss = torch.stack(losses).mean()
avg_reward = torch.stack(rewards_list).mean()
if ddp:
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
dist.all_reduce(avg_reward, op=dist.ReduceOp.AVG)
val_loss = val_loss.item()
avg_reward = avg_reward.item()
if val_loss < best_val_loss:
best_val_loss = val_loss
print0(f"Step {step:05d} | Val loss: {val_loss:.6f} | Reward: {avg_reward:.4f} | Best: {best_val_loss:.6f}")
wandb_run.log({"step": step, "val_loss": val_loss, "avg_reward": avg_reward})
model.train()
if last_step:
break
# Training step with reward weighting
total_loss = 0
for micro_step in range(grad_accum_steps):
train_inputs, train_targets, train_rewards = next(train_iter)
with autocast_ctx:
loss = model(train_inputs, train_targets, loss_reduction='none') # per-example loss
if use_rewards:
# Weight loss by rewards (RL-style)
weighted_loss = (loss * train_rewards).mean()
else:
weighted_loss = loss.mean()
train_loss = weighted_loss.detach()
total_loss += train_loss
weighted_loss = weighted_loss / grad_accum_steps
weighted_loss.backward()
# Update
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
# Logging
if step % 10 == 0:
avg_loss = (total_loss / grad_accum_steps).item()
print0(f"Step {step:05d}/{num_iterations:05d} | Train loss: {avg_loss:.6f} | LR: {lrm:.4f}")
wandb_run.log({"step": step, "train_loss": avg_loss, "lrm": lrm})
# Save
if master_process:
base_dir = get_base_dir()
depth = model.config.n_layer
model_tag_out = f"d{depth}_refrag"
checkpoint_dir = os.path.join(base_dir, "refrag_checkpoints", model_tag_out)
model_config_kwargs = {k: v for k, v in model.config.__dict__.items() if not k.startswith('_')}
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(),
None,
{
"step": step,
"val_loss": val_loss,
"model_config": model_config_kwargs,
"refrag_config": {
"knowledge_base": knowledge_base,
"max_hops": max_hops,
"use_rewards": use_rewards
}
}
)
print0(f"\n✅ Saved REFRAG model to {checkpoint_dir}")
print0("\n" + "="*80)
print0("REFRAG Training Complete!")
print0("="*80)
# Cleanup
wandb_run.finish()
compute_cleanup()