diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index a2146c2e..6ee0603e 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -8,7 +8,6 @@ Two implementations are available: import os import copy -from functools import lru_cache SPECIAL_TOKENS = [ # every document begins with the Beginning of Sequence (BOS) token that delimits documents @@ -165,6 +164,9 @@ class RustBPETokenizer: def __init__(self, enc, bos_token): self.enc = enc + # instance-level cache for special token ids; replaces lru_cache on the + # method (which kept a strong ref to self in the function-level cache) + self._special_id_cache: dict[str, int] = {} self.bos_token_id = self.encode_special(bos_token) @classmethod @@ -215,9 +217,13 @@ class RustBPETokenizer: def id_to_token(self, id): return self.enc.decode([id]) - @lru_cache(maxsize=32) def encode_special(self, text): - return self.enc.encode_single_token(text) + cached = self._special_id_cache.get(text) + if cached is not None: + return cached + v = self.enc.encode_single_token(text) + self._special_id_cache[text] = v + return v def get_bos_token_id(self): return self.bos_token_id @@ -239,8 +245,8 @@ class RustBPETokenizer: elif isinstance(text, list): ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads) if prepend is not None: - for ids_row in ids: - ids_row.insert(0, prepend_id) # TODO: same + # avoid O(n) shift per row that insert(0, ...) does + ids = [[prepend_id, *row] for row in ids] if append is not None: for ids_row in ids: ids_row.append(append_id)