nanochat/tests/test_rustbpe.py
2025-11-03 21:27:12 +01:00

636 lines
26 KiB
Python

"""
Comparing the training of:
1. (very slow) Python reference implementation
2. Optimized Python implementation
3. HuggingFace tokenizers training implementation
4. Our own custom RustBPE training implementation
All of these should calculate the same merges and produce
the same vocabulary and tokenizations.
Finally, for inference we will use tiktoken for efficiency.
So we want to make sure we can export our rustbpe tokenizer
into tiktoken and use it for inference with identical results.
Run with:
python -m pytest tests/test_rustbpe.py -v -s
-v is verbose, -s is show prints
"""
import regex as re
from collections import Counter, defaultdict
import time
import rustbpe
import tiktoken
import pytest
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
# -----------------------------------------------------------------------------
# Reference tokenizer, pretty much copy pasted and pruned a bit from minbpe
def get_stats(ids, counts=None):
"""
Given a list of integers, return a dictionary of counts of consecutive pairs
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
Optionally allows to update an existing dictionary of counts
"""
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]): # iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
"""
In the list of integers (ids), replace all consecutive occurrences
of pair with the new integer token idx
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
newids = []
i = 0
while i < len(ids):
# if not at the very last position AND the pair matches, replace it
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
class RegexTokenizer:
def __init__(self, pattern=None):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
- special_tokens: str -> int dictionary of special tokens
example: {'<|endoftext|>': 100257}
"""
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
self.merges = {} # (int, int) -> int
self.compiled_pattern = re.compile(self.pattern)
self.special_tokens = {}
self.inverse_special_tokens = {}
self.vocab = self._build_vocab()
def _build_vocab(self):
# vocab is simply and deterministically derived from merges
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
num_merges = vocab_size - 256
# keep track of whether at any point during training the merge is ambiguous (counts of pairs are not unique)
ambiguous = False
# split the text up into text chunks
text_chunks = re.findall(self.compiled_pattern, text)
# input text preprocessing
ids = [list(ch.encode("utf-8")) for ch in text_chunks]
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
for i in range(num_merges):
# count the number of times every consecutive pair appears
stats = {}
for chunk_ids in ids:
# passing in stats will update it in place, adding up counts
get_stats(chunk_ids, stats)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# check if the merge is ambiguous - i.e. the max value is not unique
pair_count = stats[pair]
pairs_with_max_count = [pair for pair, count in stats.items() if count == pair_count]
if len(pairs_with_max_count) > 1:
# print the top 10 pairs with their counts
# print(f"{i} Merge is ambiguous! {pair} has {pair_count} occurrences")
# for print_pair, print_count in sorted(stats.items(), key=lambda x: x[1], reverse=True)[:10]:
# print(f"{print_pair}: {print_count}")
ambiguous = True
# mint a new token: assign it the next available id
idx = 256 + i
# replace all occurrences of pair in ids with idx
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
return ambiguous
def _encode_chunk(self, text_bytes):
# return the token ids
# let's begin. first, convert all bytes to integers in range 0..255
ids = list(text_bytes)
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
def encode_ordinary(self, text):
"""Encoding that ignores any special tokens."""
# split text into chunks of text by categories defined in regex pattern
text_chunks = re.findall(self.compiled_pattern, text)
# all chunks of text are encoded separately, then results are joined
ids = []
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_ids = self._encode_chunk(chunk_bytes)
ids.extend(chunk_ids)
return ids
# -----------------------------------------------------------------------------
# Faster Python tokenizer, optimized version of the reference tokenizer
def fast_merge_inplace(ids, pair, idx):
"""
In the list of integers (ids), replace all consecutive occurrences
of pair with the new integer token idx in place
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
# Find all positions where the pair occurs
i = 0
while i < len(ids) - 1:
if ids[i] == pair[0] and ids[i+1] == pair[1]:
ids[i] = idx
ids.pop(i+1)
else:
i += 1
return ids
class FastRegexTokenizer:
def __init__(self, pattern=None):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
- special_tokens: str -> int dictionary of special tokens
example: {'<|endoftext|>': 100257}
"""
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
self.compiled_pattern = re.compile(self.pattern)
self.special_tokens = {}
self.inverse_special_tokens = {}
self.merges = {}
self.vocab = self._build_vocab()
def _build_vocab(self):
# vocab is simply and deterministically derived from merges
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab
def train(self, text, vocab_size, verbose=False):
"""
A number of optimizations are introduced:
- delete function call overhead by inlining functions
- modifying list of ids in place with .pop() instead of creating a new list
- collapse identical chunks to just the unique ones
- update counts more cleverly - only around the affected chunks
"""
assert vocab_size >= 256
num_merges = vocab_size - 256
# split the text up into text chunks
text_chunks = re.findall(self.compiled_pattern, text)
# many, many chunks are identical, so we can "collapse" them to just the unique ones
counts = Counter(text_chunks)
unique_chunks = [ch for ch, count in counts.items()]
chunk_counts = [count for ch, count in counts.items()]
# input text preprocessing
ids = [list(ch.encode("utf-8")) for ch in unique_chunks]
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
# Initial count: build stats and position tracking
stats = defaultdict(int)
positions = defaultdict(set) # pair -> set of chunk indices that contain this pair
for chunk_idx, (chunk_ids, count) in enumerate(zip(ids, chunk_counts)):
for pair in zip(chunk_ids, chunk_ids[1:]):
stats[pair] += count
positions[pair].add(chunk_idx)
for i in range(num_merges):
if not stats:
break
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = 256 + i
# Get chunks that contain this pair
affected_chunks = positions[pair]
# Track count changes for incremental update
count_changes = defaultdict(int)
# Replace all occurrences of pair in affected chunks only
for chunk_idx in affected_chunks:
chunk_ids = ids[chunk_idx]
chunk_count = chunk_counts[chunk_idx]
ix = 0
while ix < len(chunk_ids) - 1:
if chunk_ids[ix] == pair[0] and chunk_ids[ix+1] == pair[1]:
# Track what pairs are being removed/added
# Remove: (prev, A), (A, B), (B, next)
if ix > 0:
old_left = (chunk_ids[ix-1], chunk_ids[ix])
count_changes[old_left] -= chunk_count
# The merged pair disappears
count_changes[pair] -= chunk_count
if ix + 2 < len(chunk_ids):
old_right = (chunk_ids[ix+1], chunk_ids[ix+2])
count_changes[old_right] -= chunk_count
# Apply the merge
chunk_ids[ix] = idx
chunk_ids.pop(ix+1)
# Add: (prev, C), (C, next)
if ix > 0:
new_left = (chunk_ids[ix-1], chunk_ids[ix])
count_changes[new_left] += chunk_count
if ix + 1 < len(chunk_ids):
new_right = (chunk_ids[ix], chunk_ids[ix+1])
count_changes[new_right] += chunk_count
else:
ix += 1
# Apply incremental changes to stats and positions
for changed_pair, delta in count_changes.items():
if changed_pair == pair:
# The merged pair should disappear completely
continue
stats[changed_pair] += delta
# Update positions for changed pairs - only check affected chunks
for chunk_idx in affected_chunks:
chunk_ids = ids[chunk_idx]
contains_pair = any((chunk_ids[j], chunk_ids[j+1]) == changed_pair
for j in range(len(chunk_ids) - 1))
if contains_pair:
positions[changed_pair].add(chunk_idx)
else:
positions[changed_pair].discard(chunk_idx)
# Remove the merged pair completely
del stats[pair]
del positions[pair]
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
def register_special_tokens(self, special_tokens):
# special_tokens is a dictionary of str -> int
# example: {"<|endoftext|>": 100257}
self.special_tokens = special_tokens
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
def decode(self, ids):
# given ids (list of integers), return Python string
part_bytes = []
for idx in ids:
if idx in self.vocab:
part_bytes.append(self.vocab[idx])
elif idx in self.inverse_special_tokens:
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
else:
raise ValueError(f"invalid token id: {idx}")
text_bytes = b"".join(part_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text
def _encode_chunk(self, text_bytes):
# return the token ids
# let's begin. first, convert all bytes to integers in range 0..255
ids = list(text_bytes)
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = fast_merge_inplace(ids, pair, idx)
return ids
def encode_ordinary(self, text):
"""Encoding that ignores any special tokens."""
# split text into chunks of text by categories defined in regex pattern
text_chunks = re.findall(self.compiled_pattern, text)
# all chunks of text are encoded separately, then results are joined
ids = []
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_ids = self._encode_chunk(chunk_bytes)
ids.extend(chunk_ids)
return ids
# -----------------------------------------------------------------------------
# HuggingFace tokenizer
from tokenizers import Tokenizer as HFTokenizer
from tokenizers import pre_tokenizers, decoders, Regex
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
class HuggingFaceTokenizer:
"""Light wrapper around HuggingFace Tokenizer for some utilities"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
@classmethod
def train_from_iterator(cls, text_iterator, vocab_size):
# train from an iterator of text
# Configure the HuggingFace Tokenizer
tokenizer = HFTokenizer(BPE(
byte_fallback=True, # needed!
unk_token=None,
fuse_unk=False,
))
# Normalizer: None
tokenizer.normalizer = None
# Pre-tokenizer: GPT-4 style
gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
])
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
tokenizer.decoder = decoders.ByteLevel()
# Post-processor: None
tokenizer.post_processor = None
# Trainer: BPE
trainer = BpeTrainer(
vocab_size=vocab_size,
show_progress=True,
min_frequency=0, # no minimum frequency
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
special_tokens=[], # no special tokens
)
# Kick off the training
tokenizer.train_from_iterator(text_iterator, trainer)
return cls(tokenizer)
def encode_ordinary(self, text):
ids = self.tokenizer.encode(text, add_special_tokens=False).ids
return ids
# -----------------------------------------------------------------------------
# Test all of the above
@pytest.fixture(scope="module")
def enwik8_path():
"""Fixture to download and cache enwik8 dataset."""
import os
import zipfile
from nanochat.common import get_base_dir
base_dir = get_base_dir()
# download and unzip enwik8 to .cache directory
enwik8_url = "https://mattmahoney.net/dc/enwik8.zip"
enwik8_local_path = os.path.join(base_dir, "enwik8")
enwik8_local_path_zip = os.path.join(base_dir, "enwik8.zip")
if not os.path.exists(enwik8_local_path):
print(f"Downloading enwik8 to {enwik8_local_path_zip}")
import requests
response = requests.get(enwik8_url)
with open(enwik8_local_path_zip, "wb") as f:
f.write(response.content)
with zipfile.ZipFile(enwik8_local_path_zip, "r") as zip_ref:
zip_ref.extractall(base_dir)
print(f"Unzipped enwik8 to {enwik8_local_path}")
os.remove(enwik8_local_path_zip)
print(f"Removed {enwik8_local_path_zip}")
else:
print(f"Using existing enwik8 at {enwik8_local_path}")
return enwik8_local_path
@pytest.fixture(scope="module")
def enwik8_small(enwik8_path):
"""Fixture providing 100KB of enwik8 for quick tests."""
with open(enwik8_path, "r", encoding='utf-8') as f:
return f.read(100_000)
@pytest.fixture(scope="module")
def enwik8_large(enwik8_path):
"""Fixture providing 10MB of enwik8 for performance tests."""
with open(enwik8_path, "r", encoding='utf-8') as f:
return f.read(10**7)
def time_function(func, *args, **kwargs):
"""Time a function call and return the result and elapsed time"""
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
elapsed = end_time - start_time
return result, elapsed
def test_correctness(enwik8_small):
"""Test that all tokenizer implementations produce the same results."""
text = enwik8_small
encode_text = text
vocab_size = 256 + 20 # 20 merges
# Train slow reference
print("\nTraining slow reference...")
slow_reference_tokenizer = RegexTokenizer()
ambiguous_flag, slow_reference_train_time = time_function(slow_reference_tokenizer.train, text, vocab_size)
slow_reference_ids, slow_reference_encode_time = time_function(slow_reference_tokenizer.encode_ordinary, encode_text)
print(f"Slow reference train time: {slow_reference_train_time:.4f}s")
print(f"Slow reference encode time: {slow_reference_encode_time:.4f}s")
print(slow_reference_ids[:20])
if ambiguous_flag:
print("‼️ WARNING: merge order was detected to be ambiguous given current text and vocab size")
print("The implementation could be correct but we might see different results below")
else:
print("✅ Merge order is NOT ambiguous")
# Train fast reference
print("\nTraining fast reference...")
fast_reference_tokenizer = FastRegexTokenizer()
_, fast_reference_train_time = time_function(fast_reference_tokenizer.train, text, vocab_size)
fast_reference_ids, fast_reference_encode_time = time_function(fast_reference_tokenizer.encode_ordinary, encode_text)
print(f"Fast reference train time: {fast_reference_train_time:.4f}s")
print(f"Fast reference encode time: {fast_reference_encode_time:.4f}s")
print(fast_reference_ids[:20])
# Assert fast equals slow
assert fast_reference_ids == slow_reference_ids, "Fast reference should match slow reference"
print("✅ Fast == Slow")
# Train HuggingFace
print("\nTraining HuggingFace...")
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
hf_ids, hf_encode_time = time_function(hf_tokenizer.encode_ordinary, encode_text)
print(f"HuggingFace train time: {hf_train_time:.4f}s")
print(f"HuggingFace encode time: {hf_encode_time:.4f}s")
print(hf_ids[:20])
# HuggingFace has a different byte order, so we need custom matching
def custom_match(ids1, ids2):
perm = {}
for x, y in zip(ids1, ids2):
if x < 256:
if x in perm:
if perm[x] != y:
return False
perm[x] = y
if x >= 256 and x != y:
return False
return True
assert custom_match(hf_ids, fast_reference_ids), "HuggingFace should match fast reference"
print("✅ HuggingFace == Fast")
# Finally use our own Rust implementation
print("\nTraining rustbpe...")
rustbpe_tokenizer = rustbpe.Tokenizer()
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
rustbpe_ids, rustbpe_encode_time = time_function(rustbpe_tokenizer.encode, encode_text)
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
print(f"RustBPE encode time: {rustbpe_encode_time:.4f}s")
print(rustbpe_ids[:20])
assert rustbpe_ids == fast_reference_ids, "RustBPE should match fast reference"
print("✅ RustBPE == Fast")
# Now export rustbpe to tiktoken for more efficient inference
print("\nTesting tiktoken export...")
pattern = rustbpe_tokenizer.get_pattern()
mergeable_ranks_list = rustbpe_tokenizer.get_mergeable_ranks()
mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
enc = tiktoken.Encoding(
name="rustbpe",
pat_str=pattern,
mergeable_ranks=mergeable_ranks,
special_tokens={},
)
tiktoken_ids, tiktoken_encode_time = time_function(enc.encode, encode_text)
print(f"Tiktoken encode time: {tiktoken_encode_time:.4f}s")
print(tiktoken_ids[:20])
assert tiktoken_ids == rustbpe_ids, "Tiktoken should match RustBPE"
print("✅ Tiktoken == RustBPE")
@pytest.mark.slow
def test_training_performance(enwik8_large):
"""Use a bigger dataset and compare the training speed of the optimized tokenizers (Python, Rust, HuggingFace)."""
text = enwik8_large
vocab_size = 2048
print(f"\nText length: {len(text)}")
# Commenting out because it's just way too slow to matter
# Train optimized python version
# print("Training optimized python version...")
# optimized_python_tokenizer = FastRegexTokenizer()
# _, optimized_python_train_time = time_function(optimized_python_tokenizer.train, text, vocab_size)
# print(f"Optimized python train time: {optimized_python_train_time:.4f}s")
# Train rustbpe
print("\nTraining rustbpe...")
rustbpe_tokenizer = rustbpe.Tokenizer()
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
assert rustbpe_train_time > 0, "Training should take some time"
# Train HuggingFace
print("\nTraining HuggingFace...")
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
print(f"HuggingFace train time: {hf_train_time:.4f}s")
assert hf_train_time > 0, "Training should take some time"
# Print comparison
print(f"\n📊 Performance comparison:")
print(f" RustBPE: {rustbpe_train_time:.4f}s")
print(f" HuggingFace: {hf_train_time:.4f}s")
print(f" Speedup: {hf_train_time/rustbpe_train_time:.2f}x")
def test_interface(enwik8_small):
"""Test the RustBPETokenizer interface for training, encoding, decoding, and serialization."""
import tempfile
from nanochat.tokenizer import RustBPETokenizer
# Simple train test
vocab_size = 300
tok = RustBPETokenizer.train_from_iterator([enwik8_small], vocab_size)
assert tok.get_vocab_size() == vocab_size, f"Expected vocab size {vocab_size}, got {tok.get_vocab_size()}"
print(f"✅ Trained tokenizer with vocab size {vocab_size}")
# Encode/decode text
encode_text = "Hello world! How are you? 🙃"
ids = tok.encode(encode_text)
print(f"\nInput text: {encode_text}")
print(f"IDs: {ids}")
decoded = tok.decode(ids)
print(f"Decoded: {decoded}")
assert decoded == encode_text, f"Decoded text doesn't match: {decoded} != {encode_text}"
print("✅ Encode/decode test passed")
# Encode batch test
ids_new = tok.encode([encode_text, encode_text])
assert all(x == ids for x in ids_new), "Batch encoding should produce identical results"
print("✅ Encode batch OK")
# append/prepend functionality
ids_special = tok.encode(encode_text, prepend="<|bos|>", append="<|bos|>")
bos_token_id = tok.encode_special("<|bos|>")
assert ids_special == [bos_token_id] + ids + [bos_token_id], "Special tokens not correctly added"
print("✅ append/prepend OK")
# Save/load test through a temporary directory
with tempfile.TemporaryDirectory() as tmp_dir:
tok.save(tmp_dir)
tok_reloaded = RustBPETokenizer.from_directory(tmp_dir)
ids_reloaded = tok_reloaded.encode(encode_text)
assert ids_reloaded == ids, "Reloaded tokenizer should produce same results"
print("✅ Save/load through temporary directory OK")