mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-24 20:34:23 +00:00
When eval language_modeling tasks, be case insensitive to answers
This commit is contained in:
parent
e64aa82620
commit
8cfa0451f4
|
|
@ -71,6 +71,3 @@ conflicts = [
|
|||
{ extra = "gpu" },
|
||||
],
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["nanochat"]
|
||||
|
|
|
|||
|
|
@ -75,7 +75,6 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
# Evaluate each task
|
||||
results = {}
|
||||
centered_results = {}
|
||||
|
||||
for task in tasks:
|
||||
start_time = time.time()
|
||||
label = task['label']
|
||||
|
|
@ -131,24 +130,17 @@ class ModelWrapper:
|
|||
logits = outputs.logits
|
||||
return logits
|
||||
|
||||
def load_hf_model(hf_path: str, device, max_seq_len=None):
|
||||
def load_hf_model(hf_path: str, device):
|
||||
print0(f"Loading model from: {hf_path}")
|
||||
# Load the model
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
# Special case for GPT-2 community model, which can handle 1024 tokens.
|
||||
# If the argument is given, use that instead.
|
||||
if max_seq_len is None and "openai-community/gpt2" in hf_path:
|
||||
max_seq_len = 1024
|
||||
|
||||
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
|
||||
model = ModelWrapper(model, max_seq_len=max_seq_len)
|
||||
# Load the tokenizer
|
||||
if os.path.exists(hf_path):
|
||||
tokenizer = HuggingFaceTokenizer.from_directory(hf_path)
|
||||
else:
|
||||
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
|
||||
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
|
||||
return model, tokenizer
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -156,8 +148,6 @@ def main():
|
|||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||
parser.add_argument('--max_seq_len', type=int, default=None,
|
||||
help='Optional max sequence length for the model')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||
parser.add_argument('--model-tag', type=str, default=None, help='optional model tag for the output directory name')
|
||||
parser.add_argument('--step', type=str, default=None, help='optional model step for the output directory name')
|
||||
|
|
@ -173,7 +163,7 @@ def main():
|
|||
# atm assume that if a path is given, it's a huggingface model path
|
||||
hf_path = args.hf_path
|
||||
print0(f"Loading huggingface model from: {hf_path}")
|
||||
model, tokenizer = load_hf_model(hf_path, device, max_seq_len=args.max_seq_len)
|
||||
model, tokenizer = load_hf_model(hf_path, device)
|
||||
model_name = hf_path # just for logging
|
||||
model_slug = hf_path.replace("/", "-") # for the output csv file
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user