mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 13:15:21 +00:00
Merge e19b8b8fe1 into 5019accc5b
This commit is contained in:
commit
9b7f533c2f
|
|
@ -11,7 +11,7 @@ from nanochat.engine import Engine
|
|||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
parser = argparse.ArgumentParser(description='Chat with the model')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", choices=["sft", "rl"], help="Source of the model: sft|rl")
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ if __name__ == "__main__":
|
|||
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
|
||||
parser.add_argument('-i', '--source', type=str, required=True, choices=["sft", "rl"], help="Source of the model: sft|rl")
|
||||
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.0)
|
||||
parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
|
||||
|
|
@ -210,7 +210,10 @@ if __name__ == "__main__":
|
|||
'HumanEval': 0.0, # open-ended => 0%
|
||||
'SpellingBee': 0.0, # open-ended => 0%
|
||||
}
|
||||
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
|
||||
task_names = all_tasks if args.task_name is None else [t.strip() for t in args.task_name.split('|')]
|
||||
for task_name in task_names:
|
||||
if task_name not in all_tasks:
|
||||
raise ValueError(f"Invalid task name: {task_name!r}. Choose from: {', '.join(all_tasks)}")
|
||||
|
||||
# Run all the task evaluations sequentially
|
||||
results = {}
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ MAX_MAX_TOKENS = 4096
|
|||
|
||||
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
||||
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", choices=["sft", "rl"], help="Source of the model: sft|rl")
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
|
||||
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
|
||||
|
|
@ -352,7 +352,7 @@ async def chat_completions(request: ChatRequest):
|
|||
top_k=request.top_k
|
||||
):
|
||||
# Accumulate response for logging
|
||||
chunk_data = json.loads(chunk.replace("data: ", "").strip())
|
||||
chunk_data = json.loads(chunk.removeprefix("data: ").strip())
|
||||
if "token" in chunk_data:
|
||||
response_tokens.append(chunk_data["token"])
|
||||
yield chunk
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user