mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-05 10:09:57 +00:00
Merge 8cfa0451f4 into 41bb2eac32
This commit is contained in:
commit
744d2aea86
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -11,3 +11,4 @@ eval_bundle/
|
|||
# Local setup
|
||||
CLAUDE.md
|
||||
wandb/
|
||||
*.egg-info/
|
||||
|
|
|
|||
|
|
@ -201,6 +201,9 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
|||
for t, s, e in zip(tokens, start_idxs, end_idxs):
|
||||
if len(t) > max_tokens:
|
||||
num_to_crop = len(t) - max_tokens
|
||||
# Take the last max_tokens tokens instead of the first ones.
|
||||
# The overly long questions are usually the few-shot contexts. They are placed
|
||||
# at the beginning of the sequence, so cropping from the start should be ok.
|
||||
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
|
||||
new_start_idxs.append(s - num_to_crop) # shift the indices down
|
||||
new_end_idxs.append(e - num_to_crop)
|
||||
|
|
@ -228,7 +231,11 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
|||
# predictions[i] predict input_ids[i+1] autoregressively
|
||||
predicted_tokens = predictions[0, si-1:ei-1]
|
||||
actual_tokens = input_ids[0, si:ei]
|
||||
is_correct = torch.all(predicted_tokens == actual_tokens).item()
|
||||
# Make the matching case-insensitive for LM tasks
|
||||
predicted_text = tokenizer.decode(predicted_tokens.cpu().tolist()).lower()
|
||||
actual_text = tokenizer.decode(actual_tokens.cpu().tolist()).lower()
|
||||
# is_correct = torch.all(predicted_tokens == actual_tokens).item()
|
||||
is_correct = (predicted_text == actual_text)
|
||||
elif task_type in ['multiple_choice', 'schema']:
|
||||
# For MC/schema: find the option with lowest average loss
|
||||
mean_losses = [losses[i, si-1:ei-1].mean().item()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user