mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-29 16:15:13 +00:00
Add max_seq_len argument for gpt2
This commit is contained in:
parent
d6829284c4
commit
bf067e2a66
|
|
@ -75,6 +75,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
# Evaluate each task
|
||||
results = {}
|
||||
centered_results = {}
|
||||
|
||||
for task in tasks:
|
||||
start_time = time.time()
|
||||
label = task['label']
|
||||
|
|
@ -130,14 +131,18 @@ class ModelWrapper:
|
|||
logits = outputs.logits
|
||||
return logits
|
||||
|
||||
def load_hf_model(hf_path: str, device):
|
||||
def load_hf_model(hf_path: str, device, max_seq_len=None):
|
||||
print0(f"Loading model from: {hf_path}")
|
||||
# Load the model
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path, trust_remote_code=True)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
|
||||
# Special case for GPT-2 community model, which can handle 1024 tokens.
|
||||
# If the argument is given, use that instead.
|
||||
if max_seq_len is None and "openai-community/gpt2" in hf_path:
|
||||
max_seq_len = 1024
|
||||
|
||||
model = ModelWrapper(model, max_seq_len=max_seq_len)
|
||||
# Load the tokenizer
|
||||
if os.path.exists(hf_path):
|
||||
|
|
@ -151,6 +156,8 @@ def main():
|
|||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||
parser.add_argument('--max_seq_len', type=int, default=None,
|
||||
help='Optional max sequence length for the model')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||
parser.add_argument('--model-tag', type=str, default=None, help='optional model tag for the output directory name')
|
||||
parser.add_argument('--step', type=str, default=None, help='optional model step for the output directory name')
|
||||
|
|
@ -166,7 +173,7 @@ def main():
|
|||
# atm assume that if a path is given, it's a huggingface model path
|
||||
hf_path = args.hf_path
|
||||
print0(f"Loading huggingface model from: {hf_path}")
|
||||
model, tokenizer = load_hf_model(hf_path, device)
|
||||
model, tokenizer = load_hf_model(hf_path, device, max_seq_len=args.max_seq_len)
|
||||
model_name = hf_path # just for logging
|
||||
model_slug = hf_path.replace("/", "-") # for the output csv file
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user