diff --git a/scripts/base_eval.py b/scripts/base_eval.py index 57f9fd4..a54674d 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -156,7 +156,8 @@ def evaluate_core(model, tokenizer, device, max_per_task=-1): shuffle_rng = random.Random(1337) shuffle_rng.shuffle(data) if max_per_task > 0: - data = data[:max_per_task] + data_cutoff = max(task_meta['num_fewshot']+1, max_per_task) + data = data[:data_cutoff] accuracy = evaluate_task(model, tokenizer, data, device, task_meta) results[label] = accuracy