mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
add into rustbpe
This commit is contained in:
parent
851810c7d5
commit
0a1059d571
18
rustbpe/Cargo.lock
generated
18
rustbpe/Cargo.lock
generated
|
|
@ -186,6 +186,16 @@ version = "0.2.175"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
|
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libloading"
|
||||||
|
version = "0.8.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"windows-link",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "log"
|
name = "log"
|
||||||
version = "0.4.28"
|
version = "0.4.28"
|
||||||
|
|
@ -363,7 +373,9 @@ dependencies = [
|
||||||
"dary_heap",
|
"dary_heap",
|
||||||
"fancy-regex",
|
"fancy-regex",
|
||||||
"indexmap",
|
"indexmap",
|
||||||
|
"libloading",
|
||||||
"log",
|
"log",
|
||||||
|
"once_cell",
|
||||||
"pyo3",
|
"pyo3",
|
||||||
"pyo3-log",
|
"pyo3-log",
|
||||||
"rayon",
|
"rayon",
|
||||||
|
|
@ -431,6 +443,12 @@ dependencies = [
|
||||||
"wit-bindgen",
|
"wit-bindgen",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-link"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wit-bindgen"
|
name = "wit-bindgen"
|
||||||
version = "0.45.1"
|
version = "0.45.1"
|
||||||
|
|
|
||||||
|
|
@ -13,3 +13,5 @@ pyo3-log = "0.12.4"
|
||||||
ahash = "0.8.12"
|
ahash = "0.8.12"
|
||||||
rayon = "1.11.0"
|
rayon = "1.11.0"
|
||||||
compact_str = "0.9.0"
|
compact_str = "0.9.0"
|
||||||
|
libloading = "0.8.5"
|
||||||
|
once_cell = "1.19.0"
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,12 @@ use std::collections::HashMap as StdHashMap;
|
||||||
|
|
||||||
use dary_heap::OctonaryHeap;
|
use dary_heap::OctonaryHeap;
|
||||||
use fancy_regex::Regex;
|
use fancy_regex::Regex;
|
||||||
|
use libloading::Library;
|
||||||
|
use once_cell::sync::OnceCell;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::wrap_pyfunction;
|
use pyo3::wrap_pyfunction;
|
||||||
|
|
||||||
use ahash::{AHashMap, AHashSet};
|
use ahash::{AHashMap, AHashSet};
|
||||||
use compact_str::CompactString;
|
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
// Default GPT-4 style regex pattern for splitting text
|
// Default GPT-4 style regex pattern for splitting text
|
||||||
|
|
@ -15,6 +16,79 @@ const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{
|
||||||
|
|
||||||
type Pair = (u32, u32);
|
type Pair = (u32, u32);
|
||||||
|
|
||||||
|
#[allow(non_camel_case_types)]
|
||||||
|
type c_char = std::os::raw::c_char;
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct CTokenPos {
|
||||||
|
start: usize,
|
||||||
|
end: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct CTokenList {
|
||||||
|
splits: *mut CTokenPos,
|
||||||
|
count: usize,
|
||||||
|
capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
type FnTokenListInit = unsafe extern "C" fn(list: *mut CTokenList);
|
||||||
|
type FnTokenListFree = unsafe extern "C" fn(list: *mut CTokenList);
|
||||||
|
type FnTokenizeFast = unsafe extern "C" fn(input: *const c_char, input_len: usize, out: *mut CTokenList);
|
||||||
|
|
||||||
|
struct FregexSymbols {
|
||||||
|
_lib: Library,
|
||||||
|
tokenlist_init: FnTokenListInit,
|
||||||
|
tokenlist_free: FnTokenListFree,
|
||||||
|
tokenize_fast: FnTokenizeFast,
|
||||||
|
}
|
||||||
|
|
||||||
|
static FREGEX: OnceCell<FregexSymbols> = OnceCell::new();
|
||||||
|
|
||||||
|
fn load_fregex() -> &'static FregexSymbols {
|
||||||
|
FREGEX.get_or_init(|| {
|
||||||
|
// NOTE: adjust this path per user if needed.
|
||||||
|
let path = "fregex/libfregex.dylib";
|
||||||
|
let lib = unsafe { Library::new(path) }.unwrap_or_else(|e| {
|
||||||
|
panic!("Failed to load libfregex from {}: {}", path, e);
|
||||||
|
});
|
||||||
|
unsafe {
|
||||||
|
let tokenlist_init: FnTokenListInit = *lib.get(b"tokenlist_init\0").expect("symbol tokenlist_init");
|
||||||
|
let tokenlist_free: FnTokenListFree = *lib.get(b"tokenlist_free\0").expect("symbol tokenlist_free");
|
||||||
|
let tokenize_fast: FnTokenizeFast = *lib.get(b"tokenize_fast\0").expect("symbol tokenize_fast");
|
||||||
|
println!("rustbpe: loaded libfregex.dylib from {}", path);
|
||||||
|
FregexSymbols { _lib: lib, tokenlist_init, tokenlist_free, tokenize_fast }
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tokenize_with_c_each<'a, F>(input: &'a [u8], mut on_piece: F)
|
||||||
|
where
|
||||||
|
F: FnMut(&'a [u8]),
|
||||||
|
{
|
||||||
|
if input.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let syms = load_fregex();
|
||||||
|
let mut out = CTokenList { splits: std::ptr::null_mut(), count: 0, capacity: 0 };
|
||||||
|
let base_ptr = input.as_ptr() as usize;
|
||||||
|
unsafe {
|
||||||
|
(syms.tokenlist_init)(&mut out as *mut CTokenList);
|
||||||
|
(syms.tokenize_fast)(input.as_ptr() as *const c_char, input.len(), &mut out as *mut CTokenList);
|
||||||
|
if !out.splits.is_null() {
|
||||||
|
let slice = std::slice::from_raw_parts(out.splits, out.count);
|
||||||
|
for pos in slice.iter() {
|
||||||
|
let start = pos.start.saturating_sub(base_ptr);
|
||||||
|
let end = pos.end.saturating_sub(base_ptr);
|
||||||
|
if end <= input.len() && start <= end {
|
||||||
|
on_piece(&input[start..end]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(syms.tokenlist_free)(&mut out as *mut CTokenList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
|
/// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
pub struct Tokenizer {
|
pub struct Tokenizer {
|
||||||
|
|
@ -296,8 +370,8 @@ impl Tokenizer {
|
||||||
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
|
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
|
||||||
};
|
};
|
||||||
|
|
||||||
// Global chunk counts
|
// Global chunk counts: own bytes once per unique chunk (no string copies)
|
||||||
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
|
let mut counts: AHashMap<Vec<u8>, i32> = AHashMap::new();
|
||||||
|
|
||||||
// Temporary buffer we refill under the GIL
|
// Temporary buffer we refill under the GIL
|
||||||
let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
|
let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
|
||||||
|
|
@ -345,31 +419,28 @@ impl Tokenizer {
|
||||||
|
|
||||||
total_sequences += buf.len() as u64;
|
total_sequences += buf.len() as u64;
|
||||||
|
|
||||||
let pattern = self.compiled_pattern.clone();
|
// Build per-string local counts that reference the buffer slices (no allocations)
|
||||||
let local: AHashMap<CompactString, i32> = py.allow_threads(|| {
|
let locals: Vec<Vec<(&[u8], i32)>> = py.allow_threads(|| {
|
||||||
buf.par_iter()
|
buf.par_iter()
|
||||||
.map(|s| {
|
.map(|s| {
|
||||||
let mut m: AHashMap<CompactString, i32> = AHashMap::new();
|
let mut m: AHashMap<&[u8], i32> = AHashMap::new();
|
||||||
for mat in pattern.find_iter(s) {
|
let bytes = s.as_bytes();
|
||||||
let piece = mat.expect("regex match failed").as_str();
|
tokenize_with_c_each(bytes, |piece| { *m.entry(piece).or_default() += 1; });
|
||||||
*m.entry(CompactString::from(piece)).or_default() += 1;
|
// Materialize as Vec to allow merging after parallel section
|
||||||
}
|
m.into_iter().collect::<Vec<(&[u8], i32)>>()
|
||||||
m
|
|
||||||
})
|
})
|
||||||
.reduce(
|
.collect()
|
||||||
|| AHashMap::new(),
|
|
||||||
|mut a, b| {
|
|
||||||
for (k, v) in b {
|
|
||||||
*a.entry(k).or_default() += v;
|
|
||||||
}
|
|
||||||
a
|
|
||||||
},
|
|
||||||
)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Merge local into global (single-threaded)
|
// Merge locals into global (single-threaded) without copying unless inserting new keys
|
||||||
for (k, v) in local {
|
for local in locals {
|
||||||
*counts.entry(k).or_default() += v;
|
for (piece, v) in local {
|
||||||
|
if let Some(cnt) = counts.get_mut(piece) {
|
||||||
|
*cnt += v;
|
||||||
|
} else {
|
||||||
|
counts.insert(piece.to_vec(), v);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if exhausted {
|
if exhausted {
|
||||||
|
|
@ -381,8 +452,8 @@ impl Tokenizer {
|
||||||
// Materialize words & counts
|
// Materialize words & counts
|
||||||
let mut words = Vec::with_capacity(counts.len());
|
let mut words = Vec::with_capacity(counts.len());
|
||||||
let mut cvec = Vec::with_capacity(counts.len());
|
let mut cvec = Vec::with_capacity(counts.len());
|
||||||
for (chunk, c) in counts.into_iter() {
|
for (chunk_bytes, c) in counts.into_iter() {
|
||||||
words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect()));
|
words.push(Word::new(chunk_bytes.iter().map(|&b| b as u32).collect()));
|
||||||
cvec.push(c);
|
cvec.push(c);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user