This commit is contained in:
Ellen Xu 2026-05-05 14:40:48 -07:00 committed by GitHub
commit 586e734873
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 9 deletions

View File

@ -255,6 +255,9 @@ class RustBPETokenizer:
def decode(self, ids):
return self.enc.decode(ids)
def decode_single_token_bytes(self, token_id):
return self.enc.decode_single_token_bytes(token_id)
def save(self, tokenizer_dir):
# save the encoding object to disk
os.makedirs(tokenizer_dir, exist_ok=True)

View File

@ -74,16 +74,11 @@ assert decoded == test_text
# allows us to report a loss that is invariant to the vocab size of the tokenizer.
# The bits per byte on the validation set is then one of the primary metrics we care about.
vocab_size = tokenizer.get_vocab_size()
special_set = set(tokenizer.get_special_tokens())
token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)]
token_bytes = []
special_set = set(tokenizer.encode_special(token) for token in tokenizer.get_special_tokens())
token_bytes = [len(tokenizer.decode_single_token_bytes(token_id)) for token_id in range(vocab_size)]
for token_id in range(vocab_size):
token_str = token_strings[token_id] # the Python string representation of this token
if token_str in special_set:
token_bytes.append(0) # special characters are not counted
else:
id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
token_bytes.append(id_bytes)
if token_id in special_set:
token_bytes[token_id] = 0 # special characters are not counted
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu')
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
with open(token_bytes_path, "wb") as f: