Replace lru_cache with instance-level cache for tokens

This commit is contained in:
EFE AYDIN 2026-05-15 23:21:38 +03:00 committed by GitHub
parent f8ca0b5c21
commit 990a26332c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,7 +8,6 @@ Two implementations are available:
import os
import copy
from functools import lru_cache
SPECIAL_TOKENS = [
# every document begins with the Beginning of Sequence (BOS) token that delimits documents
@ -165,6 +164,9 @@ class RustBPETokenizer:
def __init__(self, enc, bos_token):
self.enc = enc
# instance-level cache for special token ids; replaces lru_cache on the
# method (which kept a strong ref to self in the function-level cache)
self._special_id_cache: dict[str, int] = {}
self.bos_token_id = self.encode_special(bos_token)
@classmethod
@ -215,9 +217,13 @@ class RustBPETokenizer:
def id_to_token(self, id):
return self.enc.decode([id])
@lru_cache(maxsize=32)
def encode_special(self, text):
return self.enc.encode_single_token(text)
cached = self._special_id_cache.get(text)
if cached is not None:
return cached
v = self.enc.encode_single_token(text)
self._special_id_cache[text] = v
return v
def get_bos_token_id(self):
return self.bos_token_id
@ -239,8 +245,8 @@ class RustBPETokenizer:
elif isinstance(text, list):
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
if prepend is not None:
for ids_row in ids:
ids_row.insert(0, prepend_id) # TODO: same
# avoid O(n) shift per row that insert(0, ...) does
ids = [[prepend_id, *row] for row in ids]
if append is not None:
for ids_row in ids:
ids_row.append(append_id)