This commit is contained in:
Alex Gaynor 2025-11-19 11:42:12 -05:00 committed by GitHub
commit 4450ddf07e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 39 deletions

30
rustbpe/Cargo.lock generated
View File

@ -230,11 +230,10 @@ dependencies = [
[[package]]
name = "pyo3"
version = "0.23.5"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
checksum = "7ba0117f4212101ee6544044dae45abe1083d30ce7b29c4b5cbdfa2354e07383"
dependencies = [
"cfg-if",
"indoc",
"libc",
"memoffset",
@ -248,19 +247,18 @@ dependencies = [
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
checksum = "4fc6ddaf24947d12a9aa31ac65431fb1b851b8f4365426e182901eabfb87df5f"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.23.5"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
checksum = "025474d3928738efb38ac36d4744a74a400c901c7596199e20e45d98eb194105"
dependencies = [
"libc",
"pyo3-build-config",
@ -268,9 +266,9 @@ dependencies = [
[[package]]
name = "pyo3-log"
version = "0.12.4"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264"
checksum = "d359e20231345f21a3b5b6aea7e73f4dc97e1712ef3bfe2d88997ac6a308d784"
dependencies = [
"arc-swap",
"log",
@ -279,9 +277,9 @@ dependencies = [
[[package]]
name = "pyo3-macros"
version = "0.23.5"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
checksum = "2e64eb489f22fe1c95911b77c44cc41e7c19f3082fc81cce90f657cdc42ffded"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
@ -291,9 +289,9 @@ dependencies = [
[[package]]
name = "pyo3-macros-backend"
version = "0.23.5"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
checksum = "100246c0ecf400b475341b8455a9213344569af29a3c841d29270e53102e0fcf"
dependencies = [
"heck",
"proc-macro2",
@ -400,9 +398,9 @@ dependencies = [
[[package]]
name = "target-lexicon"
version = "0.12.16"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
checksum = "df7f62577c25e07834649fc3b39fafdc597c0a3527dc1c60129201ccfcbaa50c"
[[package]]
name = "unicode-ident"

View File

@ -8,8 +8,8 @@ dary_heap = "0.3"
indexmap = "2.2"
fancy-regex = "0.16.1"
log = "0.4.28"
pyo3 = { version = "0.23.3", features = ["extension-module"] }
pyo3-log = "0.12.4"
pyo3 = { version = "0.26", features = ["extension-module"] }
pyo3-log = "0.13"
ahash = "0.8.12"
rayon = "1.11.0"
compact_str = "0.9.0"

View File

@ -277,7 +277,7 @@ impl Tokenizer {
pub fn train_from_iterator(
&mut self,
py: pyo3::Python<'_>,
iterator: &pyo3::Bound<'_, pyo3::PyAny>,
iterator: pyo3::Py<pyo3::types::PyIterator>,
vocab_size: u32,
buffer_size: usize,
pattern: Option<String>,
@ -290,11 +290,6 @@ impl Tokenizer {
self.compiled_pattern = Regex::new(&pattern_str)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?;
// Prepare a true Python iterator object
let py_iter: pyo3::Py<pyo3::PyAny> = unsafe {
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
};
// Global chunk counts
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
@ -307,28 +302,21 @@ impl Tokenizer {
// Helper: refill `buf` with up to `buffer_size` strings from the Python iterator.
// Returns Ok(true) if the iterator is exhausted, Ok(false) otherwise.
let refill = |buf: &mut Vec<String>| -> PyResult<bool> {
pyo3::Python::with_gil(|py| {
pyo3::Python::attach(|py| {
buf.clear();
let it = py_iter.bind(py);
let mut it = iterator.bind(py).clone();
loop {
if buf.len() >= buffer_size {
return Ok(false);
}
// next(it)
let next_obj = unsafe {
pyo3::Bound::from_owned_ptr_or_opt(py, pyo3::ffi::PyIter_Next(it.as_ptr()))
};
match next_obj {
Some(obj) => {
match it.next() {
Some(Ok(obj)) => {
let s: String = obj.extract()?;
buf.push(s);
}
},
Some(Err(e)) => return Err(e),
None => {
if pyo3::PyErr::occurred(py) {
return Err(pyo3::PyErr::fetch(py));
} else {
return Ok(true); // exhausted
}
return Ok(true); // exhausted
}
}
}
@ -345,7 +333,7 @@ impl Tokenizer {
total_sequences += buf.len() as u64;
let pattern = self.compiled_pattern.clone();
let local: AHashMap<CompactString, i32> = py.allow_threads(|| {
let local: AHashMap<CompactString, i32> = py.detach(|| {
buf.par_iter()
.map(|s| {
let mut m: AHashMap<CompactString, i32> = AHashMap::new();