nanochat/scripts/chat_tool_rl.py
2026-03-24 20:52:36 -04:00

220 lines
10 KiB
Python

"""
Lightweight RL stage for tool-use tuning on local ToolJSON datasets.
1 GPU:
python -m scripts.chat_tool_rl --train-jsonl seed_data/tool_eval_seed.jsonl --eval-jsonl seed_data/tool_eval_seed.jsonl
8 GPUs:
torchrun --standalone --nproc_per_node=8 -m scripts.chat_tool_rl -- --train-jsonl seed_data/tool_eval_seed.jsonl --eval-jsonl seed_data/tool_eval_seed.jsonl
"""
import argparse
import itertools
import os
import torch
import torch.distributed as dist
import wandb
from nanochat.checkpoint_manager import load_model, save_checkpoint
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
from nanochat.engine import Engine
from nanochat.tools import DEFAULT_TOOL_SCHEMA
from tasks.tool_json import ToolJSON
parser = argparse.ArgumentParser(description="RL tuning on local tool-use JSONL tasks")
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--model-tag", type=str, default=None, help="SFT model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="SFT model step to load from")
parser.add_argument("--train-jsonl", type=str, required=True, help="Training ToolJSONL file")
parser.add_argument("--eval-jsonl", type=str, default=None, help="Evaluation ToolJSONL file")
parser.add_argument("--num-epochs", type=int, default=1, help="Number of epochs over ToolJSON")
parser.add_argument("--device-batch-size", type=int, default=8, help="Max batch size per forward pass")
parser.add_argument("--examples-per-step", type=int, default=8, help="Examples per optimization step across all ranks")
parser.add_argument("--num-samples", type=int, default=8, help="Number of samples per example")
parser.add_argument("--max-new-tokens", type=int, default=256, help="Max tokens to generate")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
parser.add_argument("--top-k", type=int, default=50, help="Top-k sampling (0 disables)")
parser.add_argument("--embedding-lr", type=float, default=0.2, help="Embedding LR")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="Unembedding LR")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="Matrix LR")
parser.add_argument("--weight-decay", type=float, default=0.0, help="Weight decay")
parser.add_argument("--init-lr-frac", type=float, default=0.05, help="Initial LR multiplier")
parser.add_argument("--eval-every", type=int, default=40, help="Evaluate every N steps")
parser.add_argument("--eval-examples", type=int, default=64, help="Maximum eval examples")
parser.add_argument("--save-every", type=int, default=40, help="Save every N steps")
args = parser.parse_args()
user_config = vars(args).copy()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-tool-rl", name=args.run, config=user_config)
model, tokenizer, meta = load_model("sft", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
engine = Engine(model, tokenizer)
train_task = ToolJSON(filepath=args.train_jsonl)
eval_task = ToolJSON(filepath=args.eval_jsonl or args.train_jsonl)
num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs
print0(f"Calculated number of steps: {num_steps}")
@torch.no_grad()
def get_batch():
assistant_end = tokenizer.encode_special("<|assistant_end|>")
rank_indices = range(ddp_rank, len(train_task), ddp_world_size)
for example_idx in itertools.cycle(rank_indices):
conversation = train_task[example_idx]
tokens = tokenizer.render_for_completion(conversation)
prefix_length = len(tokens)
model.eval()
generated_token_sequences = []
masks = []
num_sampling_steps = args.num_samples // args.device_batch_size
if args.num_samples % args.device_batch_size != 0:
raise ValueError("num_samples must be divisible by device_batch_size")
for sampling_step in range(num_sampling_steps):
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF
generated_batch, mask_batch = engine.generate_batch(
tokens,
num_samples=args.device_batch_size,
max_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
seed=seed,
)
generated_token_sequences.extend(generated_batch)
masks.extend(mask_batch)
rewards = []
for sample_tokens in generated_token_sequences:
generated_text = tokenizer.decode(sample_tokens[prefix_length:])
rewards.append(train_task.reward(conversation, generated_text))
max_length = max(len(seq) for seq in generated_token_sequences)
padded_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences]
padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
ids = torch.tensor(padded_sequences, dtype=torch.long, device=device)
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
inputs = ids[:, :-1]
targets = ids[:, 1:].clone()
targets[mask_ids[:, 1:] == 0] = -1
rewards_tensor = torch.tensor(rewards, dtype=torch.float, device=device)
advantages = rewards_tensor - rewards_tensor.mean()
yield generated_token_sequences, inputs, targets, rewards_tensor, advantages
@torch.no_grad()
def run_tool_eval(task, max_examples):
max_examples = min(max_examples, len(task))
total = 0
reward_sum = 0.0
passed = 0
for idx in range(ddp_rank, max_examples, ddp_world_size):
conversation = task[idx]
tokens = tokenizer.render_for_completion(conversation)
generated_sequences, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=args.max_new_tokens, temperature=0.0, top_k=args.top_k)
generated_text = tokenizer.decode(generated_sequences[0][len(tokens):])
reward_sum += task.reward(conversation, generated_text)
passed += task.evaluate(conversation, generated_text)
total += 1
reward_tensor = torch.tensor([reward_sum], dtype=torch.float, device=device)
passed_tensor = torch.tensor([passed], dtype=torch.long, device=device)
total_tensor = torch.tensor([total], dtype=torch.long, device=device)
if ddp:
dist.all_reduce(reward_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(passed_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
total = max(total_tensor.item(), 1)
return reward_tensor.item() / total, passed_tensor.item() / total
optimizer = model.setup_optimizer(
unembedding_lr=args.unembedding_lr,
embedding_lr=args.embedding_lr,
matrix_lr=args.matrix_lr,
weight_decay=args.weight_decay,
)
for group in optimizer.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"]
def get_lr_multiplier(it):
return 1.0 - it / max(num_steps, 1)
assert args.examples_per_step % ddp_world_size == 0, "examples_per_step must be divisible by number of ranks"
examples_per_rank = args.examples_per_step // ddp_world_size
batch_iterator = get_batch()
for step in range(num_steps):
if step % args.eval_every == 0:
model.eval()
mean_reward, pass_rate = run_tool_eval(eval_task, args.eval_examples)
print0(f"Step {step} | tool_eval_reward={mean_reward:.4f} | tool_eval_pass={pass_rate:.4f}")
wandb_run.log({"step": step, "tool_eval_reward": mean_reward, "tool_eval_pass": pass_rate})
rewards_list = []
for example_step in range(examples_per_rank):
_, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
model.train()
assert inputs_all.size(0) % args.device_batch_size == 0
num_passes = inputs_all.size(0) // args.device_batch_size
for pass_idx in range(num_passes):
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
inputs = inputs_all[b0:b1]
targets = targets_all[b0:b1]
rewards = rewards_all[b0:b1]
advantages = advantages_all[b0:b1]
logp = -model(inputs, targets, loss_reduction="none").view_as(inputs)
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
num_valid = (targets >= 0).sum().clamp(min=1)
pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
loss = -pg_obj
loss.backward()
print0(
f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} "
f"| loss={loss.item():.6f} | reward={rewards.mean().item():.4f}"
)
rewards_list.append(rewards_all.mean().item())
lrm = get_lr_multiplier(step)
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
optimizer.step()
model.zero_grad(set_to_none=True)
wandb_run.log({"step": step, "lrm": lrm, "mean_reward": sum(rewards_list) / max(len(rewards_list), 1)})
if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
base_dir = get_base_dir()
output_dirname = args.model_tag if args.model_tag else f"d{model.config.n_layer}"
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,
step,
model.state_dict(),
None,
{
"step": step,
"stage": "tool_rl",
"model_config": model.config.__dict__,
"user_config": user_config,
"tool_schema": DEFAULT_TOOL_SCHEMA,
"source_hf_repo": meta.get("source_hf_repo"),
"train_jsonl": args.train_jsonl,
"eval_jsonl": args.eval_jsonl,
},
)
print0(f"Saved tool RL checkpoint to {checkpoint_dir}")
wandb_run.finish()
compute_cleanup()