mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-08 00:39:50 +00:00
220 lines
10 KiB
Python
220 lines
10 KiB
Python
"""
|
|
Lightweight RL stage for tool-use tuning on local ToolJSON datasets.
|
|
|
|
1 GPU:
|
|
python -m scripts.chat_tool_rl --train-jsonl seed_data/tool_eval_seed.jsonl --eval-jsonl seed_data/tool_eval_seed.jsonl
|
|
|
|
8 GPUs:
|
|
torchrun --standalone --nproc_per_node=8 -m scripts.chat_tool_rl -- --train-jsonl seed_data/tool_eval_seed.jsonl --eval-jsonl seed_data/tool_eval_seed.jsonl
|
|
"""
|
|
|
|
import argparse
|
|
import itertools
|
|
import os
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import wandb
|
|
|
|
from nanochat.checkpoint_manager import load_model, save_checkpoint
|
|
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
|
|
from nanochat.engine import Engine
|
|
from nanochat.tools import DEFAULT_TOOL_SCHEMA
|
|
from tasks.tool_json import ToolJSON
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="RL tuning on local tool-use JSONL tasks")
|
|
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
|
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
|
parser.add_argument("--model-tag", type=str, default=None, help="SFT model tag to load from")
|
|
parser.add_argument("--model-step", type=int, default=None, help="SFT model step to load from")
|
|
parser.add_argument("--train-jsonl", type=str, required=True, help="Training ToolJSONL file")
|
|
parser.add_argument("--eval-jsonl", type=str, default=None, help="Evaluation ToolJSONL file")
|
|
parser.add_argument("--num-epochs", type=int, default=1, help="Number of epochs over ToolJSON")
|
|
parser.add_argument("--device-batch-size", type=int, default=8, help="Max batch size per forward pass")
|
|
parser.add_argument("--examples-per-step", type=int, default=8, help="Examples per optimization step across all ranks")
|
|
parser.add_argument("--num-samples", type=int, default=8, help="Number of samples per example")
|
|
parser.add_argument("--max-new-tokens", type=int, default=256, help="Max tokens to generate")
|
|
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
|
|
parser.add_argument("--top-k", type=int, default=50, help="Top-k sampling (0 disables)")
|
|
parser.add_argument("--embedding-lr", type=float, default=0.2, help="Embedding LR")
|
|
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="Unembedding LR")
|
|
parser.add_argument("--matrix-lr", type=float, default=0.02, help="Matrix LR")
|
|
parser.add_argument("--weight-decay", type=float, default=0.0, help="Weight decay")
|
|
parser.add_argument("--init-lr-frac", type=float, default=0.05, help="Initial LR multiplier")
|
|
parser.add_argument("--eval-every", type=int, default=40, help="Evaluate every N steps")
|
|
parser.add_argument("--eval-examples", type=int, default=64, help="Maximum eval examples")
|
|
parser.add_argument("--save-every", type=int, default=40, help="Save every N steps")
|
|
args = parser.parse_args()
|
|
user_config = vars(args).copy()
|
|
|
|
|
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
|
master_process = ddp_rank == 0
|
|
|
|
use_dummy_wandb = args.run == "dummy" or not master_process
|
|
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-tool-rl", name=args.run, config=user_config)
|
|
|
|
model, tokenizer, meta = load_model("sft", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
|
engine = Engine(model, tokenizer)
|
|
|
|
train_task = ToolJSON(filepath=args.train_jsonl)
|
|
eval_task = ToolJSON(filepath=args.eval_jsonl or args.train_jsonl)
|
|
num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs
|
|
print0(f"Calculated number of steps: {num_steps}")
|
|
|
|
|
|
@torch.no_grad()
|
|
def get_batch():
|
|
assistant_end = tokenizer.encode_special("<|assistant_end|>")
|
|
rank_indices = range(ddp_rank, len(train_task), ddp_world_size)
|
|
for example_idx in itertools.cycle(rank_indices):
|
|
conversation = train_task[example_idx]
|
|
tokens = tokenizer.render_for_completion(conversation)
|
|
prefix_length = len(tokens)
|
|
model.eval()
|
|
generated_token_sequences = []
|
|
masks = []
|
|
num_sampling_steps = args.num_samples // args.device_batch_size
|
|
if args.num_samples % args.device_batch_size != 0:
|
|
raise ValueError("num_samples must be divisible by device_batch_size")
|
|
for sampling_step in range(num_sampling_steps):
|
|
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF
|
|
generated_batch, mask_batch = engine.generate_batch(
|
|
tokens,
|
|
num_samples=args.device_batch_size,
|
|
max_tokens=args.max_new_tokens,
|
|
temperature=args.temperature,
|
|
top_k=args.top_k,
|
|
seed=seed,
|
|
)
|
|
generated_token_sequences.extend(generated_batch)
|
|
masks.extend(mask_batch)
|
|
|
|
rewards = []
|
|
for sample_tokens in generated_token_sequences:
|
|
generated_text = tokenizer.decode(sample_tokens[prefix_length:])
|
|
rewards.append(train_task.reward(conversation, generated_text))
|
|
|
|
max_length = max(len(seq) for seq in generated_token_sequences)
|
|
padded_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences]
|
|
padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
|
|
ids = torch.tensor(padded_sequences, dtype=torch.long, device=device)
|
|
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
|
|
inputs = ids[:, :-1]
|
|
targets = ids[:, 1:].clone()
|
|
targets[mask_ids[:, 1:] == 0] = -1
|
|
rewards_tensor = torch.tensor(rewards, dtype=torch.float, device=device)
|
|
advantages = rewards_tensor - rewards_tensor.mean()
|
|
yield generated_token_sequences, inputs, targets, rewards_tensor, advantages
|
|
|
|
|
|
@torch.no_grad()
|
|
def run_tool_eval(task, max_examples):
|
|
max_examples = min(max_examples, len(task))
|
|
total = 0
|
|
reward_sum = 0.0
|
|
passed = 0
|
|
for idx in range(ddp_rank, max_examples, ddp_world_size):
|
|
conversation = task[idx]
|
|
tokens = tokenizer.render_for_completion(conversation)
|
|
generated_sequences, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=args.max_new_tokens, temperature=0.0, top_k=args.top_k)
|
|
generated_text = tokenizer.decode(generated_sequences[0][len(tokens):])
|
|
reward_sum += task.reward(conversation, generated_text)
|
|
passed += task.evaluate(conversation, generated_text)
|
|
total += 1
|
|
|
|
reward_tensor = torch.tensor([reward_sum], dtype=torch.float, device=device)
|
|
passed_tensor = torch.tensor([passed], dtype=torch.long, device=device)
|
|
total_tensor = torch.tensor([total], dtype=torch.long, device=device)
|
|
if ddp:
|
|
dist.all_reduce(reward_tensor, op=dist.ReduceOp.SUM)
|
|
dist.all_reduce(passed_tensor, op=dist.ReduceOp.SUM)
|
|
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
|
|
total = max(total_tensor.item(), 1)
|
|
return reward_tensor.item() / total, passed_tensor.item() / total
|
|
|
|
|
|
optimizer = model.setup_optimizer(
|
|
unembedding_lr=args.unembedding_lr,
|
|
embedding_lr=args.embedding_lr,
|
|
matrix_lr=args.matrix_lr,
|
|
weight_decay=args.weight_decay,
|
|
)
|
|
for group in optimizer.param_groups:
|
|
group["lr"] = group["lr"] * args.init_lr_frac
|
|
group["initial_lr"] = group["lr"]
|
|
|
|
|
|
def get_lr_multiplier(it):
|
|
return 1.0 - it / max(num_steps, 1)
|
|
|
|
|
|
assert args.examples_per_step % ddp_world_size == 0, "examples_per_step must be divisible by number of ranks"
|
|
examples_per_rank = args.examples_per_step // ddp_world_size
|
|
batch_iterator = get_batch()
|
|
|
|
for step in range(num_steps):
|
|
if step % args.eval_every == 0:
|
|
model.eval()
|
|
mean_reward, pass_rate = run_tool_eval(eval_task, args.eval_examples)
|
|
print0(f"Step {step} | tool_eval_reward={mean_reward:.4f} | tool_eval_pass={pass_rate:.4f}")
|
|
wandb_run.log({"step": step, "tool_eval_reward": mean_reward, "tool_eval_pass": pass_rate})
|
|
|
|
rewards_list = []
|
|
for example_step in range(examples_per_rank):
|
|
_, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
|
|
model.train()
|
|
assert inputs_all.size(0) % args.device_batch_size == 0
|
|
num_passes = inputs_all.size(0) // args.device_batch_size
|
|
for pass_idx in range(num_passes):
|
|
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
|
|
inputs = inputs_all[b0:b1]
|
|
targets = targets_all[b0:b1]
|
|
rewards = rewards_all[b0:b1]
|
|
advantages = advantages_all[b0:b1]
|
|
logp = -model(inputs, targets, loss_reduction="none").view_as(inputs)
|
|
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
|
|
num_valid = (targets >= 0).sum().clamp(min=1)
|
|
pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
|
|
loss = -pg_obj
|
|
loss.backward()
|
|
print0(
|
|
f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} "
|
|
f"| loss={loss.item():.6f} | reward={rewards.mean().item():.4f}"
|
|
)
|
|
rewards_list.append(rewards_all.mean().item())
|
|
|
|
lrm = get_lr_multiplier(step)
|
|
for group in optimizer.param_groups:
|
|
group["lr"] = group["initial_lr"] * lrm
|
|
optimizer.step()
|
|
model.zero_grad(set_to_none=True)
|
|
wandb_run.log({"step": step, "lrm": lrm, "mean_reward": sum(rewards_list) / max(len(rewards_list), 1)})
|
|
|
|
if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
|
|
base_dir = get_base_dir()
|
|
output_dirname = args.model_tag if args.model_tag else f"d{model.config.n_layer}"
|
|
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
|
|
save_checkpoint(
|
|
checkpoint_dir,
|
|
step,
|
|
model.state_dict(),
|
|
None,
|
|
{
|
|
"step": step,
|
|
"stage": "tool_rl",
|
|
"model_config": model.config.__dict__,
|
|
"user_config": user_config,
|
|
"tool_schema": DEFAULT_TOOL_SCHEMA,
|
|
"source_hf_repo": meta.get("source_hf_repo"),
|
|
"train_jsonl": args.train_jsonl,
|
|
"eval_jsonl": args.eval_jsonl,
|
|
},
|
|
)
|
|
print0(f"Saved tool RL checkpoint to {checkpoint_dir}")
|
|
|
|
wandb_run.finish()
|
|
compute_cleanup()
|