mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-07 00:09:50 +00:00
fix: truncate conversation tokens to model context window in chat_cli
Fixes #581. When conversation_tokens grows beyond model.config.sequence_len, engine.generate() received a zero-dimension tensor and crashed with a matmul shape error. Add a sliding window guard before each generate() call that keeps the most recent (sequence_len - max_new_tokens) tokens, re-inserts bos to preserve a well-formed sequence, and notifies the user when truncation occurs.
This commit is contained in:
parent
dc54a1a307
commit
09bdfd6628
|
|
@ -75,13 +75,19 @@ while True:
|
|||
|
||||
# Kick off the assistant
|
||||
conversation_tokens.append(assistant_start)
|
||||
max_new_tokens = 256
|
||||
generate_kwargs = {
|
||||
"num_samples": 1,
|
||||
"max_tokens": 256,
|
||||
"max_tokens": max_new_tokens,
|
||||
"temperature": args.temperature,
|
||||
"top_k": args.top_k,
|
||||
}
|
||||
response_tokens = []
|
||||
# Reserve max_new_tokens slots for the response; cap history to the remaining space and re-insert bos to keep the sequence well-formed
|
||||
max_ctx = model.config.sequence_len - max_new_tokens
|
||||
if len(conversation_tokens) > max_ctx:
|
||||
conversation_tokens = [bos] + conversation_tokens[-(max_ctx - 1):]
|
||||
print("[Context window full - older messages have been dropped]", flush=True)
|
||||
print("\nAssistant: ", end="", flush=True)
|
||||
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
|
||||
token = token_column[0] # pop the batch dimension (num_samples=1)
|
||||
|
|
|
|||
47
tests/test_context_truncation.py
Normal file
47
tests/test_context_truncation.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
SEQUENCE_LEN = 64
|
||||
MAX_TOKENS = 10
|
||||
MAX_CTX = SEQUENCE_LEN - MAX_TOKENS # 54
|
||||
|
||||
|
||||
def apply_sliding_window(conversation_tokens, sequence_len, max_tokens):
|
||||
# Local mirror of the guard added to scripts/chat_cli.py.
|
||||
# TODO: once Task 3 lands, import the real function instead of duplicating.
|
||||
max_ctx = sequence_len - max_tokens
|
||||
if len(conversation_tokens) > max_ctx:
|
||||
return conversation_tokens[-max_ctx:]
|
||||
return conversation_tokens
|
||||
|
||||
|
||||
def test_truncation_removes_oldest_tokens_when_over_limit():
|
||||
long_tokens = list(range(80))
|
||||
|
||||
result = apply_sliding_window(long_tokens, SEQUENCE_LEN, MAX_TOKENS)
|
||||
|
||||
assert len(result) == MAX_CTX
|
||||
assert result == long_tokens[-MAX_CTX:]
|
||||
|
||||
|
||||
def test_truncation_leaves_short_conversation_unchanged():
|
||||
short_tokens = list(range(20))
|
||||
|
||||
result = apply_sliding_window(short_tokens, SEQUENCE_LEN, MAX_TOKENS)
|
||||
|
||||
assert result == short_tokens
|
||||
|
||||
|
||||
def test_truncation_at_exact_boundary_leaves_unchanged():
|
||||
boundary_tokens = list(range(MAX_CTX))
|
||||
|
||||
result = apply_sliding_window(boundary_tokens, SEQUENCE_LEN, MAX_TOKENS)
|
||||
|
||||
assert result == boundary_tokens
|
||||
|
||||
|
||||
def test_truncation_one_over_boundary_drops_oldest_token():
|
||||
tokens = list(range(MAX_CTX + 1))
|
||||
|
||||
result = apply_sliding_window(tokens, SEQUENCE_LEN, MAX_TOKENS)
|
||||
|
||||
assert len(result) == MAX_CTX
|
||||
assert result[0] == 1
|
||||
assert result[-1] == MAX_CTX
|
||||
Loading…
Reference in New Issue
Block a user