Compare commits

...

3 Commits

Author SHA1 Message Date
Alex Gaynor
81f6a358cb
Merge 09f1f4283e into bc1fca39f3 2025-11-16 17:51:00 -08:00
Alex Gaynor
09f1f4283e
Merge branch 'master' into cleanup 2025-10-29 07:41:54 -04:00
Alex Gaynor
18590307ae Upgrade to pyo3 0.26, fix warnings, remove unsafe usage 2025-10-13 12:45:30 -04:00
3 changed files with 25 additions and 39 deletions

30
rustbpe/Cargo.lock generated
View File

@ -230,11 +230,10 @@ dependencies = [
[[package]] [[package]]
name = "pyo3" name = "pyo3"
version = "0.23.5" version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872" checksum = "7ba0117f4212101ee6544044dae45abe1083d30ce7b29c4b5cbdfa2354e07383"
dependencies = [ dependencies = [
"cfg-if",
"indoc", "indoc",
"libc", "libc",
"memoffset", "memoffset",
@ -248,19 +247,18 @@ dependencies = [
[[package]] [[package]]
name = "pyo3-build-config" name = "pyo3-build-config"
version = "0.23.5" version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb" checksum = "4fc6ddaf24947d12a9aa31ac65431fb1b851b8f4365426e182901eabfb87df5f"
dependencies = [ dependencies = [
"once_cell",
"target-lexicon", "target-lexicon",
] ]
[[package]] [[package]]
name = "pyo3-ffi" name = "pyo3-ffi"
version = "0.23.5" version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d" checksum = "025474d3928738efb38ac36d4744a74a400c901c7596199e20e45d98eb194105"
dependencies = [ dependencies = [
"libc", "libc",
"pyo3-build-config", "pyo3-build-config",
@ -268,9 +266,9 @@ dependencies = [
[[package]] [[package]]
name = "pyo3-log" name = "pyo3-log"
version = "0.12.4" version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264" checksum = "d359e20231345f21a3b5b6aea7e73f4dc97e1712ef3bfe2d88997ac6a308d784"
dependencies = [ dependencies = [
"arc-swap", "arc-swap",
"log", "log",
@ -279,9 +277,9 @@ dependencies = [
[[package]] [[package]]
name = "pyo3-macros" name = "pyo3-macros"
version = "0.23.5" version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da" checksum = "2e64eb489f22fe1c95911b77c44cc41e7c19f3082fc81cce90f657cdc42ffded"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-macros-backend", "pyo3-macros-backend",
@ -291,9 +289,9 @@ dependencies = [
[[package]] [[package]]
name = "pyo3-macros-backend" name = "pyo3-macros-backend"
version = "0.23.5" version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028" checksum = "100246c0ecf400b475341b8455a9213344569af29a3c841d29270e53102e0fcf"
dependencies = [ dependencies = [
"heck", "heck",
"proc-macro2", "proc-macro2",
@ -400,9 +398,9 @@ dependencies = [
[[package]] [[package]]
name = "target-lexicon" name = "target-lexicon"
version = "0.12.16" version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" checksum = "df7f62577c25e07834649fc3b39fafdc597c0a3527dc1c60129201ccfcbaa50c"
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"

View File

@ -8,8 +8,8 @@ dary_heap = "0.3"
indexmap = "2.2" indexmap = "2.2"
fancy-regex = "0.16.1" fancy-regex = "0.16.1"
log = "0.4.28" log = "0.4.28"
pyo3 = { version = "0.23.3", features = ["extension-module"] } pyo3 = { version = "0.26", features = ["extension-module"] }
pyo3-log = "0.12.4" pyo3-log = "0.13"
ahash = "0.8.12" ahash = "0.8.12"
rayon = "1.11.0" rayon = "1.11.0"
compact_str = "0.9.0" compact_str = "0.9.0"

View File

@ -277,7 +277,7 @@ impl Tokenizer {
pub fn train_from_iterator( pub fn train_from_iterator(
&mut self, &mut self,
py: pyo3::Python<'_>, py: pyo3::Python<'_>,
iterator: &pyo3::Bound<'_, pyo3::PyAny>, iterator: pyo3::Py<pyo3::types::PyIterator>,
vocab_size: u32, vocab_size: u32,
buffer_size: usize, buffer_size: usize,
pattern: Option<String>, pattern: Option<String>,
@ -290,11 +290,6 @@ impl Tokenizer {
self.compiled_pattern = Regex::new(&pattern_str) self.compiled_pattern = Regex::new(&pattern_str)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?; .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?;
// Prepare a true Python iterator object
let py_iter: pyo3::Py<pyo3::PyAny> = unsafe {
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
};
// Global chunk counts // Global chunk counts
let mut counts: AHashMap<CompactString, i32> = AHashMap::new(); let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
@ -307,28 +302,21 @@ impl Tokenizer {
// Helper: refill `buf` with up to `buffer_size` strings from the Python iterator. // Helper: refill `buf` with up to `buffer_size` strings from the Python iterator.
// Returns Ok(true) if the iterator is exhausted, Ok(false) otherwise. // Returns Ok(true) if the iterator is exhausted, Ok(false) otherwise.
let refill = |buf: &mut Vec<String>| -> PyResult<bool> { let refill = |buf: &mut Vec<String>| -> PyResult<bool> {
pyo3::Python::with_gil(|py| { pyo3::Python::attach(|py| {
buf.clear(); buf.clear();
let it = py_iter.bind(py); let mut it = iterator.bind(py).clone();
loop { loop {
if buf.len() >= buffer_size { if buf.len() >= buffer_size {
return Ok(false); return Ok(false);
} }
// next(it) match it.next() {
let next_obj = unsafe { Some(Ok(obj)) => {
pyo3::Bound::from_owned_ptr_or_opt(py, pyo3::ffi::PyIter_Next(it.as_ptr()))
};
match next_obj {
Some(obj) => {
let s: String = obj.extract()?; let s: String = obj.extract()?;
buf.push(s); buf.push(s);
} },
Some(Err(e)) => return Err(e),
None => { None => {
if pyo3::PyErr::occurred(py) { return Ok(true); // exhausted
return Err(pyo3::PyErr::fetch(py));
} else {
return Ok(true); // exhausted
}
} }
} }
} }
@ -345,7 +333,7 @@ impl Tokenizer {
total_sequences += buf.len() as u64; total_sequences += buf.len() as u64;
let pattern = self.compiled_pattern.clone(); let pattern = self.compiled_pattern.clone();
let local: AHashMap<CompactString, i32> = py.allow_threads(|| { let local: AHashMap<CompactString, i32> = py.detach(|| {
buf.par_iter() buf.par_iter()
.map(|s| { .map(|s| {
let mut m: AHashMap<CompactString, i32> = AHashMap::new(); let mut m: AHashMap<CompactString, i32> = AHashMap::new();