add into rustbpe

This commit is contained in:
MadMax129 2025-10-24 18:52:26 -04:00
parent 851810c7d5
commit 0a1059d571
3 changed files with 116 additions and 25 deletions

18
rustbpe/Cargo.lock generated
View File

@ -186,6 +186,16 @@ version = "0.2.175"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
[[package]]
name = "libloading"
version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55"
dependencies = [
"cfg-if",
"windows-link",
]
[[package]]
name = "log"
version = "0.4.28"
@ -363,7 +373,9 @@ dependencies = [
"dary_heap",
"fancy-regex",
"indexmap",
"libloading",
"log",
"once_cell",
"pyo3",
"pyo3-log",
"rayon",
@ -431,6 +443,12 @@ dependencies = [
"wit-bindgen",
]
[[package]]
name = "windows-link"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]]
name = "wit-bindgen"
version = "0.45.1"

View File

@ -13,3 +13,5 @@ pyo3-log = "0.12.4"
ahash = "0.8.12"
rayon = "1.11.0"
compact_str = "0.9.0"
libloading = "0.8.5"
once_cell = "1.19.0"

View File

@ -3,11 +3,12 @@ use std::collections::HashMap as StdHashMap;
use dary_heap::OctonaryHeap;
use fancy_regex::Regex;
use libloading::Library;
use once_cell::sync::OnceCell;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use ahash::{AHashMap, AHashSet};
use compact_str::CompactString;
use rayon::prelude::*;
// Default GPT-4 style regex pattern for splitting text
@ -15,6 +16,79 @@ const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{
type Pair = (u32, u32);
#[allow(non_camel_case_types)]
type c_char = std::os::raw::c_char;
#[repr(C)]
struct CTokenPos {
start: usize,
end: usize,
}
#[repr(C)]
struct CTokenList {
splits: *mut CTokenPos,
count: usize,
capacity: usize,
}
type FnTokenListInit = unsafe extern "C" fn(list: *mut CTokenList);
type FnTokenListFree = unsafe extern "C" fn(list: *mut CTokenList);
type FnTokenizeFast = unsafe extern "C" fn(input: *const c_char, input_len: usize, out: *mut CTokenList);
struct FregexSymbols {
_lib: Library,
tokenlist_init: FnTokenListInit,
tokenlist_free: FnTokenListFree,
tokenize_fast: FnTokenizeFast,
}
static FREGEX: OnceCell<FregexSymbols> = OnceCell::new();
fn load_fregex() -> &'static FregexSymbols {
FREGEX.get_or_init(|| {
// NOTE: adjust this path per user if needed.
let path = "fregex/libfregex.dylib";
let lib = unsafe { Library::new(path) }.unwrap_or_else(|e| {
panic!("Failed to load libfregex from {}: {}", path, e);
});
unsafe {
let tokenlist_init: FnTokenListInit = *lib.get(b"tokenlist_init\0").expect("symbol tokenlist_init");
let tokenlist_free: FnTokenListFree = *lib.get(b"tokenlist_free\0").expect("symbol tokenlist_free");
let tokenize_fast: FnTokenizeFast = *lib.get(b"tokenize_fast\0").expect("symbol tokenize_fast");
println!("rustbpe: loaded libfregex.dylib from {}", path);
FregexSymbols { _lib: lib, tokenlist_init, tokenlist_free, tokenize_fast }
}
})
}
fn tokenize_with_c_each<'a, F>(input: &'a [u8], mut on_piece: F)
where
F: FnMut(&'a [u8]),
{
if input.is_empty() {
return;
}
let syms = load_fregex();
let mut out = CTokenList { splits: std::ptr::null_mut(), count: 0, capacity: 0 };
let base_ptr = input.as_ptr() as usize;
unsafe {
(syms.tokenlist_init)(&mut out as *mut CTokenList);
(syms.tokenize_fast)(input.as_ptr() as *const c_char, input.len(), &mut out as *mut CTokenList);
if !out.splits.is_null() {
let slice = std::slice::from_raw_parts(out.splits, out.count);
for pos in slice.iter() {
let start = pos.start.saturating_sub(base_ptr);
let end = pos.end.saturating_sub(base_ptr);
if end <= input.len() && start <= end {
on_piece(&input[start..end]);
}
}
}
(syms.tokenlist_free)(&mut out as *mut CTokenList);
}
}
/// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
#[pyclass]
pub struct Tokenizer {
@ -296,8 +370,8 @@ impl Tokenizer {
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
};
// Global chunk counts
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
// Global chunk counts: own bytes once per unique chunk (no string copies)
let mut counts: AHashMap<Vec<u8>, i32> = AHashMap::new();
// Temporary buffer we refill under the GIL
let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
@ -345,31 +419,28 @@ impl Tokenizer {
total_sequences += buf.len() as u64;
let pattern = self.compiled_pattern.clone();
let local: AHashMap<CompactString, i32> = py.allow_threads(|| {
// Build per-string local counts that reference the buffer slices (no allocations)
let locals: Vec<Vec<(&[u8], i32)>> = py.allow_threads(|| {
buf.par_iter()
.map(|s| {
let mut m: AHashMap<CompactString, i32> = AHashMap::new();
for mat in pattern.find_iter(s) {
let piece = mat.expect("regex match failed").as_str();
*m.entry(CompactString::from(piece)).or_default() += 1;
}
m
let mut m: AHashMap<&[u8], i32> = AHashMap::new();
let bytes = s.as_bytes();
tokenize_with_c_each(bytes, |piece| { *m.entry(piece).or_default() += 1; });
// Materialize as Vec to allow merging after parallel section
m.into_iter().collect::<Vec<(&[u8], i32)>>()
})
.reduce(
|| AHashMap::new(),
|mut a, b| {
for (k, v) in b {
*a.entry(k).or_default() += v;
}
a
},
)
.collect()
});
// Merge local into global (single-threaded)
for (k, v) in local {
*counts.entry(k).or_default() += v;
// Merge locals into global (single-threaded) without copying unless inserting new keys
for local in locals {
for (piece, v) in local {
if let Some(cnt) = counts.get_mut(piece) {
*cnt += v;
} else {
counts.insert(piece.to_vec(), v);
}
}
}
if exhausted {
@ -381,8 +452,8 @@ impl Tokenizer {
// Materialize words & counts
let mut words = Vec::with_capacity(counts.len());
let mut cvec = Vec::with_capacity(counts.len());
for (chunk, c) in counts.into_iter() {
words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect()));
for (chunk_bytes, c) in counts.into_iter() {
words.push(Word::new(chunk_bytes.iter().map(|&b| b as u32).collect()));
cvec.push(c);
}