This commit is contained in:
陈家名 2026-03-27 20:58:37 +00:00 committed by GitHub
commit 68f7ce5147
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 4 deletions

View File

@ -118,7 +118,11 @@ class KVCache:
def advance(self, num_tokens):
"""Advance the cache position by num_tokens."""
self.cache_seqlens += num_tokens
# Validate that we don't exceed max sequence length
new_seqlens = self.cache_seqlens + num_tokens
if torch.any(new_seqlens > self.max_seq_len):
raise ValueError(f"Cache overflow: attempted to advance beyond max_seq_len={self.max_seq_len}")
self.cache_seqlens.copy_(new_seqlens)
def prefill(self, other):
"""

View File

@ -232,15 +232,16 @@ class RustBPETokenizer:
if isinstance(text, str):
ids = self.enc.encode_ordinary(text)
# Use list concatenation instead of insert(0, ...) for O(1) prepend
if prepend is not None:
ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
ids = [prepend_id] + ids
if append is not None:
ids.append(append_id)
elif isinstance(text, list):
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
# Use list concatenation instead of insert(0, ...) for O(1) prepend per row
if prepend is not None:
for ids_row in ids:
ids_row.insert(0, prepend_id) # TODO: same
ids = [[prepend_id] + row for row in ids]
if append is not None:
for ids_row in ids:
ids_row.append(append_id)