mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
add into rustbpe
This commit is contained in:
parent
851810c7d5
commit
0a1059d571
18
rustbpe/Cargo.lock
generated
18
rustbpe/Cargo.lock
generated
|
|
@ -186,6 +186,16 @@ version = "0.2.175"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
version = "0.8.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.28"
|
||||
|
|
@ -363,7 +373,9 @@ dependencies = [
|
|||
"dary_heap",
|
||||
"fancy-regex",
|
||||
"indexmap",
|
||||
"libloading",
|
||||
"log",
|
||||
"once_cell",
|
||||
"pyo3",
|
||||
"pyo3-log",
|
||||
"rayon",
|
||||
|
|
@ -431,6 +443,12 @@ dependencies = [
|
|||
"wit-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-link"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen"
|
||||
version = "0.45.1"
|
||||
|
|
|
|||
|
|
@ -13,3 +13,5 @@ pyo3-log = "0.12.4"
|
|||
ahash = "0.8.12"
|
||||
rayon = "1.11.0"
|
||||
compact_str = "0.9.0"
|
||||
libloading = "0.8.5"
|
||||
once_cell = "1.19.0"
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ use std::collections::HashMap as StdHashMap;
|
|||
|
||||
use dary_heap::OctonaryHeap;
|
||||
use fancy_regex::Regex;
|
||||
use libloading::Library;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
|
||||
use ahash::{AHashMap, AHashSet};
|
||||
use compact_str::CompactString;
|
||||
use rayon::prelude::*;
|
||||
|
||||
// Default GPT-4 style regex pattern for splitting text
|
||||
|
|
@ -15,6 +16,79 @@ const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{
|
|||
|
||||
type Pair = (u32, u32);
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
type c_char = std::os::raw::c_char;
|
||||
|
||||
#[repr(C)]
|
||||
struct CTokenPos {
|
||||
start: usize,
|
||||
end: usize,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
struct CTokenList {
|
||||
splits: *mut CTokenPos,
|
||||
count: usize,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
type FnTokenListInit = unsafe extern "C" fn(list: *mut CTokenList);
|
||||
type FnTokenListFree = unsafe extern "C" fn(list: *mut CTokenList);
|
||||
type FnTokenizeFast = unsafe extern "C" fn(input: *const c_char, input_len: usize, out: *mut CTokenList);
|
||||
|
||||
struct FregexSymbols {
|
||||
_lib: Library,
|
||||
tokenlist_init: FnTokenListInit,
|
||||
tokenlist_free: FnTokenListFree,
|
||||
tokenize_fast: FnTokenizeFast,
|
||||
}
|
||||
|
||||
static FREGEX: OnceCell<FregexSymbols> = OnceCell::new();
|
||||
|
||||
fn load_fregex() -> &'static FregexSymbols {
|
||||
FREGEX.get_or_init(|| {
|
||||
// NOTE: adjust this path per user if needed.
|
||||
let path = "fregex/libfregex.dylib";
|
||||
let lib = unsafe { Library::new(path) }.unwrap_or_else(|e| {
|
||||
panic!("Failed to load libfregex from {}: {}", path, e);
|
||||
});
|
||||
unsafe {
|
||||
let tokenlist_init: FnTokenListInit = *lib.get(b"tokenlist_init\0").expect("symbol tokenlist_init");
|
||||
let tokenlist_free: FnTokenListFree = *lib.get(b"tokenlist_free\0").expect("symbol tokenlist_free");
|
||||
let tokenize_fast: FnTokenizeFast = *lib.get(b"tokenize_fast\0").expect("symbol tokenize_fast");
|
||||
println!("rustbpe: loaded libfregex.dylib from {}", path);
|
||||
FregexSymbols { _lib: lib, tokenlist_init, tokenlist_free, tokenize_fast }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn tokenize_with_c_each<'a, F>(input: &'a [u8], mut on_piece: F)
|
||||
where
|
||||
F: FnMut(&'a [u8]),
|
||||
{
|
||||
if input.is_empty() {
|
||||
return;
|
||||
}
|
||||
let syms = load_fregex();
|
||||
let mut out = CTokenList { splits: std::ptr::null_mut(), count: 0, capacity: 0 };
|
||||
let base_ptr = input.as_ptr() as usize;
|
||||
unsafe {
|
||||
(syms.tokenlist_init)(&mut out as *mut CTokenList);
|
||||
(syms.tokenize_fast)(input.as_ptr() as *const c_char, input.len(), &mut out as *mut CTokenList);
|
||||
if !out.splits.is_null() {
|
||||
let slice = std::slice::from_raw_parts(out.splits, out.count);
|
||||
for pos in slice.iter() {
|
||||
let start = pos.start.saturating_sub(base_ptr);
|
||||
let end = pos.end.saturating_sub(base_ptr);
|
||||
if end <= input.len() && start <= end {
|
||||
on_piece(&input[start..end]);
|
||||
}
|
||||
}
|
||||
}
|
||||
(syms.tokenlist_free)(&mut out as *mut CTokenList);
|
||||
}
|
||||
}
|
||||
|
||||
/// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
|
||||
#[pyclass]
|
||||
pub struct Tokenizer {
|
||||
|
|
@ -296,8 +370,8 @@ impl Tokenizer {
|
|||
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
|
||||
};
|
||||
|
||||
// Global chunk counts
|
||||
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
|
||||
// Global chunk counts: own bytes once per unique chunk (no string copies)
|
||||
let mut counts: AHashMap<Vec<u8>, i32> = AHashMap::new();
|
||||
|
||||
// Temporary buffer we refill under the GIL
|
||||
let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
|
||||
|
|
@ -345,31 +419,28 @@ impl Tokenizer {
|
|||
|
||||
total_sequences += buf.len() as u64;
|
||||
|
||||
let pattern = self.compiled_pattern.clone();
|
||||
let local: AHashMap<CompactString, i32> = py.allow_threads(|| {
|
||||
// Build per-string local counts that reference the buffer slices (no allocations)
|
||||
let locals: Vec<Vec<(&[u8], i32)>> = py.allow_threads(|| {
|
||||
buf.par_iter()
|
||||
.map(|s| {
|
||||
let mut m: AHashMap<CompactString, i32> = AHashMap::new();
|
||||
for mat in pattern.find_iter(s) {
|
||||
let piece = mat.expect("regex match failed").as_str();
|
||||
*m.entry(CompactString::from(piece)).or_default() += 1;
|
||||
}
|
||||
m
|
||||
let mut m: AHashMap<&[u8], i32> = AHashMap::new();
|
||||
let bytes = s.as_bytes();
|
||||
tokenize_with_c_each(bytes, |piece| { *m.entry(piece).or_default() += 1; });
|
||||
// Materialize as Vec to allow merging after parallel section
|
||||
m.into_iter().collect::<Vec<(&[u8], i32)>>()
|
||||
})
|
||||
.reduce(
|
||||
|| AHashMap::new(),
|
||||
|mut a, b| {
|
||||
for (k, v) in b {
|
||||
*a.entry(k).or_default() += v;
|
||||
}
|
||||
a
|
||||
},
|
||||
)
|
||||
.collect()
|
||||
});
|
||||
|
||||
// Merge local into global (single-threaded)
|
||||
for (k, v) in local {
|
||||
*counts.entry(k).or_default() += v;
|
||||
// Merge locals into global (single-threaded) without copying unless inserting new keys
|
||||
for local in locals {
|
||||
for (piece, v) in local {
|
||||
if let Some(cnt) = counts.get_mut(piece) {
|
||||
*cnt += v;
|
||||
} else {
|
||||
counts.insert(piece.to_vec(), v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if exhausted {
|
||||
|
|
@ -381,8 +452,8 @@ impl Tokenizer {
|
|||
// Materialize words & counts
|
||||
let mut words = Vec::with_capacity(counts.len());
|
||||
let mut cvec = Vec::with_capacity(counts.len());
|
||||
for (chunk, c) in counts.into_iter() {
|
||||
words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect()));
|
||||
for (chunk_bytes, c) in counts.into_iter() {
|
||||
words.push(Word::new(chunk_bytes.iter().map(|&b| b as u32).collect()));
|
||||
cvec.push(c);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user