fix(chat_cli,chat_eval): add --source choices and validate --task-name to avoid KeyError

This commit is contained in:
dhunganapramod9 2026-03-01 09:02:57 -05:00
parent 20c385e8f7
commit 37d23102b6
2 changed files with 6 additions and 3 deletions

View File

@ -12,7 +12,7 @@ from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
parser = argparse.ArgumentParser(description='Chat with the model')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
parser.add_argument('-i', '--source', type=str, default="sft", choices=["sft", "rl"], help="Source of the model: sft|rl")
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')

View File

@ -183,7 +183,7 @@ if __name__ == "__main__":
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
parser.add_argument('-i', '--source', type=str, required=True, choices=["sft", "rl"], help="Source of the model: sft|rl")
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('-t', '--temperature', type=float, default=0.0)
@ -215,7 +215,10 @@ if __name__ == "__main__":
'HumanEval': 0.0, # open-ended => 0%
'SpellingBee': 0.0, # open-ended => 0%
}
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
task_names = all_tasks if args.task_name is None else [t.strip() for t in args.task_name.split('|')]
for task_name in task_names:
if task_name not in all_tasks:
raise ValueError(f"Invalid task name: {task_name!r}. Choose from: {', '.join(all_tasks)}")
# Run all the task evaluations sequentially
results = {}