mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-07 09:50:28 +00:00
small touchups to the eval script, re-order items etc, cosmetic
This commit is contained in:
parent
72b9064f9d
commit
8ebc14b348
|
|
@ -73,7 +73,7 @@ def load_hf_model(hf_path: str, device):
|
|||
model = AutoModelForCausalLM.from_pretrained(hf_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
|
||||
max_seq_len = 1024 if "gpt2" in hf_path else None
|
||||
model = ModelWrapper(model, max_seq_len=max_seq_len)
|
||||
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
|
||||
return model, tokenizer
|
||||
|
|
@ -180,7 +180,7 @@ def evaluate_core(model, tokenizer, device, max_per_task=-1):
|
|||
def main():
|
||||
parser = argparse.ArgumentParser(description="Base model evaluation")
|
||||
parser.add_argument('--eval', type=str, default='core,bpb,sample', help='Comma-separated evaluations to run: core,bpb,sample (default: all)')
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path (e.g. openai-community/gpt2)')
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path (e.g. openai-community/gpt2-xl)')
|
||||
parser.add_argument('--model-tag', type=str, default=None, help='nanochat model tag to identify the checkpoint directory')
|
||||
parser.add_argument('--step', type=int, default=None, help='Model step to load (default = last)')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per CORE task (-1 = all)')
|
||||
|
|
@ -225,48 +225,6 @@ def main():
|
|||
samples = []
|
||||
unconditioned_samples = []
|
||||
|
||||
# --- CORE evaluation ---
|
||||
if 'core' in eval_modes:
|
||||
print0("\n" + "="*80)
|
||||
print0("CORE Evaluation")
|
||||
print0("="*80)
|
||||
with autocast_ctx:
|
||||
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
|
||||
# Write CSV output
|
||||
if ddp_rank == 0:
|
||||
base_dir = get_base_dir()
|
||||
output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
|
||||
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
|
||||
with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
|
||||
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
|
||||
for label in core_results["results"]:
|
||||
acc = core_results["results"][label]
|
||||
centered = core_results["centered_results"][label]
|
||||
f.write(f"{label:<35}, {acc:<10.6f}, {centered:<10.6f}\n")
|
||||
f.write(f"{'CORE':<35}, {'':<10}, {core_results['core_metric']:<10.6f}\n")
|
||||
print0(f"\nResults written to: {output_csv_path}")
|
||||
print0(f"CORE metric: {core_results['core_metric']:.4f}")
|
||||
|
||||
# --- BPB evaluation ---
|
||||
if 'bpb' in eval_modes:
|
||||
print0("\n" + "="*80)
|
||||
print0("BPB Evaluation")
|
||||
print0("="*80)
|
||||
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
|
||||
if args.split_tokens % tokens_per_step != 0:
|
||||
# Adjust to nearest multiple
|
||||
args.split_tokens = (args.split_tokens // tokens_per_step) * tokens_per_step
|
||||
print0(f"Adjusted split_tokens to {args.split_tokens} (must be divisible by {tokens_per_step})")
|
||||
steps = args.split_tokens // tokens_per_step
|
||||
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
bpb_results[split_name] = bpb
|
||||
print0(f"{split_name} bpb: {bpb:.6f}")
|
||||
|
||||
# --- Sampling ---
|
||||
if 'sample' in eval_modes and not is_hf_model:
|
||||
print0("\n" + "="*80)
|
||||
|
|
@ -305,6 +263,48 @@ def main():
|
|||
elif 'sample' in eval_modes and is_hf_model:
|
||||
print0("\nSkipping sampling for HuggingFace models (not supported)")
|
||||
|
||||
# --- BPB evaluation ---
|
||||
if 'bpb' in eval_modes:
|
||||
print0("\n" + "="*80)
|
||||
print0("BPB Evaluation")
|
||||
print0("="*80)
|
||||
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
|
||||
if args.split_tokens % tokens_per_step != 0:
|
||||
# Adjust to nearest multiple
|
||||
args.split_tokens = (args.split_tokens // tokens_per_step) * tokens_per_step
|
||||
print0(f"Adjusted split_tokens to {args.split_tokens} (must be divisible by {tokens_per_step})")
|
||||
steps = args.split_tokens // tokens_per_step
|
||||
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
bpb_results[split_name] = bpb
|
||||
print0(f"{split_name} bpb: {bpb:.6f}")
|
||||
|
||||
# --- CORE evaluation ---
|
||||
if 'core' in eval_modes:
|
||||
print0("\n" + "="*80)
|
||||
print0("CORE Evaluation")
|
||||
print0("="*80)
|
||||
with autocast_ctx:
|
||||
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
|
||||
# Write CSV output
|
||||
if ddp_rank == 0:
|
||||
base_dir = get_base_dir()
|
||||
output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
|
||||
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
|
||||
with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
|
||||
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
|
||||
for label in core_results["results"]:
|
||||
acc = core_results["results"][label]
|
||||
centered = core_results["centered_results"][label]
|
||||
f.write(f"{label:<35}, {acc:<10.6f}, {centered:<10.6f}\n")
|
||||
f.write(f"{'CORE':<35}, {'':<10}, {core_results['core_metric']:<10.6f}\n")
|
||||
print0(f"\nResults written to: {output_csv_path}")
|
||||
print0(f"CORE metric: {core_results['core_metric']:.4f}")
|
||||
|
||||
# --- Log to report ---
|
||||
from nanochat.report import get_report
|
||||
report_data = [{"model": model_name}]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user