nanochat/scripts/chat_cli.py
rehman 09bdfd6628 fix: truncate conversation tokens to model context window in chat_cli
Fixes #581. When conversation_tokens grows beyond model.config.sequence_len,
engine.generate() received a zero-dimension tensor and crashed with a matmul
shape error. Add a sliding window guard before each generate() call that keeps
the most recent (sequence_len - max_new_tokens) tokens, re-inserts bos to
preserve a well-formed sequence, and notifies the user when truncation occurs.
2026-05-06 01:29:20 +05:00

107 lines
4.3 KiB
Python

"""
New and upgraded chat mode because a lot of the code has changed since the last one.
Intended to be run single GPU only atm:
python -m scripts.chat_cli
"""
import argparse
import torch
from nanochat.common import compute_init, autodetect_device_type
from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
parser = argparse.ArgumentParser(description='Chat with the model')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
args = parser.parse_args()
# Init the model and tokenizer
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
# Special tokens for the chat state machine
bos = tokenizer.get_bos_token_id()
user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>")
# Create Engine for efficient generation
engine = Engine(model, tokenizer)
print("\nNanoChat Interactive Mode")
print("-" * 50)
print("Type 'quit' or 'exit' to end the conversation")
print("Type 'clear' to start a new conversation")
print("-" * 50)
conversation_tokens = [bos]
while True:
if args.prompt:
# Get the prompt from the launch command
user_input = args.prompt
else:
# Get the prompt interactively from the console
try:
user_input = input("\nUser: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nGoodbye!")
break
# Handle special commands
if user_input.lower() in ['quit', 'exit']:
print("Goodbye!")
break
if user_input.lower() == 'clear':
conversation_tokens = [bos]
print("Conversation cleared.")
continue
if not user_input:
continue
# Add User message to the conversation
conversation_tokens.append(user_start)
conversation_tokens.extend(tokenizer.encode(user_input))
conversation_tokens.append(user_end)
# Kick off the assistant
conversation_tokens.append(assistant_start)
max_new_tokens = 256
generate_kwargs = {
"num_samples": 1,
"max_tokens": max_new_tokens,
"temperature": args.temperature,
"top_k": args.top_k,
}
response_tokens = []
# Reserve max_new_tokens slots for the response; cap history to the remaining space and re-insert bos to keep the sequence well-formed
max_ctx = model.config.sequence_len - max_new_tokens
if len(conversation_tokens) > max_ctx:
conversation_tokens = [bos] + conversation_tokens[-(max_ctx - 1):]
print("[Context window full - older messages have been dropped]", flush=True)
print("\nAssistant: ", end="", flush=True)
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
token = token_column[0] # pop the batch dimension (num_samples=1)
response_tokens.append(token)
token_text = tokenizer.decode([token])
print(token_text, end="", flush=True)
print()
# we have to ensure that the assistant end token is the last token
# so even if generation ends due to max tokens, we have to append it to the end
if response_tokens[-1] != assistant_end:
response_tokens.append(assistant_end)
conversation_tokens.extend(response_tokens)
# In the prompt mode, we only want a single response and exit
if args.prompt:
break