This commit is contained in:
askerlee 2026-01-29 16:14:53 +01:00 committed by GitHub
commit 744d2aea86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 1 deletions

1
.gitignore vendored
View File

@ -11,3 +11,4 @@ eval_bundle/
# Local setup
CLAUDE.md
wandb/
*.egg-info/

View File

@ -201,6 +201,9 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
for t, s, e in zip(tokens, start_idxs, end_idxs):
if len(t) > max_tokens:
num_to_crop = len(t) - max_tokens
# Take the last max_tokens tokens instead of the first ones.
# The overly long questions are usually the few-shot contexts. They are placed
# at the beginning of the sequence, so cropping from the start should be ok.
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
new_start_idxs.append(s - num_to_crop) # shift the indices down
new_end_idxs.append(e - num_to_crop)
@ -228,7 +231,11 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
# predictions[i] predict input_ids[i+1] autoregressively
predicted_tokens = predictions[0, si-1:ei-1]
actual_tokens = input_ids[0, si:ei]
is_correct = torch.all(predicted_tokens == actual_tokens).item()
# Make the matching case-insensitive for LM tasks
predicted_text = tokenizer.decode(predicted_tokens.cpu().tolist()).lower()
actual_text = tokenizer.decode(actual_tokens.cpu().tolist()).lower()
# is_correct = torch.all(predicted_tokens == actual_tokens).item()
is_correct = (predicted_text == actual_text)
elif task_type in ['multiple_choice', 'schema']:
# For MC/schema: find the option with lowest average loss
mean_losses = [losses[i, si-1:ei-1].mean().item()