mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-20 18:34:14 +00:00
add rust batch encode as a faster option over encode
This commit is contained in:
parent
d5759400f9
commit
790f3be65c
|
|
@ -465,6 +465,22 @@ impl Tokenizer {
|
|||
|
||||
all_ids
|
||||
}
|
||||
|
||||
/// Encode multiple texts in parallel using rayon.
|
||||
/// Returns a list of token ID vectors, one per input text.
|
||||
#[pyo3(signature = (texts))]
|
||||
#[pyo3(text_signature = "(self, texts)")]
|
||||
pub fn batch_encode(&self, py: Python<'_>, texts: Vec<String>) -> PyResult<Vec<Vec<u32>>> {
|
||||
// Release Python GIL and encode in parallel using rayon
|
||||
let results = py.allow_threads(|| {
|
||||
texts
|
||||
.par_iter()
|
||||
.map(|text| self.encode(text))
|
||||
.collect::<Vec<Vec<u32>>>()
|
||||
});
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
|
|
|
|||
|
|
@ -633,3 +633,84 @@ def test_interface(enwik8_small):
|
|||
ids_reloaded = tok_reloaded.encode(encode_text)
|
||||
assert ids_reloaded == ids, "Reloaded tokenizer should produce same results"
|
||||
print("✅ Save/load through temporary directory OK")
|
||||
|
||||
|
||||
def test_batch_encode_correctness(enwik8_small):
|
||||
"""Quick correctness test for batch_encode()"""
|
||||
text = enwik8_small
|
||||
vocab_size = 512
|
||||
|
||||
tokenizer = rustbpe.Tokenizer()
|
||||
tokenizer.train_from_iterator([text], vocab_size)
|
||||
|
||||
# Test with various batch sizes and edge cases
|
||||
test_texts = [
|
||||
"Hello world",
|
||||
"The quick brown fox",
|
||||
"jumps over the lazy dog",
|
||||
"", # empty string
|
||||
"a", # single char
|
||||
]
|
||||
|
||||
# Compare batch vs individual encoding
|
||||
individual = [tokenizer.encode(t) for t in test_texts]
|
||||
batched = tokenizer.batch_encode(test_texts)
|
||||
|
||||
assert individual == batched, "Batch encoding should match individual encoding"
|
||||
print("✅ batch_encode() correctness verified")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_batch_encode_performance(enwik8_large):
|
||||
"""
|
||||
Benchmark batch_encode() vs sequential encode() loop.
|
||||
Demonstrates parallelization speedup.
|
||||
"""
|
||||
# Setup
|
||||
text = enwik8_large # 10MB dataset
|
||||
vocab_size = 2048
|
||||
|
||||
# Train tokenizer
|
||||
print("\nTraining tokenizer...")
|
||||
tokenizer = rustbpe.Tokenizer()
|
||||
tokenizer.train_from_iterator([text], vocab_size)
|
||||
|
||||
# Create test batch: split text into chunks
|
||||
chunk_size = 50_000 # ~50KB per chunk
|
||||
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
|
||||
chunks = chunks[:20] # Use first 20 chunks (~1MB total)
|
||||
|
||||
print(f"\nBatch encoding benchmark:")
|
||||
print(f" Number of texts: {len(chunks)}")
|
||||
print(f" Avg text length: {sum(len(c) for c in chunks) / len(chunks):.0f} chars")
|
||||
|
||||
# Benchmark 1: Sequential encoding (baseline)
|
||||
print("\n [1/3] Sequential encode() loop...")
|
||||
sequential_results, sequential_time = time_function(
|
||||
lambda: [tokenizer.encode(chunk) for chunk in chunks]
|
||||
)
|
||||
print(f" Time: {sequential_time:.4f}s")
|
||||
|
||||
# Benchmark 2: Parallel batch_encode()
|
||||
print(" [2/3] Parallel batch_encode()...")
|
||||
batch_results, batch_time = time_function(
|
||||
tokenizer.batch_encode, chunks
|
||||
)
|
||||
print(f" Time: {batch_time:.4f}s")
|
||||
|
||||
# Verify correctness
|
||||
print(" [3/3] Verifying correctness...")
|
||||
assert len(batch_results) == len(sequential_results), "Result count mismatch"
|
||||
for i, (seq, batch) in enumerate(zip(sequential_results, batch_results)):
|
||||
assert seq == batch, f"Mismatch at index {i}"
|
||||
print(" ✓ All results match")
|
||||
|
||||
# Report speedup
|
||||
speedup = sequential_time / batch_time
|
||||
print(f"\n Performance Results:")
|
||||
print(f" Sequential: {sequential_time:.4f}s")
|
||||
print(f" Batch: {batch_time:.4f}s")
|
||||
print(f" Speedup: {speedup:.2f}x")
|
||||
|
||||
# Assert meaningful speedup (at least 1.5x on multi-core)
|
||||
assert speedup > 1.5, f"Expected >1.5x speedup, got {speedup:.2f}x"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user