mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-15 10:39:08 +00:00
Replace lru_cache with instance-level cache for tokens
This commit is contained in:
parent
f8ca0b5c21
commit
990a26332c
|
|
@ -8,7 +8,6 @@ Two implementations are available:
|
|||
|
||||
import os
|
||||
import copy
|
||||
from functools import lru_cache
|
||||
|
||||
SPECIAL_TOKENS = [
|
||||
# every document begins with the Beginning of Sequence (BOS) token that delimits documents
|
||||
|
|
@ -165,6 +164,9 @@ class RustBPETokenizer:
|
|||
|
||||
def __init__(self, enc, bos_token):
|
||||
self.enc = enc
|
||||
# instance-level cache for special token ids; replaces lru_cache on the
|
||||
# method (which kept a strong ref to self in the function-level cache)
|
||||
self._special_id_cache: dict[str, int] = {}
|
||||
self.bos_token_id = self.encode_special(bos_token)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -215,9 +217,13 @@ class RustBPETokenizer:
|
|||
def id_to_token(self, id):
|
||||
return self.enc.decode([id])
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def encode_special(self, text):
|
||||
return self.enc.encode_single_token(text)
|
||||
cached = self._special_id_cache.get(text)
|
||||
if cached is not None:
|
||||
return cached
|
||||
v = self.enc.encode_single_token(text)
|
||||
self._special_id_cache[text] = v
|
||||
return v
|
||||
|
||||
def get_bos_token_id(self):
|
||||
return self.bos_token_id
|
||||
|
|
@ -239,8 +245,8 @@ class RustBPETokenizer:
|
|||
elif isinstance(text, list):
|
||||
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
|
||||
if prepend is not None:
|
||||
for ids_row in ids:
|
||||
ids_row.insert(0, prepend_id) # TODO: same
|
||||
# avoid O(n) shift per row that insert(0, ...) does
|
||||
ids = [[prepend_id, *row] for row in ids]
|
||||
if append is not None:
|
||||
for ids_row in ids:
|
||||
ids_row.append(append_id)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user