diff --git a/rustbpe/src/lib.rs b/rustbpe/src/lib.rs index 273d7f2..f9c8494 100644 --- a/rustbpe/src/lib.rs +++ b/rustbpe/src/lib.rs @@ -465,6 +465,22 @@ impl Tokenizer { all_ids } + + /// Encode multiple texts in parallel using rayon. + /// Returns a list of token ID vectors, one per input text. + #[pyo3(signature = (texts))] + #[pyo3(text_signature = "(self, texts)")] + pub fn batch_encode(&self, py: Python<'_>, texts: Vec) -> PyResult>> { + // Release Python GIL and encode in parallel using rayon + let results = py.allow_threads(|| { + texts + .par_iter() + .map(|text| self.encode(text)) + .collect::>>() + }); + + Ok(results) + } } #[pymodule] diff --git a/tests/test_rustbpe.py b/tests/test_rustbpe.py index aca67fc..482ea20 100644 --- a/tests/test_rustbpe.py +++ b/tests/test_rustbpe.py @@ -633,3 +633,84 @@ def test_interface(enwik8_small): ids_reloaded = tok_reloaded.encode(encode_text) assert ids_reloaded == ids, "Reloaded tokenizer should produce same results" print("✅ Save/load through temporary directory OK") + + +def test_batch_encode_correctness(enwik8_small): + """Quick correctness test for batch_encode()""" + text = enwik8_small + vocab_size = 512 + + tokenizer = rustbpe.Tokenizer() + tokenizer.train_from_iterator([text], vocab_size) + + # Test with various batch sizes and edge cases + test_texts = [ + "Hello world", + "The quick brown fox", + "jumps over the lazy dog", + "", # empty string + "a", # single char + ] + + # Compare batch vs individual encoding + individual = [tokenizer.encode(t) for t in test_texts] + batched = tokenizer.batch_encode(test_texts) + + assert individual == batched, "Batch encoding should match individual encoding" + print("✅ batch_encode() correctness verified") + + +@pytest.mark.slow +def test_batch_encode_performance(enwik8_large): + """ + Benchmark batch_encode() vs sequential encode() loop. + Demonstrates parallelization speedup. + """ + # Setup + text = enwik8_large # 10MB dataset + vocab_size = 2048 + + # Train tokenizer + print("\nTraining tokenizer...") + tokenizer = rustbpe.Tokenizer() + tokenizer.train_from_iterator([text], vocab_size) + + # Create test batch: split text into chunks + chunk_size = 50_000 # ~50KB per chunk + chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)] + chunks = chunks[:20] # Use first 20 chunks (~1MB total) + + print(f"\nBatch encoding benchmark:") + print(f" Number of texts: {len(chunks)}") + print(f" Avg text length: {sum(len(c) for c in chunks) / len(chunks):.0f} chars") + + # Benchmark 1: Sequential encoding (baseline) + print("\n [1/3] Sequential encode() loop...") + sequential_results, sequential_time = time_function( + lambda: [tokenizer.encode(chunk) for chunk in chunks] + ) + print(f" Time: {sequential_time:.4f}s") + + # Benchmark 2: Parallel batch_encode() + print(" [2/3] Parallel batch_encode()...") + batch_results, batch_time = time_function( + tokenizer.batch_encode, chunks + ) + print(f" Time: {batch_time:.4f}s") + + # Verify correctness + print(" [3/3] Verifying correctness...") + assert len(batch_results) == len(sequential_results), "Result count mismatch" + for i, (seq, batch) in enumerate(zip(sequential_results, batch_results)): + assert seq == batch, f"Mismatch at index {i}" + print(" ✓ All results match") + + # Report speedup + speedup = sequential_time / batch_time + print(f"\n Performance Results:") + print(f" Sequential: {sequential_time:.4f}s") + print(f" Batch: {batch_time:.4f}s") + print(f" Speedup: {speedup:.2f}x") + + # Assert meaningful speedup (at least 1.5x on multi-core) + assert speedup > 1.5, f"Expected >1.5x speedup, got {speedup:.2f}x"