mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
delete pandas dep in base_eval use csv instead
This commit is contained in:
parent
ad39db5a23
commit
7d2c4a3d95
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Evlauate the CORE metric for a given model.
|
Evaluate the CORE metric for a given model.
|
||||||
|
|
||||||
Run on a single GPU:
|
Run on a single GPU:
|
||||||
python base_eval.py
|
python base_eval.py
|
||||||
|
|
@ -10,14 +10,13 @@ torchrun --nproc_per_node=8 base_eval.py
|
||||||
The script will print the CORE metric to the console.
|
The script will print the CORE metric to the console.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import sys
|
import csv
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import yaml
|
import yaml
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
|
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
|
||||||
|
|
@ -26,13 +25,12 @@ from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.core_eval import evaluate_task
|
from nanochat.core_eval import evaluate_task
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# nanoChat specific function dealing with I/O etc.
|
# nanochat specific function dealing with I/O etc.
|
||||||
|
|
||||||
def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||||
"""
|
"""
|
||||||
Evaluate a base model on the CORE benchmark.
|
Evaluate a base model on the CORE benchmark.
|
||||||
- max_per_task: crop the data to this many examples per task for testing (-1 = disable)
|
- max_per_task: crop the data to this many examples per task for testing (-1 = disable)
|
||||||
TODO: clean up this function, delete the need for all the files, for pandas dependency, etc.
|
|
||||||
"""
|
"""
|
||||||
# Load config and task metadata
|
# Load config and task metadata
|
||||||
base_dir = get_base_dir()
|
base_dir = get_base_dir()
|
||||||
|
|
@ -43,7 +41,15 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
config = yaml.safe_load(f)
|
config = yaml.safe_load(f)
|
||||||
tasks = config['icl_tasks']
|
tasks = config['icl_tasks']
|
||||||
eval_metadata = pd.read_csv(eval_meta_data)
|
|
||||||
|
# Load random baseline values from eval metadata
|
||||||
|
random_baselines = {}
|
||||||
|
with open(eval_meta_data, 'r', encoding='utf-8') as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
task_name = row['Eval Task']
|
||||||
|
random_baseline = row['Random baseline']
|
||||||
|
random_baselines[task_name] = float(random_baseline)
|
||||||
|
|
||||||
# Evaluate each task
|
# Evaluate each task
|
||||||
results = {}
|
results = {}
|
||||||
|
|
@ -75,8 +81,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||||
accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
|
accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
|
||||||
|
|
||||||
results[label] = accuracy
|
results[label] = accuracy
|
||||||
row = eval_metadata[eval_metadata["Eval Task"] == label]
|
random_baseline = random_baselines[label]
|
||||||
random_baseline = row["Random baseline"].values[0]
|
|
||||||
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
||||||
centered_results[label] = centered_result
|
centered_results[label] = centered_result
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user