diff --git a/scripts/base_eval.py b/scripts/base_eval.py index bd83ff3..f3d604c 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -96,7 +96,8 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): shuffle_rng = random.Random(1337) shuffle_rng.shuffle(data) if max_per_task > 0: - data = data[:max_per_task] + data_cutoff = max(task_meta['num_fewshot']+1, max_per_task) + data = data[:data_cutoff] # run the evaluation for this task accuracy = evaluate_task(model, tokenizer, data, device, task_meta)