mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-16 20:57:33 +00:00
Merge e9b44f62c2 into dc54a1a307
This commit is contained in:
commit
586e734873
|
|
@ -255,6 +255,9 @@ class RustBPETokenizer:
|
|||
def decode(self, ids):
|
||||
return self.enc.decode(ids)
|
||||
|
||||
def decode_single_token_bytes(self, token_id):
|
||||
return self.enc.decode_single_token_bytes(token_id)
|
||||
|
||||
def save(self, tokenizer_dir):
|
||||
# save the encoding object to disk
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
|
|
|
|||
|
|
@ -74,16 +74,11 @@ assert decoded == test_text
|
|||
# allows us to report a loss that is invariant to the vocab size of the tokenizer.
|
||||
# The bits per byte on the validation set is then one of the primary metrics we care about.
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
special_set = set(tokenizer.get_special_tokens())
|
||||
token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)]
|
||||
token_bytes = []
|
||||
special_set = set(tokenizer.encode_special(token) for token in tokenizer.get_special_tokens())
|
||||
token_bytes = [len(tokenizer.decode_single_token_bytes(token_id)) for token_id in range(vocab_size)]
|
||||
for token_id in range(vocab_size):
|
||||
token_str = token_strings[token_id] # the Python string representation of this token
|
||||
if token_str in special_set:
|
||||
token_bytes.append(0) # special characters are not counted
|
||||
else:
|
||||
id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
|
||||
token_bytes.append(id_bytes)
|
||||
if token_id in special_set:
|
||||
token_bytes[token_id] = 0 # special characters are not counted
|
||||
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu')
|
||||
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
|
||||
with open(token_bytes_path, "wb") as f:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user