diff --git a/rustbpe/src/lib.rs b/rustbpe/src/lib.rs index b43fb6c..04c2298 100644 --- a/rustbpe/src/lib.rs +++ b/rustbpe/src/lib.rs @@ -468,6 +468,651 @@ impl Tokenizer { } } +// ------------------------ Rust Unit Tests ------------------------ + +#[cfg(test)] +mod tests { + use super::*; + + + /// Helper function to create a simple tokenizer for testing + fn create_test_tokenizer() -> Tokenizer { + Tokenizer { + merges: StdHashMap::new(), + pattern: GPT4_PATTERN.to_string(), + compiled_pattern: Regex::new(GPT4_PATTERN).unwrap(), + } + } + + /// Helper function to train tokenizer on simple text + fn train_simple_tokenizer(text: &str, vocab_size: u32) -> Tokenizer { + let mut tokenizer = create_test_tokenizer(); + + // Convert text to words and counts (simplified version) + let text_chunks: Vec<&str> = tokenizer.compiled_pattern.find_iter(text) + .filter_map(|m| m.ok()) + .map(|m| m.as_str()) + .collect(); + + let mut counts: AHashMap = AHashMap::new(); + for chunk in text_chunks { + *counts.entry(CompactString::from(chunk)).or_default() += 1; + } + + let mut words = Vec::with_capacity(counts.len()); + let mut cvec = Vec::with_capacity(counts.len()); + for (chunk, c) in counts.into_iter() { + words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect())); + cvec.push(c); + } + + tokenizer.train_core_incremental(words, cvec, vocab_size); + tokenizer + } + + #[test] + fn test_word_creation() { + let ids = vec![72, 101, 108, 108, 111]; // "Hello" in ASCII + let word = Word::new(ids.clone()); + + assert_eq!(word.ids, ids); + } + + #[test] + fn test_word_pairs() { + let ids = vec![72, 101, 108, 108, 111]; // "Hello" + let word = Word::new(ids); + + let pairs: Vec = word.pairs().collect(); + assert_eq!(pairs, vec![(72, 101), (101, 108), (108, 108), (108, 111)]); + } + + #[test] + fn test_word_pairs_empty() { + let word = Word::new(vec![]); + let pairs: Vec = word.pairs().collect(); + assert_eq!(pairs, vec![]); + } + + #[test] + fn test_word_pairs_single() { + let word = Word::new(vec![72]); + let pairs: Vec = word.pairs().collect(); + assert_eq!(pairs, vec![]); + } + + #[test] + fn test_merge_pair_simple() { + let mut word = Word::new(vec![65, 65, 66]); // "AAB" + let pair = (65, 65); // "AA" + let new_id = 256; + + let deltas = word.merge_pair(pair, new_id); + + assert_eq!(word.ids, vec![256, 66]); // Should be [new_id, 66] + + // Check deltas: removed (65,65), removed (65,66), added (256,66) + assert_eq!(deltas.len(), 3); + assert!(deltas.contains(&(pair, -1))); // removed (65,65) + assert!(deltas.contains(&((65, 66), -1))); // removed (65,66) + assert!(deltas.contains(&((256, 66), 1))); // added (256,66) + } + + #[test] + fn test_merge_pair_no_match() { + let mut word = Word::new(vec![65, 66, 67]); // "ABC" + let pair = (68, 69); // "DE" - not in word + let new_id = 256; + + let _deltas = word.merge_pair(pair, new_id); + + assert_eq!(word.ids, vec![65, 66, 67]); // Should be unchanged + } + + #[test] + fn test_merge_pair_multiple_occurrences() { + let mut word = Word::new(vec![65, 65, 65, 65]); // "AAAA" + let pair = (65, 65); // "AA" + let new_id = 256; + + let deltas = word.merge_pair(pair, new_id); + + assert_eq!(word.ids, vec![256, 256]); // Should be [new_id, new_id] + + // Should have removed 3 pairs of (65,65) and added 1 pair of (256,256) + let delta_sum: i32 = deltas.iter().map(|(_, delta)| *delta).sum(); + assert_eq!(delta_sum, -2); // 3 removed, 1 added = -2 + } + + #[test] + fn test_merge_pair_overlapping() { + let mut word = Word::new(vec![65, 65, 65]); // "AAA" + let pair = (65, 65); // "AA" + let new_id = 256; + + let _deltas = word.merge_pair(pair, new_id); + + assert_eq!(word.ids, vec![256, 65]); // Should be [new_id, 65] - non-overlapping merges only + } + + #[test] + fn test_merge_job_ordering() { + let job1 = MergeJob { + pair: (1, 2), + count: 10, + pos: AHashSet::new(), + }; + + let job2 = MergeJob { + pair: (3, 4), + count: 5, + pos: AHashSet::new(), + }; + + let job3 = MergeJob { + pair: (1, 2), + count: 10, + pos: AHashSet::new(), + }; + + // Test equality + assert_eq!(job1, job3); + assert_ne!(job1, job2); + + // Test ordering (max-heap by count) + assert!(job1 > job2); // Higher count should be "greater" + + // Test tie-breaking by pair (ascending order for max-heap) + let job4 = MergeJob { + pair: (2, 3), + count: 10, + pos: AHashSet::new(), + }; + + // For max-heap with tie-breaking, lower pair should be "greater" + assert!(job1 > job4); // Same count, but (1,2) < (2,3) so job1 > job4 + } + + #[test] + fn test_count_pairs_parallel() { + let words = vec![ + Word::new(vec![1, 2, 3]), + Word::new(vec![2, 3, 4]), + Word::new(vec![1, 2, 3]), + ]; + let counts = vec![1, 1, 2]; // Different weights for each word + + let (pair_counts, positions) = count_pairs_parallel(&words, &counts); + + // Expected pairs and their counts: + // Word 0 (count=1): (1,2), (2,3) + // Word 1 (count=1): (2,3), (3,4) + // Word 2 (count=2): (1,2), (2,3) with weight 2 + + assert_eq!(pair_counts.get(&(1, 2)), Some(&3)); // 1 + 2 + assert_eq!(pair_counts.get(&(2, 3)), Some(&4)); // 1 + 1 + 2 + assert_eq!(pair_counts.get(&(3, 4)), Some(&1)); // 1 + + // Check positions + assert!(positions.get(&(1, 2)).unwrap().contains(&0)); + assert!(positions.get(&(1, 2)).unwrap().contains(&2)); + assert!(positions.get(&(2, 3)).unwrap().contains(&0)); + assert!(positions.get(&(2, 3)).unwrap().contains(&1)); + assert!(positions.get(&(2, 3)).unwrap().contains(&2)); + assert!(positions.get(&(3, 4)).unwrap().contains(&1)); + } + + #[test] + fn test_count_pairs_parallel_empty() { + let words = vec![]; + let counts = vec![]; + + let (pair_counts, positions) = count_pairs_parallel(&words, &counts); + + assert_eq!(pair_counts.len(), 0); + assert_eq!(positions.len(), 0); + } + + #[test] + fn test_tokenizer_creation() { + let tokenizer = create_test_tokenizer(); + + assert_eq!(tokenizer.pattern, GPT4_PATTERN); + assert_eq!(tokenizer.merges.len(), 0); + } + + #[test] + fn test_tokenizer_get_pattern() { + let tokenizer = create_test_tokenizer(); + assert_eq!(tokenizer.get_pattern(), GPT4_PATTERN); + } + + #[test] + fn test_tokenizer_train_minimum_vocab() { + let mut tokenizer = create_test_tokenizer(); + let text = "hello"; + + // Train with minimum vocab size (no merges) + let words = vec![Word::new(text.as_bytes().iter().map(|&b| b as u32).collect())]; + let counts = vec![1]; + + tokenizer.train_core_incremental(words, counts, 256); + + assert_eq!(tokenizer.merges.len(), 0); // No merges should occur + } + + #[test] + fn test_tokenizer_train_simple() { + let text = "aaabdaaabac"; + let tokenizer = train_simple_tokenizer(text, 256 + 3); + + assert_eq!(tokenizer.merges.len(), 3); // Should have 3 merges + + // Check that we can encode the training text + let encoded = tokenizer.encode(text); + assert!(!encoded.is_empty()); + } + + #[test] + fn test_tokenizer_train_repetitive_text() { + let text = "hello world hello world hello world"; + let tokenizer = train_simple_tokenizer(text, 300); + + assert!(tokenizer.merges.len() > 0); + + let encoded = tokenizer.encode(text); + assert!(!encoded.is_empty()); + + // Encoding the same text twice should give same result + let encoded2 = tokenizer.encode(text); + assert_eq!(encoded, encoded2); + } + + #[test] + fn test_tokenizer_encode_empty() { + let tokenizer = create_test_tokenizer(); + let encoded = tokenizer.encode(""); + assert_eq!(encoded, Vec::::new()); + } + + #[test] + fn test_tokenizer_encode_single_char() { + let tokenizer = create_test_tokenizer(); + let encoded = tokenizer.encode("A"); + assert_eq!(encoded, vec![65]); // ASCII 'A' + } + + #[test] + fn test_tokenizer_encode_ascii() { + let tokenizer = create_test_tokenizer(); + let text = "Hello"; + let encoded = tokenizer.encode(text); + assert_eq!(encoded, vec![72, 101, 108, 108, 111]); // ASCII values + } + + #[test] + fn test_tokenizer_encode_with_merges() { + let mut tokenizer = create_test_tokenizer(); + + // Manually add a merge for "aa" -> 256 + tokenizer.merges.insert((97, 97), 256); // 'a' = 97 in ASCII + + let encoded = tokenizer.encode("aa"); + assert_eq!(encoded, vec![256]); // Should use the merged token + } + + #[test] + fn test_tokenizer_encode_multiple_merges() { + let mut tokenizer = create_test_tokenizer(); + + // Add merges: "aa" -> 256, "bb" -> 257 + tokenizer.merges.insert((97, 97), 256); + tokenizer.merges.insert((98, 98), 257); + + let encoded = tokenizer.encode("aabb"); + assert_eq!(encoded, vec![256, 257]); // Should use both merged tokens + } + + #[test] + fn test_tokenizer_encode_priority() { + let mut tokenizer = create_test_tokenizer(); + + // Add merges with different priorities (lower ID = higher priority) + tokenizer.merges.insert((97, 98), 257); // "ab" -> 257 + tokenizer.merges.insert((98, 99), 256); // "bc" -> 256 (higher priority) + + let encoded = tokenizer.encode("abc"); + // Should merge "bc" first (higher priority), then "a" + result + assert_eq!(encoded, vec![97, 256]); // "a" + merged "bc" + } + + #[test] + fn test_tokenizer_encode_complex_text() { + let text = "Hello, world! 123"; + let tokenizer = train_simple_tokenizer(text, 300); + + let encoded = tokenizer.encode(text); + assert!(!encoded.is_empty()); + + // Should be able to encode and decode back (approximately) + // Note: We don't have decode in Rust, so we just check it's reasonable + assert!(encoded.len() <= text.len() * 2); // Reasonable upper bound + } + + #[test] + fn test_tokenizer_get_mergeable_ranks_basic() { + let tokenizer = create_test_tokenizer(); + let ranks = tokenizer.get_mergeable_ranks(); + + // Should have exactly 256 entries (just the base bytes) + assert_eq!(ranks.len(), 256); + + // Check first few entries + for i in 0..256 { + assert_eq!(ranks[i], (vec![i as u8], i as u32)); + } + } + + #[test] + fn test_tokenizer_get_mergeable_ranks_with_merges() { + let mut tokenizer = create_test_tokenizer(); + + // Add some merges + tokenizer.merges.insert((97, 97), 256); // "aa" -> 256 + tokenizer.merges.insert((98, 99), 257); // "bc" -> 257 + + let ranks = tokenizer.get_mergeable_ranks(); + + // Should have base 256 + 2 merges + assert_eq!(ranks.len(), 258); + + // Check base bytes are still correct + for i in 0..256 { + assert_eq!(ranks[i], (vec![i as u8], i as u32)); + } + + // Check merge tokens + assert_eq!(ranks[256], (vec![97, 97], 256)); // "aa" + assert_eq!(ranks[257], (vec![98, 99], 257)); // "bc" + } + + #[test] + fn test_tokenizer_get_mergeable_ranks_complex_merges() { + let mut tokenizer = create_test_tokenizer(); + + // Create a chain of merges: "aa" -> 256, then "aa" + "a" -> 257 + tokenizer.merges.insert((97, 97), 256); // "aa" -> 256 + tokenizer.merges.insert((256, 97), 257); // "aa" + "a" -> 257 + + let ranks = tokenizer.get_mergeable_ranks(); + + assert_eq!(ranks.len(), 258); + assert_eq!(ranks[256], (vec![97, 97], 256)); // "aa" + assert_eq!(ranks[257], (vec![97, 97, 97], 257)); // "aaa" + } + + #[test] + fn test_unicode_handling() { + let tokenizer = create_test_tokenizer(); + + // Test with Unicode characters + let unicode_text = "Hello δΈ–η•Œ πŸš€"; + let encoded = tokenizer.encode(unicode_text); + + assert!(!encoded.is_empty()); + // Each Unicode character should be encoded as one or more bytes + assert!(encoded.len() >= unicode_text.len()); + } + + #[test] + fn test_deterministic_training() { + let text = "hello world test"; + + let tokenizer1 = train_simple_tokenizer(text, 280); + let tokenizer2 = train_simple_tokenizer(text, 280); + + // Should produce identical results + assert_eq!(tokenizer1.merges, tokenizer2.merges); + + let encoded1 = tokenizer1.encode(text); + let encoded2 = tokenizer2.encode(text); + assert_eq!(encoded1, encoded2); + } + + #[test] + fn test_training_with_different_vocab_sizes() { + let text = "hello world"; + + let tokenizer256 = train_simple_tokenizer(text, 256); + let tokenizer300 = train_simple_tokenizer(text, 300); + + assert_eq!(tokenizer256.merges.len(), 0); + assert!(tokenizer300.merges.len() > 0); + } + + #[test] + fn test_edge_case_empty_training() { + let mut tokenizer = create_test_tokenizer(); + + // Train with empty data + let words = vec![]; + let counts = vec![]; + + tokenizer.train_core_incremental(words, counts, 300); + + assert_eq!(tokenizer.merges.len(), 0); + } + + #[test] + fn test_edge_case_single_word() { + let mut tokenizer = create_test_tokenizer(); + + let words = vec![Word::new(vec![72, 101, 108, 108, 111])]; // "Hello" + let counts = vec![1]; + + tokenizer.train_core_incremental(words, counts, 260); + + // Should have some merges + assert!(tokenizer.merges.len() > 0); + } + + #[test] + fn test_regex_pattern_compilation() { + let tokenizer = create_test_tokenizer(); + + // Test that the pattern compiles and works + let test_text = "Hello, world! 123"; + let matches: Vec<&str> = tokenizer.compiled_pattern.find_iter(test_text) + .filter_map(|m| m.ok()) + .map(|m| m.as_str()) + .collect(); + + assert!(!matches.is_empty()); + assert!(matches.iter().any(|&m| m.contains("Hello"))); + } + + #[test] + fn test_merge_job_partial_ord() { + let job1 = MergeJob { + pair: (1, 2), + count: 10, + pos: AHashSet::new(), + }; + + let job2 = MergeJob { + pair: (3, 4), + count: 15, + pos: AHashSet::new(), + }; + + // Test partial ordering + assert!(job1 < job2); // Lower count should be less + assert!(job2 > job1); // Higher count should be greater + } + + #[test] + fn test_large_text_training() { + // Create a larger text by repetition + let base_text = "The quick brown fox jumps over the lazy dog. "; + let large_text = base_text.repeat(100); + + let tokenizer = train_simple_tokenizer(&large_text, 500); + + assert!(tokenizer.merges.len() > 0); + + let encoded = tokenizer.encode(&large_text); + assert!(!encoded.is_empty()); + + // Should be more efficient than raw bytes due to merges + assert!(encoded.len() < large_text.len()); + } + + #[test] + fn test_special_characters() { + let tokenizer = create_test_tokenizer(); + + let special_text = "!@#$%^&*()_+-=[]{}|;':\",./<>?"; + let encoded = tokenizer.encode(special_text); + + assert!(!encoded.is_empty()); + assert_eq!(encoded.len(), special_text.len()); // Each char should be one byte + } + + #[test] + fn test_whitespace_handling() { + let tokenizer = create_test_tokenizer(); + + let whitespace_text = " \t\n\r "; + let encoded = tokenizer.encode(whitespace_text); + + assert!(!encoded.is_empty()); + // Should handle all whitespace characters + } + + // ------------------------ Comparison Tests with Python Reference ------------------------ + + #[test] + fn test_comparison_simple_case() { + // Test case: "aaabdaaabac" with vocab_size = 259 (256 + 3 merges) + // Expected Python result: [258, 100, 258, 97, 99] with 3 merges + let text = "aaabdaaabac"; + let vocab_size = 259; + + let tokenizer = train_simple_tokenizer(text, vocab_size); + let encoded = tokenizer.encode(text); + + // The most important thing: we should get the same final encoding + // Different merge sequences can lead to the same result, which is acceptable + let expected = vec![258, 100, 258, 97, 99]; + assert_eq!(encoded, expected, "Final encoding should match Python reference"); + assert_eq!(tokenizer.merges.len(), 3, "Should have exactly 3 merges"); + + // We should at least have the first merge (97, 97) -> 256 which is unambiguous + assert_eq!(tokenizer.merges.get(&(97, 97)), Some(&256), "First merge should be (97, 97) -> 256"); + } + + #[test] + fn test_comparison_hello_world() { + // Test case: "hello world" with vocab_size = 300 + // Both implementations should compress "hello world" to 2 tokens with 9 merges + let text = "hello world"; + let vocab_size = 300; + + let tokenizer = train_simple_tokenizer(text, vocab_size); + let encoded = tokenizer.encode(text); + + // The key achievement: compress "hello world" (11 chars) to 2 tokens + assert_eq!(encoded.len(), 2, "Should compress 'hello world' to 2 tokens"); + assert_eq!(tokenizer.merges.len(), 9, "Should have exactly 9 merges"); + + // Both tokens should be > 255 (indicating they're merged tokens) + assert!(encoded[0] > 255, "First token should be a merged token"); + assert!(encoded[1] > 255, "Second token should be a merged token"); + } + + #[test] + fn test_comparison_minbpe_wikipedia_example() { + // Test the exact example from minbpe Wikipedia: "aaabdaaabac" + // This should produce the same result as minbpe's BasicTokenizer + let text = "aaabdaaabac"; + let vocab_size = 259; // 256 + 3 merges + + let tokenizer = train_simple_tokenizer(text, vocab_size); + let encoded = tokenizer.encode(text); + + // According to Wikipedia, this should compress to something like "XdXac" + // where X=ZY, Y=ab, Z=aa. In our token IDs: + // a=97, b=98, c=99, d=100 + // Z=(97,97) -> 256, Y=(256,98) -> 257, X=(257,97) -> 258 + // Result: [258, 100, 258, 97, 99] + let expected = vec![258, 100, 258, 97, 99]; + assert_eq!(encoded, expected); + } + + #[test] + fn test_comparison_deterministic_merges() { + // Test that merges are deterministic and match expected pattern + let text = "aaabdaaabac"; + let vocab_size = 259; + + let tokenizer1 = train_simple_tokenizer(text, vocab_size); + let tokenizer2 = train_simple_tokenizer(text, vocab_size); + + // Should produce identical merges + assert_eq!(tokenizer1.merges, tokenizer2.merges); + + // Check the specific merge order + let merges_vec: Vec<_> = tokenizer1.merges.iter().collect(); + assert_eq!(merges_vec.len(), 3); + + // First merge should be (97, 97) -> 256 (most frequent "aa") + assert!(tokenizer1.merges.contains_key(&(97, 97))); + assert_eq!(tokenizer1.merges[&(97, 97)], 256); + } + + #[test] + fn test_comparison_round_trip_consistency() { + // Test that encoding is consistent across multiple runs + let text = "hello world test round trip"; + let vocab_size = 350; + + let tokenizer1 = train_simple_tokenizer(text, vocab_size); + let tokenizer2 = train_simple_tokenizer(text, vocab_size); + + let encoded1 = tokenizer1.encode(text); + let encoded2 = tokenizer2.encode(text); + + assert_eq!(encoded1, encoded2); + assert_eq!(tokenizer1.merges, tokenizer2.merges); + } + + #[test] + fn test_comparison_empty_and_single_char() { + // Test edge cases that should match Python behavior + let empty_tokenizer = create_test_tokenizer(); + let empty_encoded = empty_tokenizer.encode(""); + assert_eq!(empty_encoded, Vec::::new()); + + let single_char = "A"; + let single_encoded = empty_tokenizer.encode(single_char); + assert_eq!(single_encoded, vec![65]); // ASCII 'A' + } + + #[test] + fn test_comparison_unicode_handling() { + // Test that Unicode is handled consistently + let unicode_text = "Hello δΈ–η•Œ πŸš€"; + let tokenizer = train_simple_tokenizer(unicode_text, 300); + + let encoded = tokenizer.encode(unicode_text); + assert!(!encoded.is_empty()); + + // Should be able to encode the same text multiple times + let encoded2 = tokenizer.encode(unicode_text); + assert_eq!(encoded, encoded2); + } +} + #[pymodule] fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> { pyo3_log::init(); // forwards Rust `log` to Python's `logging`