mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-09 13:52:17 +00:00
349 lines
11 KiB
Python
349 lines
11 KiB
Python
"""
|
|
REFRAG (Recursive Retrieval-Augmented Generation) Fine-tuning
|
|
|
|
Train models with multi-hop retrieval and reinforcement learning.
|
|
Optimized for Mamba and hybrid architectures.
|
|
|
|
Usage:
|
|
torchrun --standalone --nproc_per_node=8 -m scripts.refrag_finetune \
|
|
--knowledge_base data/kb \
|
|
--max_hops 3 \
|
|
--use_rewards true
|
|
"""
|
|
|
|
import os
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
import torch
|
|
import torch.distributed as dist
|
|
import wandb
|
|
|
|
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
|
|
from nanochat.checkpoint_manager import load_model, save_checkpoint
|
|
from nanochat.retrieval import RetrievalManager
|
|
from nanochat.rag_utils import compute_rag_reward
|
|
|
|
from tasks.rag_task import MultiHopRAGTask
|
|
from tasks.smoltalk import SmolTalk
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# REFRAG Hyperparameters
|
|
run = "dummy"
|
|
# Model
|
|
source = "mid"
|
|
model_tag = None
|
|
step = None
|
|
# RAG
|
|
knowledge_base = None # REQUIRED
|
|
retriever_type = "dense"
|
|
max_hops = 3 # number of retrieval hops
|
|
top_k_per_hop = 3 # docs per hop
|
|
# RL options
|
|
use_rewards = True # use RL-style rewards
|
|
reward_weight_answer = 0.6
|
|
reward_weight_relevance = 0.3
|
|
reward_weight_efficiency = 0.1
|
|
# Training
|
|
dtype = "bfloat16"
|
|
device_batch_size = 2 # smaller for multi-hop (longer contexts)
|
|
num_epochs = 1
|
|
max_iterations = 500 # REFRAG is expensive, limit iterations
|
|
target_examples_per_step = 16
|
|
# Optimization
|
|
unembedding_lr = 0.002 # lower LR for stability
|
|
embedding_lr = 0.1
|
|
matrix_lr = 0.01
|
|
weight_decay = 0.0
|
|
init_lr_frac = 0.01 # very conservative start
|
|
# Eval
|
|
eval_every = 50
|
|
eval_steps = 20
|
|
# CLI overrides
|
|
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
|
exec(open(os.path.join('nanochat', 'configurator.py')).read())
|
|
user_config = {k: globals()[k] for k in config_keys}
|
|
# -----------------------------------------------------------------------------
|
|
|
|
if knowledge_base is None:
|
|
raise ValueError("--knowledge_base required")
|
|
|
|
# Compute init
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
|
master_process = ddp_rank == 0
|
|
dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
|
|
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype_torch)
|
|
|
|
# WandB
|
|
use_dummy_wandb = run == "dummy" or not master_process
|
|
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(
|
|
project="nanochat-refrag",
|
|
name=run,
|
|
config=user_config
|
|
)
|
|
|
|
# Load model
|
|
print0("Loading model...")
|
|
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
|
|
|
# Validate Mamba/hybrid
|
|
block_pattern = model.config.block_pattern
|
|
if block_pattern is None or "M" not in "".join(block_pattern):
|
|
raise ValueError("REFRAG requires Mamba or hybrid models")
|
|
|
|
print0(f"✓ Model: {block_pattern.count('T')} transformer, {block_pattern.count('M')} Mamba blocks")
|
|
|
|
orig_model = model
|
|
|
|
# Load retrieval
|
|
print0(f"Loading knowledge base...")
|
|
retrieval_manager = RetrievalManager(
|
|
retriever_type=retriever_type,
|
|
knowledge_base_path=knowledge_base
|
|
)
|
|
print0("✓ Knowledge base loaded")
|
|
|
|
# Create multi-hop RAG task
|
|
print0(f"Creating multi-hop RAG task (max_hops={max_hops})...")
|
|
base_task = SmolTalk(split="train", stop=5000) # Limit for REFRAG
|
|
train_task = MultiHopRAGTask(
|
|
base_task=base_task,
|
|
knowledge_base_path=knowledge_base,
|
|
retriever_type=retriever_type,
|
|
max_hops=max_hops,
|
|
top_k_per_hop=top_k_per_hop
|
|
)
|
|
|
|
val_base = SmolTalk(split="test", stop=500)
|
|
val_task = MultiHopRAGTask(
|
|
base_task=val_base,
|
|
knowledge_base_path=knowledge_base,
|
|
retriever_type=retriever_type,
|
|
max_hops=max_hops,
|
|
top_k_per_hop=top_k_per_hop
|
|
)
|
|
|
|
print0(f"✓ Train: {len(train_task)} examples")
|
|
print0(f"✓ Val: {len(val_task)} examples")
|
|
|
|
# DataLoader
|
|
def refrag_data_generator(dataset, batch_size):
|
|
"""Data generator for REFRAG (handles multi-hop retrieval)."""
|
|
pad_token_id = tokenizer.encode_special("<|assistant_end|>")
|
|
|
|
def collate_and_yield(batch):
|
|
nrows = len(batch)
|
|
ncols = max(len(ids) for ids, mask, _ in batch) - 1
|
|
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
|
|
targets = torch.full((nrows, ncols), -1, dtype=torch.long)
|
|
rewards_list = []
|
|
|
|
for i, (ids, mask, reward) in enumerate(batch):
|
|
n = len(ids)
|
|
ids_tensor = torch.tensor(ids, dtype=torch.long)
|
|
inputs[i, :n-1] = ids_tensor[:-1]
|
|
|
|
row_targets = ids_tensor[1:]
|
|
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
|
|
row_targets[mask_tensor == 0] = -1
|
|
targets[i, :n-1] = row_targets
|
|
|
|
rewards_list.append(reward)
|
|
|
|
inputs = inputs.to(device)
|
|
targets = targets.to(device)
|
|
rewards = torch.tensor(rewards_list, device=device, dtype=dtype_torch)
|
|
|
|
return inputs, targets, rewards
|
|
|
|
batch = []
|
|
while True:
|
|
for i in range(ddp_rank, len(dataset), ddp_world_size):
|
|
conversation = dataset[i]
|
|
ids, mask = tokenizer.render_conversation(conversation)
|
|
|
|
# Truncate if needed
|
|
max_len = 6144 # Allow longer for multi-hop
|
|
if len(ids) > max_len:
|
|
ids = ids[:max_len]
|
|
mask = mask[:max_len]
|
|
|
|
# Compute reward if using RL
|
|
reward = 1.0 # default
|
|
if use_rewards:
|
|
# Simple reward: based on conversation structure
|
|
# In full RL, would compare generated vs ground truth
|
|
reward = compute_refrag_reward(conversation)
|
|
|
|
batch.append((ids, mask, reward))
|
|
|
|
if len(batch) == batch_size:
|
|
yield collate_and_yield(batch)
|
|
batch = []
|
|
|
|
def compute_refrag_reward(conversation):
|
|
"""Compute reward for REFRAG training."""
|
|
messages = conversation.get("messages", [])
|
|
|
|
# Check if retrieval was successful
|
|
has_retrieval = any(msg.get("role") == "retrieval" for msg in messages)
|
|
if not has_retrieval:
|
|
return 0.5 # penalty for no retrieval
|
|
|
|
# Check if multi-hop
|
|
retrieval_msg = next((m for m in messages if m.get("role") == "retrieval"), None)
|
|
if retrieval_msg and retrieval_msg.get("multi_hop"):
|
|
hops = retrieval_msg.get("hops", [])
|
|
num_hops = len(hops)
|
|
# Reward more hops (up to max_hops)
|
|
hop_reward = min(num_hops / max_hops, 1.0)
|
|
else:
|
|
hop_reward = 0.3 # penalty for single-hop
|
|
|
|
# Combine rewards
|
|
return 0.5 + 0.5 * hop_reward
|
|
|
|
# Training setup
|
|
examples_per_step = device_batch_size * ddp_world_size
|
|
grad_accum_steps = target_examples_per_step // examples_per_step
|
|
num_iterations = min(max_iterations, (len(train_task) // target_examples_per_step) * num_epochs)
|
|
|
|
print0(f"\nTraining configuration:")
|
|
print0(f" Device batch size: {device_batch_size}")
|
|
print0(f" Gradient accumulation: {grad_accum_steps}")
|
|
print0(f" Iterations: {num_iterations}")
|
|
|
|
train_loader = refrag_data_generator(train_task, device_batch_size)
|
|
build_val_loader = lambda: refrag_data_generator(val_task, device_batch_size)
|
|
|
|
# Optimizer
|
|
optimizers = model.setup_optimizers(
|
|
unembedding_lr=unembedding_lr,
|
|
embedding_lr=embedding_lr,
|
|
matrix_lr=matrix_lr,
|
|
weight_decay=weight_decay
|
|
)
|
|
|
|
for opt in optimizers:
|
|
for group in opt.param_groups:
|
|
group["lr"] = group["lr"] * init_lr_frac
|
|
group["initial_lr"] = group["lr"]
|
|
|
|
# Training loop
|
|
print0("\n" + "="*80)
|
|
print0("Starting REFRAG Training (Multi-hop RAG with RL)")
|
|
print0("="*80 + "\n")
|
|
|
|
def get_lr_multiplier(it):
|
|
return 1.0 - it / num_iterations
|
|
|
|
step = 0
|
|
train_iter = iter(train_loader)
|
|
best_val_loss = float('inf')
|
|
|
|
for step in range(num_iterations):
|
|
last_step = step == num_iterations - 1
|
|
|
|
# Validation
|
|
if last_step or step % eval_every == 0:
|
|
model.eval()
|
|
val_iter = iter(build_val_loader())
|
|
losses = []
|
|
rewards_list = []
|
|
|
|
for _ in range(eval_steps):
|
|
val_inputs, val_targets, val_rewards = next(val_iter)
|
|
with torch.no_grad(), autocast_ctx:
|
|
loss = model(val_inputs, val_targets)
|
|
losses.append(loss)
|
|
rewards_list.append(val_rewards.mean())
|
|
|
|
val_loss = torch.stack(losses).mean()
|
|
avg_reward = torch.stack(rewards_list).mean()
|
|
|
|
if ddp:
|
|
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
|
|
dist.all_reduce(avg_reward, op=dist.ReduceOp.AVG)
|
|
|
|
val_loss = val_loss.item()
|
|
avg_reward = avg_reward.item()
|
|
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
|
|
print0(f"Step {step:05d} | Val loss: {val_loss:.6f} | Reward: {avg_reward:.4f} | Best: {best_val_loss:.6f}")
|
|
wandb_run.log({"step": step, "val_loss": val_loss, "avg_reward": avg_reward})
|
|
model.train()
|
|
|
|
if last_step:
|
|
break
|
|
|
|
# Training step with reward weighting
|
|
total_loss = 0
|
|
for micro_step in range(grad_accum_steps):
|
|
train_inputs, train_targets, train_rewards = next(train_iter)
|
|
|
|
with autocast_ctx:
|
|
loss = model(train_inputs, train_targets, loss_reduction='none') # per-example loss
|
|
|
|
if use_rewards:
|
|
# Weight loss by rewards (RL-style)
|
|
weighted_loss = (loss * train_rewards).mean()
|
|
else:
|
|
weighted_loss = loss.mean()
|
|
|
|
train_loss = weighted_loss.detach()
|
|
total_loss += train_loss
|
|
weighted_loss = weighted_loss / grad_accum_steps
|
|
weighted_loss.backward()
|
|
|
|
# Update
|
|
lrm = get_lr_multiplier(step)
|
|
for opt in optimizers:
|
|
for group in opt.param_groups:
|
|
group["lr"] = group["initial_lr"] * lrm
|
|
|
|
for opt in optimizers:
|
|
opt.step()
|
|
model.zero_grad(set_to_none=True)
|
|
|
|
# Logging
|
|
if step % 10 == 0:
|
|
avg_loss = (total_loss / grad_accum_steps).item()
|
|
print0(f"Step {step:05d}/{num_iterations:05d} | Train loss: {avg_loss:.6f} | LR: {lrm:.4f}")
|
|
wandb_run.log({"step": step, "train_loss": avg_loss, "lrm": lrm})
|
|
|
|
# Save
|
|
if master_process:
|
|
base_dir = get_base_dir()
|
|
depth = model.config.n_layer
|
|
model_tag_out = f"d{depth}_refrag"
|
|
checkpoint_dir = os.path.join(base_dir, "refrag_checkpoints", model_tag_out)
|
|
|
|
model_config_kwargs = {k: v for k, v in model.config.__dict__.items() if not k.startswith('_')}
|
|
|
|
save_checkpoint(
|
|
checkpoint_dir,
|
|
step,
|
|
orig_model.state_dict(),
|
|
None,
|
|
{
|
|
"step": step,
|
|
"val_loss": val_loss,
|
|
"model_config": model_config_kwargs,
|
|
"refrag_config": {
|
|
"knowledge_base": knowledge_base,
|
|
"max_hops": max_hops,
|
|
"use_rewards": use_rewards
|
|
}
|
|
}
|
|
)
|
|
print0(f"\n✅ Saved REFRAG model to {checkpoint_dir}")
|
|
|
|
print0("\n" + "="*80)
|
|
print0("REFRAG Training Complete!")
|
|
print0("="*80)
|
|
|
|
# Cleanup
|
|
wandb_run.finish()
|
|
compute_cleanup()
|
|
|