add rust batch encode as a faster option over encode

This commit is contained in:
Barış Özmen 2025-12-18 19:17:59 +03:00
parent d5759400f9
commit 790f3be65c
2 changed files with 97 additions and 0 deletions

View File

@ -465,6 +465,22 @@ impl Tokenizer {
all_ids
}
/// Encode multiple texts in parallel using rayon.
/// Returns a list of token ID vectors, one per input text.
#[pyo3(signature = (texts))]
#[pyo3(text_signature = "(self, texts)")]
pub fn batch_encode(&self, py: Python<'_>, texts: Vec<String>) -> PyResult<Vec<Vec<u32>>> {
// Release Python GIL and encode in parallel using rayon
let results = py.allow_threads(|| {
texts
.par_iter()
.map(|text| self.encode(text))
.collect::<Vec<Vec<u32>>>()
});
Ok(results)
}
}
#[pymodule]

View File

@ -633,3 +633,84 @@ def test_interface(enwik8_small):
ids_reloaded = tok_reloaded.encode(encode_text)
assert ids_reloaded == ids, "Reloaded tokenizer should produce same results"
print("✅ Save/load through temporary directory OK")
def test_batch_encode_correctness(enwik8_small):
"""Quick correctness test for batch_encode()"""
text = enwik8_small
vocab_size = 512
tokenizer = rustbpe.Tokenizer()
tokenizer.train_from_iterator([text], vocab_size)
# Test with various batch sizes and edge cases
test_texts = [
"Hello world",
"The quick brown fox",
"jumps over the lazy dog",
"", # empty string
"a", # single char
]
# Compare batch vs individual encoding
individual = [tokenizer.encode(t) for t in test_texts]
batched = tokenizer.batch_encode(test_texts)
assert individual == batched, "Batch encoding should match individual encoding"
print("✅ batch_encode() correctness verified")
@pytest.mark.slow
def test_batch_encode_performance(enwik8_large):
"""
Benchmark batch_encode() vs sequential encode() loop.
Demonstrates parallelization speedup.
"""
# Setup
text = enwik8_large # 10MB dataset
vocab_size = 2048
# Train tokenizer
print("\nTraining tokenizer...")
tokenizer = rustbpe.Tokenizer()
tokenizer.train_from_iterator([text], vocab_size)
# Create test batch: split text into chunks
chunk_size = 50_000 # ~50KB per chunk
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
chunks = chunks[:20] # Use first 20 chunks (~1MB total)
print(f"\nBatch encoding benchmark:")
print(f" Number of texts: {len(chunks)}")
print(f" Avg text length: {sum(len(c) for c in chunks) / len(chunks):.0f} chars")
# Benchmark 1: Sequential encoding (baseline)
print("\n [1/3] Sequential encode() loop...")
sequential_results, sequential_time = time_function(
lambda: [tokenizer.encode(chunk) for chunk in chunks]
)
print(f" Time: {sequential_time:.4f}s")
# Benchmark 2: Parallel batch_encode()
print(" [2/3] Parallel batch_encode()...")
batch_results, batch_time = time_function(
tokenizer.batch_encode, chunks
)
print(f" Time: {batch_time:.4f}s")
# Verify correctness
print(" [3/3] Verifying correctness...")
assert len(batch_results) == len(sequential_results), "Result count mismatch"
for i, (seq, batch) in enumerate(zip(sequential_results, batch_results)):
assert seq == batch, f"Mismatch at index {i}"
print(" ✓ All results match")
# Report speedup
speedup = sequential_time / batch_time
print(f"\n Performance Results:")
print(f" Sequential: {sequential_time:.4f}s")
print(f" Batch: {batch_time:.4f}s")
print(f" Speedup: {speedup:.2f}x")
# Assert meaningful speedup (at least 1.5x on multi-core)
assert speedup > 1.5, f"Expected >1.5x speedup, got {speedup:.2f}x"