fix: truncate conversation tokens to model context window in chat_cli

Fixes #581. When conversation_tokens grows beyond model.config.sequence_len,
engine.generate() received a zero-dimension tensor and crashed with a matmul
shape error. Add a sliding window guard before each generate() call that keeps
the most recent (sequence_len - max_new_tokens) tokens, re-inserts bos to
preserve a well-formed sequence, and notifies the user when truncation occurs.
This commit is contained in:
rehman 2026-05-06 01:29:20 +05:00
parent dc54a1a307
commit 09bdfd6628
2 changed files with 54 additions and 1 deletions

View File

@ -75,13 +75,19 @@ while True:
# Kick off the assistant
conversation_tokens.append(assistant_start)
max_new_tokens = 256
generate_kwargs = {
"num_samples": 1,
"max_tokens": 256,
"max_tokens": max_new_tokens,
"temperature": args.temperature,
"top_k": args.top_k,
}
response_tokens = []
# Reserve max_new_tokens slots for the response; cap history to the remaining space and re-insert bos to keep the sequence well-formed
max_ctx = model.config.sequence_len - max_new_tokens
if len(conversation_tokens) > max_ctx:
conversation_tokens = [bos] + conversation_tokens[-(max_ctx - 1):]
print("[Context window full - older messages have been dropped]", flush=True)
print("\nAssistant: ", end="", flush=True)
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
token = token_column[0] # pop the batch dimension (num_samples=1)

View File

@ -0,0 +1,47 @@
SEQUENCE_LEN = 64
MAX_TOKENS = 10
MAX_CTX = SEQUENCE_LEN - MAX_TOKENS # 54
def apply_sliding_window(conversation_tokens, sequence_len, max_tokens):
# Local mirror of the guard added to scripts/chat_cli.py.
# TODO: once Task 3 lands, import the real function instead of duplicating.
max_ctx = sequence_len - max_tokens
if len(conversation_tokens) > max_ctx:
return conversation_tokens[-max_ctx:]
return conversation_tokens
def test_truncation_removes_oldest_tokens_when_over_limit():
long_tokens = list(range(80))
result = apply_sliding_window(long_tokens, SEQUENCE_LEN, MAX_TOKENS)
assert len(result) == MAX_CTX
assert result == long_tokens[-MAX_CTX:]
def test_truncation_leaves_short_conversation_unchanged():
short_tokens = list(range(20))
result = apply_sliding_window(short_tokens, SEQUENCE_LEN, MAX_TOKENS)
assert result == short_tokens
def test_truncation_at_exact_boundary_leaves_unchanged():
boundary_tokens = list(range(MAX_CTX))
result = apply_sliding_window(boundary_tokens, SEQUENCE_LEN, MAX_TOKENS)
assert result == boundary_tokens
def test_truncation_one_over_boundary_drops_oldest_token():
tokens = list(range(MAX_CTX + 1))
result = apply_sliding_window(tokens, SEQUENCE_LEN, MAX_TOKENS)
assert len(result) == MAX_CTX
assert result[0] == 1
assert result[-1] == MAX_CTX