diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index a2146c2e..e6046029 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -255,6 +255,9 @@ class RustBPETokenizer: def decode(self, ids): return self.enc.decode(ids) + def decode_single_token_bytes(self, token_id): + return self.enc.decode_single_token_bytes(token_id) + def save(self, tokenizer_dir): # save the encoding object to disk os.makedirs(tokenizer_dir, exist_ok=True) diff --git a/scripts/tok_train.py b/scripts/tok_train.py index 90495b19..56a4fc68 100644 --- a/scripts/tok_train.py +++ b/scripts/tok_train.py @@ -74,16 +74,11 @@ assert decoded == test_text # allows us to report a loss that is invariant to the vocab size of the tokenizer. # The bits per byte on the validation set is then one of the primary metrics we care about. vocab_size = tokenizer.get_vocab_size() -special_set = set(tokenizer.get_special_tokens()) -token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)] -token_bytes = [] +special_set = set(tokenizer.encode_special(token) for token in tokenizer.get_special_tokens()) +token_bytes = [len(tokenizer.decode_single_token_bytes(token_id)) for token_id in range(vocab_size)] for token_id in range(vocab_size): - token_str = token_strings[token_id] # the Python string representation of this token - if token_str in special_set: - token_bytes.append(0) # special characters are not counted - else: - id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token - token_bytes.append(id_bytes) + if token_id in special_set: + token_bytes[token_id] = 0 # special characters are not counted token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu') token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt") with open(token_bytes_path, "wb") as f: