Add unit tests for RustBPE implementation

This commit is contained in:
Matt Suiche 2025-10-14 12:36:59 +02:00
parent dd6ff9a1cc
commit 144db24d5f

View File

@ -468,6 +468,651 @@ impl Tokenizer {
}
}
// ------------------------ Rust Unit Tests ------------------------
#[cfg(test)]
mod tests {
use super::*;
/// Helper function to create a simple tokenizer for testing
fn create_test_tokenizer() -> Tokenizer {
Tokenizer {
merges: StdHashMap::new(),
pattern: GPT4_PATTERN.to_string(),
compiled_pattern: Regex::new(GPT4_PATTERN).unwrap(),
}
}
/// Helper function to train tokenizer on simple text
fn train_simple_tokenizer(text: &str, vocab_size: u32) -> Tokenizer {
let mut tokenizer = create_test_tokenizer();
// Convert text to words and counts (simplified version)
let text_chunks: Vec<&str> = tokenizer.compiled_pattern.find_iter(text)
.filter_map(|m| m.ok())
.map(|m| m.as_str())
.collect();
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
for chunk in text_chunks {
*counts.entry(CompactString::from(chunk)).or_default() += 1;
}
let mut words = Vec::with_capacity(counts.len());
let mut cvec = Vec::with_capacity(counts.len());
for (chunk, c) in counts.into_iter() {
words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect()));
cvec.push(c);
}
tokenizer.train_core_incremental(words, cvec, vocab_size);
tokenizer
}
#[test]
fn test_word_creation() {
let ids = vec![72, 101, 108, 108, 111]; // "Hello" in ASCII
let word = Word::new(ids.clone());
assert_eq!(word.ids, ids);
}
#[test]
fn test_word_pairs() {
let ids = vec![72, 101, 108, 108, 111]; // "Hello"
let word = Word::new(ids);
let pairs: Vec<Pair> = word.pairs().collect();
assert_eq!(pairs, vec![(72, 101), (101, 108), (108, 108), (108, 111)]);
}
#[test]
fn test_word_pairs_empty() {
let word = Word::new(vec![]);
let pairs: Vec<Pair> = word.pairs().collect();
assert_eq!(pairs, vec![]);
}
#[test]
fn test_word_pairs_single() {
let word = Word::new(vec![72]);
let pairs: Vec<Pair> = word.pairs().collect();
assert_eq!(pairs, vec![]);
}
#[test]
fn test_merge_pair_simple() {
let mut word = Word::new(vec![65, 65, 66]); // "AAB"
let pair = (65, 65); // "AA"
let new_id = 256;
let deltas = word.merge_pair(pair, new_id);
assert_eq!(word.ids, vec![256, 66]); // Should be [new_id, 66]
// Check deltas: removed (65,65), removed (65,66), added (256,66)
assert_eq!(deltas.len(), 3);
assert!(deltas.contains(&(pair, -1))); // removed (65,65)
assert!(deltas.contains(&((65, 66), -1))); // removed (65,66)
assert!(deltas.contains(&((256, 66), 1))); // added (256,66)
}
#[test]
fn test_merge_pair_no_match() {
let mut word = Word::new(vec![65, 66, 67]); // "ABC"
let pair = (68, 69); // "DE" - not in word
let new_id = 256;
let _deltas = word.merge_pair(pair, new_id);
assert_eq!(word.ids, vec![65, 66, 67]); // Should be unchanged
}
#[test]
fn test_merge_pair_multiple_occurrences() {
let mut word = Word::new(vec![65, 65, 65, 65]); // "AAAA"
let pair = (65, 65); // "AA"
let new_id = 256;
let deltas = word.merge_pair(pair, new_id);
assert_eq!(word.ids, vec![256, 256]); // Should be [new_id, new_id]
// Should have removed 3 pairs of (65,65) and added 1 pair of (256,256)
let delta_sum: i32 = deltas.iter().map(|(_, delta)| *delta).sum();
assert_eq!(delta_sum, -2); // 3 removed, 1 added = -2
}
#[test]
fn test_merge_pair_overlapping() {
let mut word = Word::new(vec![65, 65, 65]); // "AAA"
let pair = (65, 65); // "AA"
let new_id = 256;
let _deltas = word.merge_pair(pair, new_id);
assert_eq!(word.ids, vec![256, 65]); // Should be [new_id, 65] - non-overlapping merges only
}
#[test]
fn test_merge_job_ordering() {
let job1 = MergeJob {
pair: (1, 2),
count: 10,
pos: AHashSet::new(),
};
let job2 = MergeJob {
pair: (3, 4),
count: 5,
pos: AHashSet::new(),
};
let job3 = MergeJob {
pair: (1, 2),
count: 10,
pos: AHashSet::new(),
};
// Test equality
assert_eq!(job1, job3);
assert_ne!(job1, job2);
// Test ordering (max-heap by count)
assert!(job1 > job2); // Higher count should be "greater"
// Test tie-breaking by pair (ascending order for max-heap)
let job4 = MergeJob {
pair: (2, 3),
count: 10,
pos: AHashSet::new(),
};
// For max-heap with tie-breaking, lower pair should be "greater"
assert!(job1 > job4); // Same count, but (1,2) < (2,3) so job1 > job4
}
#[test]
fn test_count_pairs_parallel() {
let words = vec![
Word::new(vec![1, 2, 3]),
Word::new(vec![2, 3, 4]),
Word::new(vec![1, 2, 3]),
];
let counts = vec![1, 1, 2]; // Different weights for each word
let (pair_counts, positions) = count_pairs_parallel(&words, &counts);
// Expected pairs and their counts:
// Word 0 (count=1): (1,2), (2,3)
// Word 1 (count=1): (2,3), (3,4)
// Word 2 (count=2): (1,2), (2,3) with weight 2
assert_eq!(pair_counts.get(&(1, 2)), Some(&3)); // 1 + 2
assert_eq!(pair_counts.get(&(2, 3)), Some(&4)); // 1 + 1 + 2
assert_eq!(pair_counts.get(&(3, 4)), Some(&1)); // 1
// Check positions
assert!(positions.get(&(1, 2)).unwrap().contains(&0));
assert!(positions.get(&(1, 2)).unwrap().contains(&2));
assert!(positions.get(&(2, 3)).unwrap().contains(&0));
assert!(positions.get(&(2, 3)).unwrap().contains(&1));
assert!(positions.get(&(2, 3)).unwrap().contains(&2));
assert!(positions.get(&(3, 4)).unwrap().contains(&1));
}
#[test]
fn test_count_pairs_parallel_empty() {
let words = vec![];
let counts = vec![];
let (pair_counts, positions) = count_pairs_parallel(&words, &counts);
assert_eq!(pair_counts.len(), 0);
assert_eq!(positions.len(), 0);
}
#[test]
fn test_tokenizer_creation() {
let tokenizer = create_test_tokenizer();
assert_eq!(tokenizer.pattern, GPT4_PATTERN);
assert_eq!(tokenizer.merges.len(), 0);
}
#[test]
fn test_tokenizer_get_pattern() {
let tokenizer = create_test_tokenizer();
assert_eq!(tokenizer.get_pattern(), GPT4_PATTERN);
}
#[test]
fn test_tokenizer_train_minimum_vocab() {
let mut tokenizer = create_test_tokenizer();
let text = "hello";
// Train with minimum vocab size (no merges)
let words = vec![Word::new(text.as_bytes().iter().map(|&b| b as u32).collect())];
let counts = vec![1];
tokenizer.train_core_incremental(words, counts, 256);
assert_eq!(tokenizer.merges.len(), 0); // No merges should occur
}
#[test]
fn test_tokenizer_train_simple() {
let text = "aaabdaaabac";
let tokenizer = train_simple_tokenizer(text, 256 + 3);
assert_eq!(tokenizer.merges.len(), 3); // Should have 3 merges
// Check that we can encode the training text
let encoded = tokenizer.encode(text);
assert!(!encoded.is_empty());
}
#[test]
fn test_tokenizer_train_repetitive_text() {
let text = "hello world hello world hello world";
let tokenizer = train_simple_tokenizer(text, 300);
assert!(tokenizer.merges.len() > 0);
let encoded = tokenizer.encode(text);
assert!(!encoded.is_empty());
// Encoding the same text twice should give same result
let encoded2 = tokenizer.encode(text);
assert_eq!(encoded, encoded2);
}
#[test]
fn test_tokenizer_encode_empty() {
let tokenizer = create_test_tokenizer();
let encoded = tokenizer.encode("");
assert_eq!(encoded, Vec::<u32>::new());
}
#[test]
fn test_tokenizer_encode_single_char() {
let tokenizer = create_test_tokenizer();
let encoded = tokenizer.encode("A");
assert_eq!(encoded, vec![65]); // ASCII 'A'
}
#[test]
fn test_tokenizer_encode_ascii() {
let tokenizer = create_test_tokenizer();
let text = "Hello";
let encoded = tokenizer.encode(text);
assert_eq!(encoded, vec![72, 101, 108, 108, 111]); // ASCII values
}
#[test]
fn test_tokenizer_encode_with_merges() {
let mut tokenizer = create_test_tokenizer();
// Manually add a merge for "aa" -> 256
tokenizer.merges.insert((97, 97), 256); // 'a' = 97 in ASCII
let encoded = tokenizer.encode("aa");
assert_eq!(encoded, vec![256]); // Should use the merged token
}
#[test]
fn test_tokenizer_encode_multiple_merges() {
let mut tokenizer = create_test_tokenizer();
// Add merges: "aa" -> 256, "bb" -> 257
tokenizer.merges.insert((97, 97), 256);
tokenizer.merges.insert((98, 98), 257);
let encoded = tokenizer.encode("aabb");
assert_eq!(encoded, vec![256, 257]); // Should use both merged tokens
}
#[test]
fn test_tokenizer_encode_priority() {
let mut tokenizer = create_test_tokenizer();
// Add merges with different priorities (lower ID = higher priority)
tokenizer.merges.insert((97, 98), 257); // "ab" -> 257
tokenizer.merges.insert((98, 99), 256); // "bc" -> 256 (higher priority)
let encoded = tokenizer.encode("abc");
// Should merge "bc" first (higher priority), then "a" + result
assert_eq!(encoded, vec![97, 256]); // "a" + merged "bc"
}
#[test]
fn test_tokenizer_encode_complex_text() {
let text = "Hello, world! 123";
let tokenizer = train_simple_tokenizer(text, 300);
let encoded = tokenizer.encode(text);
assert!(!encoded.is_empty());
// Should be able to encode and decode back (approximately)
// Note: We don't have decode in Rust, so we just check it's reasonable
assert!(encoded.len() <= text.len() * 2); // Reasonable upper bound
}
#[test]
fn test_tokenizer_get_mergeable_ranks_basic() {
let tokenizer = create_test_tokenizer();
let ranks = tokenizer.get_mergeable_ranks();
// Should have exactly 256 entries (just the base bytes)
assert_eq!(ranks.len(), 256);
// Check first few entries
for i in 0..256 {
assert_eq!(ranks[i], (vec![i as u8], i as u32));
}
}
#[test]
fn test_tokenizer_get_mergeable_ranks_with_merges() {
let mut tokenizer = create_test_tokenizer();
// Add some merges
tokenizer.merges.insert((97, 97), 256); // "aa" -> 256
tokenizer.merges.insert((98, 99), 257); // "bc" -> 257
let ranks = tokenizer.get_mergeable_ranks();
// Should have base 256 + 2 merges
assert_eq!(ranks.len(), 258);
// Check base bytes are still correct
for i in 0..256 {
assert_eq!(ranks[i], (vec![i as u8], i as u32));
}
// Check merge tokens
assert_eq!(ranks[256], (vec![97, 97], 256)); // "aa"
assert_eq!(ranks[257], (vec![98, 99], 257)); // "bc"
}
#[test]
fn test_tokenizer_get_mergeable_ranks_complex_merges() {
let mut tokenizer = create_test_tokenizer();
// Create a chain of merges: "aa" -> 256, then "aa" + "a" -> 257
tokenizer.merges.insert((97, 97), 256); // "aa" -> 256
tokenizer.merges.insert((256, 97), 257); // "aa" + "a" -> 257
let ranks = tokenizer.get_mergeable_ranks();
assert_eq!(ranks.len(), 258);
assert_eq!(ranks[256], (vec![97, 97], 256)); // "aa"
assert_eq!(ranks[257], (vec![97, 97, 97], 257)); // "aaa"
}
#[test]
fn test_unicode_handling() {
let tokenizer = create_test_tokenizer();
// Test with Unicode characters
let unicode_text = "Hello 世界 🚀";
let encoded = tokenizer.encode(unicode_text);
assert!(!encoded.is_empty());
// Each Unicode character should be encoded as one or more bytes
assert!(encoded.len() >= unicode_text.len());
}
#[test]
fn test_deterministic_training() {
let text = "hello world test";
let tokenizer1 = train_simple_tokenizer(text, 280);
let tokenizer2 = train_simple_tokenizer(text, 280);
// Should produce identical results
assert_eq!(tokenizer1.merges, tokenizer2.merges);
let encoded1 = tokenizer1.encode(text);
let encoded2 = tokenizer2.encode(text);
assert_eq!(encoded1, encoded2);
}
#[test]
fn test_training_with_different_vocab_sizes() {
let text = "hello world";
let tokenizer256 = train_simple_tokenizer(text, 256);
let tokenizer300 = train_simple_tokenizer(text, 300);
assert_eq!(tokenizer256.merges.len(), 0);
assert!(tokenizer300.merges.len() > 0);
}
#[test]
fn test_edge_case_empty_training() {
let mut tokenizer = create_test_tokenizer();
// Train with empty data
let words = vec![];
let counts = vec![];
tokenizer.train_core_incremental(words, counts, 300);
assert_eq!(tokenizer.merges.len(), 0);
}
#[test]
fn test_edge_case_single_word() {
let mut tokenizer = create_test_tokenizer();
let words = vec![Word::new(vec![72, 101, 108, 108, 111])]; // "Hello"
let counts = vec![1];
tokenizer.train_core_incremental(words, counts, 260);
// Should have some merges
assert!(tokenizer.merges.len() > 0);
}
#[test]
fn test_regex_pattern_compilation() {
let tokenizer = create_test_tokenizer();
// Test that the pattern compiles and works
let test_text = "Hello, world! 123";
let matches: Vec<&str> = tokenizer.compiled_pattern.find_iter(test_text)
.filter_map(|m| m.ok())
.map(|m| m.as_str())
.collect();
assert!(!matches.is_empty());
assert!(matches.iter().any(|&m| m.contains("Hello")));
}
#[test]
fn test_merge_job_partial_ord() {
let job1 = MergeJob {
pair: (1, 2),
count: 10,
pos: AHashSet::new(),
};
let job2 = MergeJob {
pair: (3, 4),
count: 15,
pos: AHashSet::new(),
};
// Test partial ordering
assert!(job1 < job2); // Lower count should be less
assert!(job2 > job1); // Higher count should be greater
}
#[test]
fn test_large_text_training() {
// Create a larger text by repetition
let base_text = "The quick brown fox jumps over the lazy dog. ";
let large_text = base_text.repeat(100);
let tokenizer = train_simple_tokenizer(&large_text, 500);
assert!(tokenizer.merges.len() > 0);
let encoded = tokenizer.encode(&large_text);
assert!(!encoded.is_empty());
// Should be more efficient than raw bytes due to merges
assert!(encoded.len() < large_text.len());
}
#[test]
fn test_special_characters() {
let tokenizer = create_test_tokenizer();
let special_text = "!@#$%^&*()_+-=[]{}|;':\",./<>?";
let encoded = tokenizer.encode(special_text);
assert!(!encoded.is_empty());
assert_eq!(encoded.len(), special_text.len()); // Each char should be one byte
}
#[test]
fn test_whitespace_handling() {
let tokenizer = create_test_tokenizer();
let whitespace_text = " \t\n\r ";
let encoded = tokenizer.encode(whitespace_text);
assert!(!encoded.is_empty());
// Should handle all whitespace characters
}
// ------------------------ Comparison Tests with Python Reference ------------------------
#[test]
fn test_comparison_simple_case() {
// Test case: "aaabdaaabac" with vocab_size = 259 (256 + 3 merges)
// Expected Python result: [258, 100, 258, 97, 99] with 3 merges
let text = "aaabdaaabac";
let vocab_size = 259;
let tokenizer = train_simple_tokenizer(text, vocab_size);
let encoded = tokenizer.encode(text);
// The most important thing: we should get the same final encoding
// Different merge sequences can lead to the same result, which is acceptable
let expected = vec![258, 100, 258, 97, 99];
assert_eq!(encoded, expected, "Final encoding should match Python reference");
assert_eq!(tokenizer.merges.len(), 3, "Should have exactly 3 merges");
// We should at least have the first merge (97, 97) -> 256 which is unambiguous
assert_eq!(tokenizer.merges.get(&(97, 97)), Some(&256), "First merge should be (97, 97) -> 256");
}
#[test]
fn test_comparison_hello_world() {
// Test case: "hello world" with vocab_size = 300
// Both implementations should compress "hello world" to 2 tokens with 9 merges
let text = "hello world";
let vocab_size = 300;
let tokenizer = train_simple_tokenizer(text, vocab_size);
let encoded = tokenizer.encode(text);
// The key achievement: compress "hello world" (11 chars) to 2 tokens
assert_eq!(encoded.len(), 2, "Should compress 'hello world' to 2 tokens");
assert_eq!(tokenizer.merges.len(), 9, "Should have exactly 9 merges");
// Both tokens should be > 255 (indicating they're merged tokens)
assert!(encoded[0] > 255, "First token should be a merged token");
assert!(encoded[1] > 255, "Second token should be a merged token");
}
#[test]
fn test_comparison_minbpe_wikipedia_example() {
// Test the exact example from minbpe Wikipedia: "aaabdaaabac"
// This should produce the same result as minbpe's BasicTokenizer
let text = "aaabdaaabac";
let vocab_size = 259; // 256 + 3 merges
let tokenizer = train_simple_tokenizer(text, vocab_size);
let encoded = tokenizer.encode(text);
// According to Wikipedia, this should compress to something like "XdXac"
// where X=ZY, Y=ab, Z=aa. In our token IDs:
// a=97, b=98, c=99, d=100
// Z=(97,97) -> 256, Y=(256,98) -> 257, X=(257,97) -> 258
// Result: [258, 100, 258, 97, 99]
let expected = vec![258, 100, 258, 97, 99];
assert_eq!(encoded, expected);
}
#[test]
fn test_comparison_deterministic_merges() {
// Test that merges are deterministic and match expected pattern
let text = "aaabdaaabac";
let vocab_size = 259;
let tokenizer1 = train_simple_tokenizer(text, vocab_size);
let tokenizer2 = train_simple_tokenizer(text, vocab_size);
// Should produce identical merges
assert_eq!(tokenizer1.merges, tokenizer2.merges);
// Check the specific merge order
let merges_vec: Vec<_> = tokenizer1.merges.iter().collect();
assert_eq!(merges_vec.len(), 3);
// First merge should be (97, 97) -> 256 (most frequent "aa")
assert!(tokenizer1.merges.contains_key(&(97, 97)));
assert_eq!(tokenizer1.merges[&(97, 97)], 256);
}
#[test]
fn test_comparison_round_trip_consistency() {
// Test that encoding is consistent across multiple runs
let text = "hello world test round trip";
let vocab_size = 350;
let tokenizer1 = train_simple_tokenizer(text, vocab_size);
let tokenizer2 = train_simple_tokenizer(text, vocab_size);
let encoded1 = tokenizer1.encode(text);
let encoded2 = tokenizer2.encode(text);
assert_eq!(encoded1, encoded2);
assert_eq!(tokenizer1.merges, tokenizer2.merges);
}
#[test]
fn test_comparison_empty_and_single_char() {
// Test edge cases that should match Python behavior
let empty_tokenizer = create_test_tokenizer();
let empty_encoded = empty_tokenizer.encode("");
assert_eq!(empty_encoded, Vec::<u32>::new());
let single_char = "A";
let single_encoded = empty_tokenizer.encode(single_char);
assert_eq!(single_encoded, vec![65]); // ASCII 'A'
}
#[test]
fn test_comparison_unicode_handling() {
// Test that Unicode is handled consistently
let unicode_text = "Hello 世界 🚀";
let tokenizer = train_simple_tokenizer(unicode_text, 300);
let encoded = tokenizer.encode(unicode_text);
assert!(!encoded.is_empty());
// Should be able to encode the same text multiple times
let encoded2 = tokenizer.encode(unicode_text);
assert_eq!(encoded, encoded2);
}
}
#[pymodule]
fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> {
pyo3_log::init(); // forwards Rust `log` to Python's `logging`