diff --git a/rustbpe/Cargo.lock b/rustbpe/Cargo.lock index 69f8754..80dbd7f 100644 --- a/rustbpe/Cargo.lock +++ b/rustbpe/Cargo.lock @@ -230,11 +230,10 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.23.5" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872" +checksum = "7ba0117f4212101ee6544044dae45abe1083d30ce7b29c4b5cbdfa2354e07383" dependencies = [ - "cfg-if", "indoc", "libc", "memoffset", @@ -248,19 +247,18 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.23.5" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb" +checksum = "4fc6ddaf24947d12a9aa31ac65431fb1b851b8f4365426e182901eabfb87df5f" dependencies = [ - "once_cell", "target-lexicon", ] [[package]] name = "pyo3-ffi" -version = "0.23.5" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d" +checksum = "025474d3928738efb38ac36d4744a74a400c901c7596199e20e45d98eb194105" dependencies = [ "libc", "pyo3-build-config", @@ -268,9 +266,9 @@ dependencies = [ [[package]] name = "pyo3-log" -version = "0.12.4" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264" +checksum = "d359e20231345f21a3b5b6aea7e73f4dc97e1712ef3bfe2d88997ac6a308d784" dependencies = [ "arc-swap", "log", @@ -279,9 +277,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.23.5" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da" +checksum = "2e64eb489f22fe1c95911b77c44cc41e7c19f3082fc81cce90f657cdc42ffded" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -291,9 +289,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.23.5" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028" +checksum = "100246c0ecf400b475341b8455a9213344569af29a3c841d29270e53102e0fcf" dependencies = [ "heck", "proc-macro2", @@ -400,9 +398,9 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.16" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +checksum = "df7f62577c25e07834649fc3b39fafdc597c0a3527dc1c60129201ccfcbaa50c" [[package]] name = "unicode-ident" diff --git a/rustbpe/Cargo.toml b/rustbpe/Cargo.toml index 392a828..3e8efd3 100644 --- a/rustbpe/Cargo.toml +++ b/rustbpe/Cargo.toml @@ -8,8 +8,8 @@ dary_heap = "0.3" indexmap = "2.2" fancy-regex = "0.16.1" log = "0.4.28" -pyo3 = { version = "0.23.3", features = ["extension-module"] } -pyo3-log = "0.12.4" +pyo3 = { version = "0.26", features = ["extension-module"] } +pyo3-log = "0.13" ahash = "0.8.12" rayon = "1.11.0" compact_str = "0.9.0" diff --git a/rustbpe/src/lib.rs b/rustbpe/src/lib.rs index 273d7f2..81d6670 100644 --- a/rustbpe/src/lib.rs +++ b/rustbpe/src/lib.rs @@ -277,7 +277,7 @@ impl Tokenizer { pub fn train_from_iterator( &mut self, py: pyo3::Python<'_>, - iterator: &pyo3::Bound<'_, pyo3::PyAny>, + iterator: pyo3::Py, vocab_size: u32, buffer_size: usize, pattern: Option, @@ -290,11 +290,6 @@ impl Tokenizer { self.compiled_pattern = Regex::new(&pattern_str) .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?; - // Prepare a true Python iterator object - let py_iter: pyo3::Py = unsafe { - pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))? - }; - // Global chunk counts let mut counts: AHashMap = AHashMap::new(); @@ -307,28 +302,21 @@ impl Tokenizer { // Helper: refill `buf` with up to `buffer_size` strings from the Python iterator. // Returns Ok(true) if the iterator is exhausted, Ok(false) otherwise. let refill = |buf: &mut Vec| -> PyResult { - pyo3::Python::with_gil(|py| { + pyo3::Python::attach(|py| { buf.clear(); - let it = py_iter.bind(py); + let mut it = iterator.bind(py).clone(); loop { if buf.len() >= buffer_size { return Ok(false); } - // next(it) - let next_obj = unsafe { - pyo3::Bound::from_owned_ptr_or_opt(py, pyo3::ffi::PyIter_Next(it.as_ptr())) - }; - match next_obj { - Some(obj) => { + match it.next() { + Some(Ok(obj)) => { let s: String = obj.extract()?; buf.push(s); - } + }, + Some(Err(e)) => return Err(e), None => { - if pyo3::PyErr::occurred(py) { - return Err(pyo3::PyErr::fetch(py)); - } else { - return Ok(true); // exhausted - } + return Ok(true); // exhausted } } } @@ -345,7 +333,7 @@ impl Tokenizer { total_sequences += buf.len() as u64; let pattern = self.compiled_pattern.clone(); - let local: AHashMap = py.allow_threads(|| { + let local: AHashMap = py.detach(|| { buf.par_iter() .map(|s| { let mut m: AHashMap = AHashMap::new();