Add max_seq_len argument for gpt2

This commit is contained in:
askerlee 2026-01-14 14:19:20 +08:00
parent d6829284c4
commit bf067e2a66

View File

@ -75,6 +75,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
# Evaluate each task
results = {}
centered_results = {}
for task in tasks:
start_time = time.time()
label = task['label']
@ -130,14 +131,18 @@ class ModelWrapper:
logits = outputs.logits
return logits
def load_hf_model(hf_path: str, device):
def load_hf_model(hf_path: str, device, max_seq_len=None):
print0(f"Loading model from: {hf_path}")
# Load the model
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(hf_path, trust_remote_code=True)
model.to(device)
model.eval()
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
# Special case for GPT-2 community model, which can handle 1024 tokens.
# If the argument is given, use that instead.
if max_seq_len is None and "openai-community/gpt2" in hf_path:
max_seq_len = 1024
model = ModelWrapper(model, max_seq_len=max_seq_len)
# Load the tokenizer
if os.path.exists(hf_path):
@ -151,6 +156,8 @@ def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
parser.add_argument('--max_seq_len', type=int, default=None,
help='Optional max sequence length for the model')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
parser.add_argument('--model-tag', type=str, default=None, help='optional model tag for the output directory name')
parser.add_argument('--step', type=str, default=None, help='optional model step for the output directory name')
@ -166,7 +173,7 @@ def main():
# atm assume that if a path is given, it's a huggingface model path
hf_path = args.hf_path
print0(f"Loading huggingface model from: {hf_path}")
model, tokenizer = load_hf_model(hf_path, device)
model, tokenizer = load_hf_model(hf_path, device, max_seq_len=args.max_seq_len)
model_name = hf_path # just for logging
model_slug = hf_path.replace("/", "-") # for the output csv file
else: