diff --git a/.gitignore b/.gitignore index 3e92824..e2c6f53 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ eval_bundle/ # Local setup CLAUDE.md wandb/ +*.egg-info/ diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py index f3c9a9f..ff63c3e 100644 --- a/nanochat/core_eval.py +++ b/nanochat/core_eval.py @@ -201,6 +201,9 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta): for t, s, e in zip(tokens, start_idxs, end_idxs): if len(t) > max_tokens: num_to_crop = len(t) - max_tokens + # Take the last max_tokens tokens instead of the first ones. + # The overly long questions are usually the few-shot contexts. They are placed + # at the beginning of the sequence, so cropping from the start should be ok. new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens new_start_idxs.append(s - num_to_crop) # shift the indices down new_end_idxs.append(e - num_to_crop) @@ -228,7 +231,11 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta): # predictions[i] predict input_ids[i+1] autoregressively predicted_tokens = predictions[0, si-1:ei-1] actual_tokens = input_ids[0, si:ei] - is_correct = torch.all(predicted_tokens == actual_tokens).item() + # Make the matching case-insensitive for LM tasks + predicted_text = tokenizer.decode(predicted_tokens.cpu().tolist()).lower() + actual_text = tokenizer.decode(actual_tokens.cpu().tolist()).lower() + # is_correct = torch.all(predicted_tokens == actual_tokens).item() + is_correct = (predicted_text == actual_text) elif task_type in ['multiple_choice', 'schema']: # For MC/schema: find the option with lowest average loss mean_losses = [losses[i, si-1:ei-1].mean().item()