From bf067e2a664e7286b9590abe5c5eb8378210cca6 Mon Sep 17 00:00:00 2001 From: askerlee Date: Wed, 14 Jan 2026 14:19:20 +0800 Subject: [PATCH] Add max_seq_len argument for gpt2 --- scripts/base_eval.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/scripts/base_eval.py b/scripts/base_eval.py index 672faec..0b6c888 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -75,6 +75,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): # Evaluate each task results = {} centered_results = {} + for task in tasks: start_time = time.time() label = task['label'] @@ -130,14 +131,18 @@ class ModelWrapper: logits = outputs.logits return logits -def load_hf_model(hf_path: str, device): +def load_hf_model(hf_path: str, device, max_seq_len=None): print0(f"Loading model from: {hf_path}") # Load the model from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(hf_path, trust_remote_code=True) model.to(device) model.eval() - max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None + # Special case for GPT-2 community model, which can handle 1024 tokens. + # If the argument is given, use that instead. + if max_seq_len is None and "openai-community/gpt2" in hf_path: + max_seq_len = 1024 + model = ModelWrapper(model, max_seq_len=max_seq_len) # Load the tokenizer if os.path.exists(hf_path): @@ -151,6 +156,8 @@ def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate') + parser.add_argument('--max_seq_len', type=int, default=None, + help='Optional max sequence length for the model') parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)') parser.add_argument('--model-tag', type=str, default=None, help='optional model tag for the output directory name') parser.add_argument('--step', type=str, default=None, help='optional model step for the output directory name') @@ -166,7 +173,7 @@ def main(): # atm assume that if a path is given, it's a huggingface model path hf_path = args.hf_path print0(f"Loading huggingface model from: {hf_path}") - model, tokenizer = load_hf_model(hf_path, device) + model, tokenizer = load_hf_model(hf_path, device, max_seq_len=args.max_seq_len) model_name = hf_path # just for logging model_slug = hf_path.replace("/", "-") # for the output csv file else: