mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 21:25:21 +00:00
Merge ed565be892 into a445144d39
This commit is contained in:
commit
68f7ce5147
|
|
@ -118,7 +118,11 @@ class KVCache:
|
|||
|
||||
def advance(self, num_tokens):
|
||||
"""Advance the cache position by num_tokens."""
|
||||
self.cache_seqlens += num_tokens
|
||||
# Validate that we don't exceed max sequence length
|
||||
new_seqlens = self.cache_seqlens + num_tokens
|
||||
if torch.any(new_seqlens > self.max_seq_len):
|
||||
raise ValueError(f"Cache overflow: attempted to advance beyond max_seq_len={self.max_seq_len}")
|
||||
self.cache_seqlens.copy_(new_seqlens)
|
||||
|
||||
def prefill(self, other):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -232,15 +232,16 @@ class RustBPETokenizer:
|
|||
|
||||
if isinstance(text, str):
|
||||
ids = self.enc.encode_ordinary(text)
|
||||
# Use list concatenation instead of insert(0, ...) for O(1) prepend
|
||||
if prepend is not None:
|
||||
ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
|
||||
ids = [prepend_id] + ids
|
||||
if append is not None:
|
||||
ids.append(append_id)
|
||||
elif isinstance(text, list):
|
||||
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
|
||||
# Use list concatenation instead of insert(0, ...) for O(1) prepend per row
|
||||
if prepend is not None:
|
||||
for ids_row in ids:
|
||||
ids_row.insert(0, prepend_id) # TODO: same
|
||||
ids = [[prepend_id] + row for row in ids]
|
||||
if append is not None:
|
||||
for ids_row in ids:
|
||||
ids_row.append(append_id)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user