diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index 1a09962..ad557b9 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -19,7 +19,6 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default import argparse import os import itertools -import re import wandb import torch import torch.distributed as dist @@ -174,7 +173,7 @@ def run_gsm8k_eval(task, tokenizer, engine, tokens = tokenizer.render_for_completion(conversation) prefix_length = len(tokens) # Generate k samples using batched generation inside the Engine - assert num_samples <= device_batch_size # usually this is true. we can add a loop if not... + assert num_samples <= args.device_batch_size # usually this is true. we can add a loop if not... generated_token_sequences, masks = engine.generate_batch( tokens, num_samples=num_samples,