From 09bdfd66281dfdaade9aaeb0e9871cdc602cc87a Mon Sep 17 00:00:00 2001 From: rehman Date: Wed, 6 May 2026 01:29:20 +0500 Subject: [PATCH] fix: truncate conversation tokens to model context window in chat_cli Fixes #581. When conversation_tokens grows beyond model.config.sequence_len, engine.generate() received a zero-dimension tensor and crashed with a matmul shape error. Add a sliding window guard before each generate() call that keeps the most recent (sequence_len - max_new_tokens) tokens, re-inserts bos to preserve a well-formed sequence, and notifies the user when truncation occurs. --- scripts/chat_cli.py | 8 +++++- tests/test_context_truncation.py | 47 ++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 tests/test_context_truncation.py diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index 2bcc8aad..91a11f4b 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -75,13 +75,19 @@ while True: # Kick off the assistant conversation_tokens.append(assistant_start) + max_new_tokens = 256 generate_kwargs = { "num_samples": 1, - "max_tokens": 256, + "max_tokens": max_new_tokens, "temperature": args.temperature, "top_k": args.top_k, } response_tokens = [] + # Reserve max_new_tokens slots for the response; cap history to the remaining space and re-insert bos to keep the sequence well-formed + max_ctx = model.config.sequence_len - max_new_tokens + if len(conversation_tokens) > max_ctx: + conversation_tokens = [bos] + conversation_tokens[-(max_ctx - 1):] + print("[Context window full - older messages have been dropped]", flush=True) print("\nAssistant: ", end="", flush=True) for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs): token = token_column[0] # pop the batch dimension (num_samples=1) diff --git a/tests/test_context_truncation.py b/tests/test_context_truncation.py new file mode 100644 index 00000000..bf917238 --- /dev/null +++ b/tests/test_context_truncation.py @@ -0,0 +1,47 @@ +SEQUENCE_LEN = 64 +MAX_TOKENS = 10 +MAX_CTX = SEQUENCE_LEN - MAX_TOKENS # 54 + + +def apply_sliding_window(conversation_tokens, sequence_len, max_tokens): + # Local mirror of the guard added to scripts/chat_cli.py. + # TODO: once Task 3 lands, import the real function instead of duplicating. + max_ctx = sequence_len - max_tokens + if len(conversation_tokens) > max_ctx: + return conversation_tokens[-max_ctx:] + return conversation_tokens + + +def test_truncation_removes_oldest_tokens_when_over_limit(): + long_tokens = list(range(80)) + + result = apply_sliding_window(long_tokens, SEQUENCE_LEN, MAX_TOKENS) + + assert len(result) == MAX_CTX + assert result == long_tokens[-MAX_CTX:] + + +def test_truncation_leaves_short_conversation_unchanged(): + short_tokens = list(range(20)) + + result = apply_sliding_window(short_tokens, SEQUENCE_LEN, MAX_TOKENS) + + assert result == short_tokens + + +def test_truncation_at_exact_boundary_leaves_unchanged(): + boundary_tokens = list(range(MAX_CTX)) + + result = apply_sliding_window(boundary_tokens, SEQUENCE_LEN, MAX_TOKENS) + + assert result == boundary_tokens + + +def test_truncation_one_over_boundary_drops_oldest_token(): + tokens = list(range(MAX_CTX + 1)) + + result = apply_sliding_window(tokens, SEQUENCE_LEN, MAX_TOKENS) + + assert len(result) == MAX_CTX + assert result[0] == 1 + assert result[-1] == MAX_CTX