mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-09 05:42:22 +00:00
389 lines
13 KiB
Python
389 lines
13 KiB
Python
"""
|
|
RAG Fine-tuning Script for Mamba and Hybrid Models
|
|
|
|
Fine-tune a pretrained model with retrieval-augmented generation.
|
|
Optimized for Mamba and hybrid (Transformer+Mamba) architectures.
|
|
|
|
Usage:
|
|
# Single GPU
|
|
python -m scripts.rag_finetune --knowledge_base data/kb
|
|
|
|
# Multi-GPU
|
|
torchrun --standalone --nproc_per_node=8 -m scripts.rag_finetune \
|
|
--knowledge_base data/kb \
|
|
--source mid \
|
|
--retriever_type dense
|
|
|
|
Only works with Mamba or hybrid models (block_pattern must contain "M").
|
|
"""
|
|
|
|
import os
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
import time
|
|
import wandb
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
|
|
from nanochat.checkpoint_manager import load_model, save_checkpoint
|
|
from nanochat.engine import Engine
|
|
from nanochat.retrieval import RetrievalManager
|
|
from nanochat.rag_utils import render_rag_conversation_for_tokenizer
|
|
|
|
from tasks.rag_task import RAGTask, create_rag_task
|
|
from tasks.smoltalk import SmolTalk
|
|
from tasks.mmlu import MMLU
|
|
from tasks.arc import ARC
|
|
from tasks.gsm8k import GSM8K
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# RAG Fine-tuning Hyperparameters
|
|
run = "dummy" # wandb run name
|
|
# Model options
|
|
source = "mid" # base|mid - which checkpoint to load
|
|
model_tag = None # model tag to load
|
|
step = None # step to load
|
|
# RAG options
|
|
knowledge_base = None # REQUIRED: path to knowledge base
|
|
retriever_type = "simple" # simple|dense
|
|
top_k = 5 # number of documents to retrieve
|
|
max_doc_length = 500 # max characters per document in prompt
|
|
insert_position = "before_user" # where to insert retrieval
|
|
# Task options
|
|
base_tasks = "SmolTalk" # comma-separated: SmolTalk,MMLU,ARC-Easy,GSM8K
|
|
task_samples = 10000 # samples per task (-1 = all)
|
|
# Training options
|
|
dtype = "bfloat16"
|
|
device_batch_size = 4 # smaller due to longer contexts with RAG
|
|
num_epochs = 1
|
|
max_iterations = -1
|
|
target_examples_per_step = 32
|
|
# Optimization
|
|
unembedding_lr = 0.004
|
|
embedding_lr = 0.2
|
|
matrix_lr = 0.02
|
|
weight_decay = 0.0
|
|
init_lr_frac = 0.02 # start with lower LR for stability
|
|
# Evaluation
|
|
eval_every = 100
|
|
eval_steps = 50
|
|
# Allow CLI overrides
|
|
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
|
exec(open(os.path.join('nanochat', 'configurator.py')).read())
|
|
user_config = {k: globals()[k] for k in config_keys}
|
|
# -----------------------------------------------------------------------------
|
|
|
|
# Validate
|
|
if knowledge_base is None:
|
|
raise ValueError("--knowledge_base is required for RAG fine-tuning")
|
|
|
|
if not os.path.exists(knowledge_base):
|
|
raise FileNotFoundError(f"Knowledge base not found: {knowledge_base}")
|
|
|
|
# Compute init
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
|
master_process = ddp_rank == 0
|
|
dtype_torch = torch.float32 if dtype == 'float32' else torch.bfloat16
|
|
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype_torch)
|
|
|
|
# WandB logging
|
|
use_dummy_wandb = run == "dummy" or not master_process
|
|
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(
|
|
project="nanochat-rag",
|
|
name=run,
|
|
config=user_config,
|
|
save_code=True
|
|
)
|
|
|
|
# Load model and tokenizer
|
|
print0(f"Loading model from {source} checkpoint...")
|
|
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
|
|
|
# Validate model has Mamba blocks
|
|
block_pattern = model.config.block_pattern
|
|
if block_pattern is None or "M" not in "".join(block_pattern):
|
|
raise ValueError(
|
|
"RAG fine-tuning requires Mamba or hybrid models. "
|
|
f"Current block_pattern: {block_pattern}. "
|
|
"Please use a model with Mamba blocks (contains 'M')."
|
|
)
|
|
|
|
print0(f"✓ Model has block pattern: {block_pattern}")
|
|
print0(f" Transformer blocks: {block_pattern.count('T')}")
|
|
print0(f" Mamba blocks: {block_pattern.count('M')}")
|
|
|
|
orig_model = model
|
|
# Don't compile for RAG (variable-length contexts)
|
|
# model = torch.compile(model, dynamic=True)
|
|
|
|
# Initialize retrieval manager
|
|
print0(f"Loading knowledge base from {knowledge_base}...")
|
|
print0(f"Using retriever type: {retriever_type}")
|
|
retrieval_manager = RetrievalManager(
|
|
retriever_type=retriever_type,
|
|
knowledge_base_path=knowledge_base
|
|
)
|
|
print0("✓ Knowledge base loaded")
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Create RAG tasks
|
|
|
|
print0(f"Creating RAG tasks from base tasks: {base_tasks}")
|
|
task_list = base_tasks.split(",")
|
|
train_rag_tasks = []
|
|
val_rag_tasks = []
|
|
|
|
for task_name in task_list:
|
|
task_name = task_name.strip()
|
|
print0(f" Creating RAG wrapper for {task_name}...")
|
|
|
|
# Create training task
|
|
try:
|
|
train_task = create_rag_task(
|
|
task_name=task_name,
|
|
split="train",
|
|
knowledge_base_path=knowledge_base,
|
|
retriever_type=retriever_type,
|
|
top_k=top_k,
|
|
stop=task_samples if task_samples > 0 else None
|
|
)
|
|
train_rag_tasks.append(train_task)
|
|
print0(f" Train: {len(train_task)} examples")
|
|
except Exception as e:
|
|
print0(f" Warning: Could not create train task for {task_name}: {e}")
|
|
|
|
# Create validation task
|
|
try:
|
|
val_task = create_rag_task(
|
|
task_name=task_name,
|
|
split="test" if task_name == "SmolTalk" else "val",
|
|
knowledge_base_path=knowledge_base,
|
|
retriever_type=retriever_type,
|
|
top_k=top_k,
|
|
stop=1000 # Limit validation size
|
|
)
|
|
val_rag_tasks.append(val_task)
|
|
print0(f" Val: {len(val_task)} examples")
|
|
except Exception as e:
|
|
print0(f" Warning: Could not create val task for {task_name}: {e}")
|
|
|
|
# Combine tasks
|
|
from tasks.common import TaskMixture
|
|
train_ds = TaskMixture(train_rag_tasks) if len(train_rag_tasks) > 1 else train_rag_tasks[0]
|
|
val_ds = TaskMixture(val_rag_tasks) if len(val_rag_tasks) > 1 else (val_rag_tasks[0] if val_rag_tasks else train_rag_tasks[0])
|
|
|
|
print0(f"\n✓ Total training examples: {len(train_ds)}")
|
|
print0(f"✓ Total validation examples: {len(val_ds)}")
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# DataLoader for RAG
|
|
|
|
def rag_data_generator(dataset, batch_size):
|
|
"""Data generator for RAG training with retrieved documents."""
|
|
pad_token_id = tokenizer.encode_special("<|assistant_end|>")
|
|
|
|
def collate_and_yield(batch):
|
|
"""Collate RAG conversations into batch."""
|
|
nrows = len(batch)
|
|
ncols = max(len(ids) for ids, mask in batch) - 1
|
|
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
|
|
targets = torch.full((nrows, ncols), -1, dtype=torch.long)
|
|
|
|
for i, (ids, mask) in enumerate(batch):
|
|
n = len(ids)
|
|
ids_tensor = torch.tensor(ids, dtype=torch.long)
|
|
inputs[i, :n-1] = ids_tensor[:-1]
|
|
|
|
row_targets = ids_tensor[1:]
|
|
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
|
|
row_targets[mask_tensor == 0] = -1
|
|
targets[i, :n-1] = row_targets
|
|
|
|
inputs = inputs.to(device)
|
|
targets = targets.to(device)
|
|
return inputs, targets
|
|
|
|
batch = []
|
|
while True:
|
|
for i in range(ddp_rank, len(dataset), ddp_world_size):
|
|
# Get RAG-augmented conversation
|
|
conversation = dataset[i]
|
|
|
|
# Render to tokens
|
|
ids, mask = tokenizer.render_conversation(conversation)
|
|
|
|
# Truncate if too long (RAG contexts can be long)
|
|
max_len = 4096 # Allow longer contexts for Mamba
|
|
if len(ids) > max_len:
|
|
ids = ids[:max_len]
|
|
mask = mask[:max_len]
|
|
|
|
batch.append((ids, mask))
|
|
|
|
if len(batch) == batch_size:
|
|
yield collate_and_yield(batch)
|
|
batch = []
|
|
|
|
# Calculate gradient accumulation
|
|
examples_per_step = device_batch_size * ddp_world_size
|
|
print0(f"\nTraining configuration:")
|
|
print0(f" Device batch size: {device_batch_size}")
|
|
print0(f" Examples per step: {examples_per_step}")
|
|
assert target_examples_per_step % examples_per_step == 0
|
|
grad_accum_steps = target_examples_per_step // examples_per_step
|
|
print0(f" Gradient accumulation steps: {grad_accum_steps}")
|
|
|
|
# Calculate iterations
|
|
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
|
if max_iterations >= 0 and num_iterations > max_iterations:
|
|
num_iterations = max_iterations
|
|
print0(f" Number of iterations: {num_iterations}")
|
|
|
|
train_loader = rag_data_generator(train_ds, batch_size=device_batch_size)
|
|
build_val_loader = lambda: rag_data_generator(val_ds, batch_size=device_batch_size)
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Initialize optimizer
|
|
|
|
optimizers = model.setup_optimizers(
|
|
unembedding_lr=unembedding_lr,
|
|
embedding_lr=embedding_lr,
|
|
matrix_lr=matrix_lr,
|
|
weight_decay=weight_decay,
|
|
)
|
|
|
|
# Set initial LR
|
|
for opt in optimizers:
|
|
for group in opt.param_groups:
|
|
group["lr"] = group["lr"] * init_lr_frac
|
|
group["initial_lr"] = group["lr"]
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Training loop
|
|
|
|
print0("\n" + "="*80)
|
|
print0("Starting RAG Fine-Tuning")
|
|
print0("="*80 + "\n")
|
|
|
|
def get_lr_multiplier(it):
|
|
"""Linear decay to 0."""
|
|
return 1.0 - it / num_iterations
|
|
|
|
# Training loop
|
|
step = 0
|
|
train_iter = iter(train_loader)
|
|
best_val_loss = float('inf')
|
|
|
|
for step in range(num_iterations):
|
|
last_step = step == num_iterations - 1
|
|
|
|
# Validation
|
|
if last_step or step % eval_every == 0:
|
|
model.eval()
|
|
val_iter = iter(build_val_loader())
|
|
losses = []
|
|
|
|
for _ in range(eval_steps):
|
|
val_inputs, val_targets = next(val_iter)
|
|
with torch.no_grad(), autocast_ctx:
|
|
loss = model(val_inputs, val_targets)
|
|
losses.append(loss)
|
|
|
|
val_loss = torch.stack(losses).mean()
|
|
if ddp:
|
|
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
|
|
val_loss = val_loss.item()
|
|
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
|
|
print0(f"Step {step:05d} | Val loss: {val_loss:.6f} | Best: {best_val_loss:.6f}")
|
|
wandb_run.log({"step": step, "val_loss": val_loss, "best_val_loss": best_val_loss})
|
|
model.train()
|
|
|
|
if last_step:
|
|
break
|
|
|
|
# Training step
|
|
for micro_step in range(grad_accum_steps):
|
|
train_inputs, train_targets = next(train_iter)
|
|
with autocast_ctx:
|
|
loss = model(train_inputs, train_targets)
|
|
train_loss = loss.detach()
|
|
loss = loss / grad_accum_steps
|
|
loss.backward()
|
|
|
|
# Update
|
|
lrm = get_lr_multiplier(step)
|
|
for opt in optimizers:
|
|
for group in opt.param_groups:
|
|
group["lr"] = group["initial_lr"] * lrm
|
|
|
|
for opt in optimizers:
|
|
opt.step()
|
|
model.zero_grad(set_to_none=True)
|
|
|
|
# Logging
|
|
if step % 10 == 0:
|
|
train_loss_item = train_loss.item()
|
|
print0(f"Step {step:05d}/{num_iterations:05d} | Train loss: {train_loss_item:.6f} | LR mult: {lrm:.4f}")
|
|
wandb_run.log({"step": step, "train_loss": train_loss_item, "lrm": lrm})
|
|
|
|
step += 1
|
|
|
|
# Save final model
|
|
if master_process:
|
|
base_dir = get_base_dir()
|
|
depth = model.config.n_layer
|
|
model_tag_out = f"d{depth}_rag"
|
|
checkpoint_dir = os.path.join(base_dir, "rag_checkpoints", model_tag_out)
|
|
|
|
model_config_kwargs = {
|
|
k: v for k, v in model.config.__dict__.items()
|
|
if not k.startswith('_')
|
|
}
|
|
|
|
save_checkpoint(
|
|
checkpoint_dir,
|
|
step,
|
|
orig_model.state_dict(),
|
|
None,
|
|
{
|
|
"step": step,
|
|
"val_loss": val_loss,
|
|
"best_val_loss": best_val_loss,
|
|
"model_config": model_config_kwargs,
|
|
"rag_config": {
|
|
"knowledge_base": knowledge_base,
|
|
"retriever_type": retriever_type,
|
|
"top_k": top_k,
|
|
"base_tasks": base_tasks
|
|
}
|
|
}
|
|
)
|
|
print0(f"\n✅ Saved RAG model to {checkpoint_dir}")
|
|
|
|
# Log to report
|
|
from nanochat.report import get_report
|
|
get_report().log(section="RAG Fine-Tuning", data=[
|
|
user_config,
|
|
{
|
|
"Training examples": len(train_ds),
|
|
"Number of iterations": num_iterations,
|
|
"Final val loss": val_loss,
|
|
"Best val loss": best_val_loss,
|
|
"Knowledge base": knowledge_base,
|
|
"Retriever type": retriever_type,
|
|
"Top-k documents": top_k
|
|
}
|
|
])
|
|
|
|
print0("\n" + "="*80)
|
|
print0("RAG Fine-Tuning Complete!")
|
|
print0("="*80)
|
|
|
|
# Cleanup
|
|
wandb_run.finish()
|
|
compute_cleanup()
|
|
|