From 12f418f0a10f1c37981180d2d8427157b494ef63 Mon Sep 17 00:00:00 2001 From: MadMax129 Date: Thu, 23 Oct 2025 16:59:10 -0400 Subject: [PATCH 1/5] faster regex in C --- fregex/bench.py | 157 ++++++++++++ fregex/cload.py | 20 ++ fregex/compare.py | 178 ++++++++++++++ fregex/fregex.c | 524 +++++++++++++++++++++++++++++++++++++++++ fregex/fregex.h | 28 +++ fregex/fuzz.py | 361 ++++++++++++++++++++++++++++ fregex/py_tokenizer.py | 43 ++++ fregex/rust_split.py | 38 +++ rustbpe/src/lib.rs | 15 ++ 9 files changed, 1364 insertions(+) create mode 100755 fregex/bench.py create mode 100644 fregex/cload.py create mode 100644 fregex/compare.py create mode 100644 fregex/fregex.c create mode 100644 fregex/fregex.h create mode 100644 fregex/fuzz.py create mode 100644 fregex/py_tokenizer.py create mode 100644 fregex/rust_split.py diff --git a/fregex/bench.py b/fregex/bench.py new file mode 100755 index 0000000..f916722 --- /dev/null +++ b/fregex/bench.py @@ -0,0 +1,157 @@ +""" +Benchmarker for comparing tok.c tokenize_fast() vs rust split_text() +Measures speed WITHOUT subprocess overhead - direct function calls only. + +Usage: + cd pytok + source ../.venv/bin/activate + python3 bench.py # Run synthetic data benchmarks + python3 bench.py /path/to/file.txt # Benchmark a specific file + python3 bench.py file1.txt file2.txt ... # Benchmark multiple files +""" + +import sys +import ctypes +import random +import time +import statistics +from pathlib import Path + +from nanochat.tokenizer import SPLIT_PATTERN +from rustbpe import split_text as rust_split_text +from fregex.fuzz import gen_valid_unicode_string, compare_pair_text +from fregex.cload import * + +def bench_c_regex(data: bytes, iterations: int) -> list: + times = [] + for _ in range(iterations): + token_list = TokenList() + c_lib.tokenlist_init(ctypes.byref(token_list)) + + start = time.perf_counter() + c_lib.tokenize_fast(data, len(data), ctypes.byref(token_list)) + elapsed = time.perf_counter() - start + + c_lib.tokenlist_free(ctypes.byref(token_list)) + times.append(elapsed * 1000) + + return times + +def bench_rust_regex(text: str, iterations: int) -> list: + times = [] + for _ in range(iterations): + start = time.perf_counter() + rust_split_text(SPLIT_PATTERN, text) + elapsed = time.perf_counter() - start + times.append(elapsed * 1000) + + return times + +def stats_summary(times: list) -> dict: + """Compute statistics from timing list.""" + if not times or len(times) == 0: + return {} + + return { + 'min': min(times), + 'max': max(times), + 'mean': statistics.mean(times), + 'median': statistics.median(times), + 'stdev': statistics.stdev(times) if len(times) > 1 else 0, + } + +def format_stats(name: str, data_size: int, times: list) -> str: + """Format timing statistics for output.""" + if not times or len(times) == 0: + return f"{name:20} {data_size:>10} B --\n" + + stats = stats_summary(times) + + return (f"{name:20} {data_size:>10} B " + f"min={stats['min']:.3f}ms max={stats['max']:.3f}ms " + f"mean={stats['mean']:.3f}ms median={stats['median']:.3f}ms " + f"stdev={stats['stdev']:.3f}ms\n") + +def benchmark_dataset(name: str, data_bytes: bytes, iterations: int) -> None: + test_text = data_bytes.decode('utf-8', errors='replace') + + print(f"\n--- Dataset: {name} ({len(data_bytes)} bytes, {iterations} iterations) ---") + print() + + c_times = bench_c_regex(data_bytes, iterations) + print(format_stats("C tokenizer", len(data_bytes), c_times), end='') + + rust_times = bench_rust_regex(test_text, iterations) + print(format_stats("Rust split", len(data_bytes), rust_times), end='') + + if c_times and rust_times: + c_mean = statistics.mean(c_times) + rust_mean = statistics.mean(rust_times) + ratio = rust_mean / c_mean + speedup = "C is faster" if ratio > 1 else "Rust is faster" + print(f"Speedup: {ratio:.2f}x ({speedup})") + + print() + + # Verify token splits match between C and Python regex tokenizer + cmp_text = data_bytes.decode('utf-8', errors='surrogatepass') + ok, err, out_c, out_py = compare_pair_text(cmp_text) + if ok: + print("Compare: OK (C vs Py splits match)") + else: + print("Compare: MISMATCH (C vs Py)") + if err: + print(err) + if out_c is not None and out_py is not None: + c_lines = out_c.splitlines() + p_lines = out_py.splitlines() + print(f"C tokens: {len(c_lines)} | Py tokens: {len(p_lines)}") + print("--- C (head) ---") + print("\n".join(c_lines[:10])) + print("--- Py (head) ---") + print("\n".join(p_lines[:10])) + # Stop the benchmark if mismatch detected + raise SystemExit(1) + +def main(): + # Check if files were provided as arguments + file_args = sys.argv[1:] if len(sys.argv) > 1 else [] + + # If files provided, benchmark them + if file_args: + for file_path in file_args: + path = Path(file_path) + if not path.exists(): + print(f"❌ File not found: {file_path}") + continue + + try: + data = path.read_bytes() + benchmark_dataset(path.name, data, 10) + except Exception as e: + print(f"❌ Error reading {file_path}: {e}") + else: + # Run random generated data + configs = [ + ("tiny", 100, 1000), + ("small", 1024, 500), + ("medium", 10 * 1024, 100), + ("large", 100 * 1024, 30), + ("xlarge", 1024 * 1024, 10), + ] + + for name, size_bytes, iterations in configs: + # Generate test data + test_text = gen_valid_unicode_string( + random.Random(hash(name)), + size_bytes + ) + test_bytes = test_text.encode('utf-8') + + benchmark_dataset(name, test_bytes, iterations) + + print("=" * 140) + + +if __name__ == "__main__": + main() diff --git a/fregex/cload.py b/fregex/cload.py new file mode 100644 index 0000000..2a17c94 --- /dev/null +++ b/fregex/cload.py @@ -0,0 +1,20 @@ +import ctypes + +c_lib = ctypes.CDLL("fregex/libfregex.dylib") + +class TokenList(ctypes.Structure): + pass + +TokenList._fields_ = [ + ("tokens", ctypes.POINTER(ctypes.POINTER(ctypes.c_char))), + ("lengths", ctypes.POINTER(ctypes.c_size_t)), + ("count", ctypes.c_size_t), + ("capacity", ctypes.c_size_t), +] + +c_lib.tokenlist_init.argtypes = [ctypes.POINTER(TokenList)] +c_lib.tokenlist_init.restype = None +c_lib.tokenlist_free.argtypes = [ctypes.POINTER(TokenList)] +c_lib.tokenlist_free.restype = None +c_lib.tokenize_fast.argtypes = [ctypes.c_char_p, ctypes.c_size_t, ctypes.POINTER(TokenList)] +c_lib.tokenize_fast.restype = None \ No newline at end of file diff --git a/fregex/compare.py b/fregex/compare.py new file mode 100644 index 0000000..7e8a7ff --- /dev/null +++ b/fregex/compare.py @@ -0,0 +1,178 @@ +import sys +import ctypes +from pathlib import Path + +from nanochat.tokenizer import SPLIT_PATTERN +from rustbpe import split_text as rust_split_text +from fregex.cload import * +from fregex.py_tokenizer import tokenize_py as py_tokenize_str + +def escape_bytes(b: bytes) -> str: + buf = [] + for code in b: + if code == 0x5C: + buf.append('\\') + elif code == 0x0A: + buf.append('\\n') + elif code == 0x0D: + buf.append('\\r') + elif code == 0x09: + buf.append('\\t') + elif code == 0x0C: + buf.append('\\f') + elif code == 0x0B: + buf.append('\\v') + elif code == 0x22: + buf.append('\\"') + elif code < 32 or code >= 127: + buf.append(f"\\x{code:02X}") + else: + buf.append(chr(code)) + return ''.join(buf) + +def dump_tokens(tokens: list[bytes]) -> str: + return "\n".join(f"{len(b)}\t{escape_bytes(b)}" for b in tokens) + +def tokenize_c_bytes(data: bytes) -> list[bytes]: + tl = TokenList() + c_lib.tokenlist_init(ctypes.byref(tl)) + try: + c_lib.tokenize_fast(data, len(data), ctypes.byref(tl)) + out: list[bytes] = [] + count = int(tl.count) + for i in range(count): + ptr = tl.tokens[i] + ln = int(tl.lengths[i]) + out.append(ctypes.string_at(ptr, ln)) + return out + finally: + c_lib.tokenlist_free(ctypes.byref(tl)) + +def tokenize_py_bytes(data: bytes) -> list[bytes]: + text = data.decode('utf-8', errors='surrogatepass') + toks = py_tokenize_str(text) + return [t.encode('utf-8', errors='surrogatepass') for t in toks] + +def tokenize_rs_bytes(data: bytes) -> list[bytes]: + text = data.decode('utf-8', errors='surrogatepass') + parts = rust_split_text(SPLIT_PATTERN, text) + return [t.encode('utf-8', errors='surrogatepass') for t in parts] + +def compare_one(path: Path) -> int: + data_bytes = Path(path).read_bytes() + try: + c_toks = tokenize_c_bytes(data_bytes) + except Exception as e: + print(f"C tokenizer failed on {path}:\n{e}", file=sys.stderr) + return 1 + try: + py_toks = tokenize_py_bytes(data_bytes) + except Exception as e: + print(f"Python tokenizer failed on {path}:\n{e}", file=sys.stderr) + return 1 + try: + rs_toks = tokenize_rs_bytes(data_bytes) + except Exception as e: + print(f"Rust split failed on {path}:\n{e}", file=sys.stderr) + return 1 + + out_c = dump_tokens(c_toks) + out_py = dump_tokens(py_toks) + out_rs = dump_tokens(rs_toks) + + if out_c == out_py == out_rs: + print(f"OK {path.name}") + return 0 + else: + print(f"DIFF {path.name}") + # Show a small 3-way diff at first differing line, with byte offsets + c_lines = out_c.splitlines() + p_lines = out_py.splitlines() + r_lines = out_rs.splitlines() + + def parse_lines(lines): + parsed = [] + for ln in lines: + # Format is: "\t" + try: + left, right = ln.split('\t', 1) + blen = int(left) + except Exception: + blen = 0 + right = ln + parsed.append((blen, right)) + return parsed + + c_parsed = parse_lines(c_lines) + p_parsed = parse_lines(p_lines) + r_parsed = parse_lines(r_lines) + + def byte_offsets(parsed): + offs = [] + pos = 0 + for blen, _ in parsed: + offs.append((pos, pos + blen)) + pos += blen + return offs + + c_offs = byte_offsets(c_parsed) + p_offs = byte_offsets(p_parsed) + r_offs = byte_offsets(r_parsed) + + # Load original input bytes so we can show precise substrings and code points + data_bytes = Path(path).read_bytes() + + def print_unicode_debug(label, offs_list, idx): + if idx >= len(offs_list): + print(f" {label} piece: [n/a]") + return + start, end = offs_list[idx] + piece_bytes = data_bytes[start:end] + piece_text = piece_bytes.decode('utf-8', errors='replace') + if not piece_bytes: + print(f" {label} piece: [EMPTY]") + return + cp_parts = [] + for ch in piece_text: + cp_parts.append(f"U+{ord(ch):04X}") + bytes_hex = ' '.join(f"{b:02X}" for b in piece_bytes) + print(f" {label} chars: {' | '.join(cp_parts)}") + print(f" {label} bytes: {bytes_hex} ({len(piece_bytes)}B, {len(piece_text)} chars)") + + max_len = max(len(c_lines), len(p_lines), len(r_lines)) + for i in range(max_len): + cl = c_lines[i] if i < len(c_lines) else "" + pl = p_lines[i] if i < len(p_lines) else "" + rl = r_lines[i] if i < len(r_lines) else "" + if not (cl == pl == rl): + # Collect byte positions if available + c_pos = f"[{c_offs[i][0]}:{c_offs[i][1]}]" if i < len(c_offs) else "[n/a]" + p_pos = f"[{p_offs[i][0]}:{p_offs[i][1]}]" if i < len(p_offs) else "[n/a]" + r_pos = f"[{r_offs[i][0]}:{r_offs[i][1]}]" if i < len(r_offs) else "[n/a]" + print( + f" line {i+1}:\n" + f" C: {cl} @ bytes {c_pos}\n" + f" Py: {pl} @ bytes {p_pos}\n" + f" Rs: {rl} @ bytes {r_pos}" + ) + print(" === Unicode split detail ===") + print_unicode_debug("C", c_offs, i) + print_unicode_debug("Py", p_offs, i) + print_unicode_debug("Rs", r_offs, i) + break + return 2 + +def main(): + if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} ") + sys.exit(2) + paths = sorted(Path(sys.argv[1]).glob('*.txt')) + bad = 0 + for p in paths: + bad += compare_one(p) + print(f"Completed. Failures: {bad}") + +if __name__ == '__main__': + main() + + diff --git a/fregex/fregex.c b/fregex/fregex.c new file mode 100644 index 0000000..538430b --- /dev/null +++ b/fregex/fregex.c @@ -0,0 +1,524 @@ +#include "fregex.h" + +#include +#include +#include +#include +#include + +#include "utf8proc/utf8proc.h" + +/* +Regex pattern we care about from nanochat/tokenizer.py SPLIT_PATTERN + +Break it down: +A) '(?i:[sdmt]|ll|ve|re) +B) [^\r\n\p{L}\p{N}]?+\p{L}+ +C) \p{N}{1,2} +D) ?[^\s\p{L}\p{N}]++[\r\n]* +E) \s*[\r\n] +F) \s+(?!\S) +G) \s+ +*/ + +#define UNICODE_LF 0x000A // Line Feed +#define UNICODE_CR 0x000D // Carriage Return + +static inline size_t utf8_decode_cp( + const char *s, + const char *end, + unsigned int *out_cp +) { + if (s >= end) { + *out_cp = 0; + return 0; + } + utf8proc_int32_t ch = 0; + ssize_t ret = utf8proc_iterate( + (const utf8proc_uint8_t*)s, + (ssize_t)(end - s), + &ch + ); + if (ret < 0) { + // invalid sequence: treat as single byte + *out_cp = (unsigned char)*s; + return 1; + } + *out_cp = (unsigned int)ch; + return (size_t)ret; +} + +static inline bool is_cr_or_lf(unsigned int cp) { + return cp == UNICODE_LF || cp == UNICODE_CR; +} + +static inline bool is_letter(unsigned int cp) { + utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp); + + switch (cat) { + case UTF8PROC_CATEGORY_LU: // Letter, Uppercase + case UTF8PROC_CATEGORY_LL: // Letter, Lowercase + case UTF8PROC_CATEGORY_LT: // Letter, Titlecase + case UTF8PROC_CATEGORY_LM: // Letter, Modifier + case UTF8PROC_CATEGORY_LO: // Letter, Other + return true; + default: + return false; + } +} + +static inline bool is_number(unsigned int cp) { + utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp); + switch (cat) { + case UTF8PROC_CATEGORY_ND: // Number, Decimal Digit + case UTF8PROC_CATEGORY_NL: // Number, Letter + case UTF8PROC_CATEGORY_NO: // Number, Other + return true; + default: + return false; + } +} + +static inline bool is_space(unsigned int cp) { + utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp); + + if ( + cat == UTF8PROC_CATEGORY_ZS || + cat == UTF8PROC_CATEGORY_ZL || + cat == UTF8PROC_CATEGORY_ZP + ) { + return true; + } + + switch (cp) { + case 0x0009: // TAB + case 0x000A: // LF + case 0x000B: // VT + case 0x000C: // FF + case 0x000D: // CR + case 0x0085: // NEL + return true; + default: + return false; + } +} + +static inline bool is_alnum(unsigned int cp) { + return is_letter(cp) || is_number(cp); +} + +static inline bool is_non_space_letter_number(unsigned int cp) { + return !is_space(cp) && !is_alnum(cp); +} + +static void *xrealloc(void *ptr, size_t new_size) { + void *p = realloc(ptr, new_size); + if (!p) { + fprintf(stderr, "Out of memory while reallocating %zu bytes\n", new_size); + exit(1); + } + return p; +} + +static char *xmemdup(const char *src, size_t len) { + char *s = (char*)malloc(len + 1); + if (!s) { + fprintf(stderr, "Out of memory while allocating %zu bytes\n", len + 1); + exit(1); + } + memcpy(s, src, len); + s[len] = '\0'; + return s; +} + +void tokenlist_init(TokenList *list) { + list->tokens = NULL; + list->lengths = NULL; + list->count = 0; + list->capacity = 0; +} + +void tokenlist_free(TokenList *list) { + if (!list) + return; + for (size_t i = 0; i < list->count; ++i) + free(list->tokens[i]); + free(list->tokens); + free(list->lengths); + list->tokens = NULL; + list->lengths = NULL; + list->count = 0; + list->capacity = 0; +} + +static void tokenlist_push(TokenList *list, const char *start, size_t len) { + if (list->count == list->capacity) { + const size_t new_cap = list->capacity ? (list->capacity * 2) : 64; + list->tokens = (char**)xrealloc(list->tokens, new_cap * sizeof(char*)); + list->lengths = (size_t*)xrealloc(list->lengths, new_cap * sizeof(size_t)); + list->capacity = new_cap; + } + list->tokens[list->count] = xmemdup(start, len); + list->lengths[list->count] = len; + list->count++; +} + +static void fput_escaped_char(unsigned char c, FILE *out) { + switch (c) { + case '\\': fputs("\\\\", out); break; + case '\n': fputs("\\n", out); break; + case '\r': fputs("\\r", out); break; + case '\t': fputs("\\t", out); break; + case '\f': fputs("\\f", out); break; + case '\v': fputs("\\v", out); break; + case '\"': fputs("\\\"", out); break; + default: + if (c < 32 || c >= 127) { + fprintf(out, "\\x%02X", c); + } else { + fputc(c, out); + } + } +} + +void print_token_escaped(const char *s, size_t len, FILE *out) { + fprintf(out, "%zu\t", len); + for (size_t i = 0; i < len; ++i) fput_escaped_char((unsigned char)s[i], out); + fputc('\n', out); +} + +void print_tokens_escaped(const TokenList *list, FILE *out) { + for (size_t i = 0; i < list->count; ++i) { + const char *tok = list->tokens[i]; + size_t len = list->lengths[i]; + print_token_escaped(tok, len, out); + } +} + +/* A) '(?i:[sdmt]|ll|ve|re) */ +static size_t match_contraction(const char *p, const char *end) { + if (p >= end || *p != '\'' || (p + 1) >= end) + return 0; + + unsigned char a = (unsigned char)p[1]; + + // locale-independent lowercase for ASCII letters: + // map A–Z to a–z; leaves others unchanged + if (a >= 'A' && a <= 'Z') + a = (unsigned char)(a + ('a' - 'A')); + + // 's 'd 'm 't + if (a == 's' || a == 'd' || a == 'm' || a == 't') + return 2; + + // Need a second following byte for 'll 've 're + if (p + 2 >= end) + return 0; + + unsigned char b = (unsigned char)p[2]; + if (b >= 'A' && b <= 'Z') + b = (unsigned char)(b + ('a' - 'A')); + + if ((a == 'l' && b == 'l') || + (a == 'v' && b == 'e') || + (a == 'r' && b == 'e')) { + return 3; + } + + return 0; +} + +/* B) [^\r\n\p{L}\p{N}]?+\p{L}+ */ +static size_t match_word_with_optional_prefix(const char *p, const char *end) { + if (p >= end) + return 0; + + const char *q = p; + unsigned int cp0; + size_t n0 = utf8_decode_cp(q, end, &cp0); + if (n0 == 0) + return 0; + + const char *letters_start = q; // will point to first letter of \p{L}+ + size_t prefix_bytes = 0; + + // Consider optional one-codepoint prefix if first cp is NOT (CR/LF or alnum) + if (!is_cr_or_lf(cp0) && !is_alnum(cp0)) { + // Look ahead: must have at least one letter right after to commit the prefix + const char *after_prefix = q + n0; + if (after_prefix >= end) { + return 0; // no room for required \p{L}+ + } + unsigned int cp1; + size_t n1 = utf8_decode_cp(after_prefix, end, &cp1); + if (n1 > 0 && is_letter(cp1)) { + // Commit the prefix (possessive) and start letters after it + prefix_bytes = n0; + letters_start = after_prefix; + q = after_prefix; + } else { + // Can't commit the prefix (no letter follows), so this branch fails + return 0; + } + } else if (is_letter(cp0)) { + // No prefix; we already sit on the first letter + letters_start = q; + } else { + // First cp is CR/LF or a number; this branch doesn't match + return 0; + } + + // Consume \p{L}+ (require at least one letter) + size_t letter_count = 0; + const char *scan = letters_start; + while (scan < end) { + unsigned int cp; + size_t n = utf8_decode_cp(scan, end, &cp); + if (n == 0 || !is_letter(cp)) + break; + scan += n; + letter_count++; + } + + if (letter_count == 0) { + // Shouldn't happen given the look-ahead, but keep the guard + return 0; + } + + return (size_t)(scan - p); // includes prefix (if any) + letters +} + +/* C) \p{N}{1,2} */ +static size_t match_short_number(const char *p, const char *end) { + if (p >= end) + return 0; + + /* First number required */ + unsigned int cp1; + size_t n1 = utf8_decode_cp(p, end, &cp1); + if (n1 == 0 || !is_number(cp1)) + return 0; + + const char *q = p + n1; + + // Optional second number cp + if (q < end) { + unsigned int cp2; + size_t n2 = utf8_decode_cp(q, end, &cp2); + if (n2 > 0 && is_number(cp2)) + return (size_t)((q + n2) - p); // 2 numbers + } + return (size_t)(q - p); // 1 number +} + +/* D) ?[^\s\p{L}\p{N}]++[\r\n]* */ +// Optional single ASCII space, then 1+ of (not whitespace, not letter, not number), +static size_t match_punct_run(const char *p, const char *end) { + const char *q = p; + + /* Optional single ASCII space */ + if (q < end && *q == ' ') { + const char *r = q + 1; + if (r >= end) + return 0; + + unsigned int cp; + size_t n = utf8_decode_cp(r, end, &cp); + if (n == 0) + return 0; + + // Eligible b/c not whitespace and not alnum + if (!is_space(cp) && !is_alnum(cp)) { + q = r; // commit the space + } else { + // Not followed by eligible punct + return 0; + } + } + + // Now require at least one eligible (not whitespace, not alnum) + size_t took = 0; + while (q < end) { + unsigned int cp; + size_t n = utf8_decode_cp(q, end, &cp); + if (n == 0) break; + if (is_space(cp) || is_alnum(cp)) + break; // stop on any whitespace or letter/number + q += n; + took++; + } + if (took == 0) + return 0; // must have at least one punctuation/symbol + + // Finally, optionally absorb CR/LF sequence(s) + while (q < end) { + unsigned int cp; + size_t n = utf8_decode_cp(q, end, &cp); + if (n == 0 || !is_cr_or_lf(cp)) + break; + q += n; + } + + return (size_t)(q - p); +} + +/* E) \s*[\r\n] */ +static size_t match_ws_then_linebreak(const char *p, const char *end) { + const char *q = p; + + // Collect all positions while consuming whitespace + // TODO: ? Could we hit the limit + const char *positions[256]; + int pos_count = 0; + + // Store initial position (zero whitespace consumed) + positions[pos_count++] = q; + + while (q < end && pos_count < 255) { + unsigned int cp; + size_t n = utf8_decode_cp(q, end, &cp); + if (n == 0 || !is_space(cp)) + break; + q += n; + positions[pos_count++] = q; + } + + // Try positions from longest to shortest (backtracking) + // We need to find a position where the next character is a linebreak + for (int i = pos_count - 1; i >= 0; i--) { + q = positions[i]; + + // Check if next character is a linebreak + if (q < end) { + unsigned int br; + size_t nb = utf8_decode_cp(q, end, &br); + if (nb > 0 && is_cr_or_lf(br)) { + // Found a linebreak, include it and return + return (size_t)(q + nb - p); + } + } else { + // EOF reached, rule requires a linebreak so fail + continue; + } + } + + // No position found where next char is a linebreak + return 0; +} + +/* F) \s+(?!\S) */ +static size_t match_trailing_ws(const char *p, const char *end) { + if (p >= end) return 0; + + /* Must start with at least one whitespace */ + const char *q = p; + unsigned int cp; + size_t n = utf8_decode_cp(q, end, &cp); + if (n == 0 || !is_space(cp)) + return 0; + + /* Collect all whitespace positions */ + // TODO: ? Could we hit the limit + const char *positions[256]; + positions[0] = q + n; // Position after first whitespace + int pos_count = 1; + + q += n; + + while (q < end && pos_count < 255) { + size_t m = utf8_decode_cp(q, end, &cp); + if (m == 0 || !is_space(cp)) + break; + q += m; + positions[pos_count++] = q; + } + + /* Try positions from longest to shortest (backtracking) */ + for (int i = pos_count - 1; i >= 0; i--) { + q = positions[i]; + + /* Check negative lookahead: (?!\S) at this position */ + if (q < end) { + size_t k = utf8_decode_cp(q, end, &cp); + if (k > 0 && !is_space(cp)) { + continue; /* Next char is non-space, try shorter match */ + } + } + + /* Lookahead succeeded at this position */ + return (size_t)(q - p); + } + + /* All positions failed lookahead */ + return 0; +} + +/* G) \s+ */ +static size_t match_ws_run(const char *p, const char *end) { + if (p >= end) + return 0; + + const char *q = p; + unsigned int cp; + size_t n = utf8_decode_cp(q, end, &cp); + if (n == 0 || !is_space(cp)) + return 0; + + /* We have at least one whitespace, consume the run */ + q += n; + while (q < end) { + size_t m = utf8_decode_cp(q, end, &cp); + if (m == 0 || !is_space(cp)) + break; + q += m; + } + return (size_t)(q - p); +} + +void tokenize_fast(const char *input, size_t input_len, TokenList *out) { + if (!input) { + out->tokens = NULL; + out->count = 0; + out->capacity = 0; + return; + } + const char *p = input; + const char *end = input + input_len; + + while (p < end) { + /* Special tokens take precedence */ + // TODO LATER + int captured = 0; + + /* Evaluate case A */ + captured = match_contraction(p, end); + /* Evaluate case B */ + if (!captured) captured = match_word_with_optional_prefix(p, end); + /* Evaluate case C */ + if (!captured) captured = match_short_number(p, end); + /* Evaluate case D */ + if (!captured) captured = match_punct_run(p, end); + /* Evaluate case E */ + if (!captured) captured = match_ws_then_linebreak(p, end); + /* Evaluate case F */ + if (!captured) captured = match_trailing_ws(p, end); + /* Evaluate case G */ + if (!captured) captured = match_ws_run(p, end); + + if (captured) { + tokenlist_push(out, p, captured); + p += captured; + } + else { + /* Fallback take a single CP */ + unsigned int cp; + size_t adv = utf8_decode_cp(p, end, &cp); + if (adv == 0) + adv = 1; + tokenlist_push(out, p, adv); + p += adv; + } + } +} + + diff --git a/fregex/fregex.h b/fregex/fregex.h new file mode 100644 index 0000000..1ea0cd9 --- /dev/null +++ b/fregex/fregex.h @@ -0,0 +1,28 @@ +#ifndef FAST_TOKENIZER_H +#define FAST_TOKENIZER_H + +#include +#include + +typedef struct { + char **tokens; + size_t *lengths; // Store length of each token to handle null bytes + size_t count; + size_t capacity; +} TokenList; + +void tokenlist_init(TokenList *list); +void tokenlist_free(TokenList *list); + +// Tokenize input according to the GPT-like regex split semantics +// r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" +void tokenize_fast(const char *input, size_t input_len, TokenList *out); + +// Utility to print tokens with C-like escaping (one per line): +// \t\n +void print_token_escaped(const char *s, size_t len, FILE *out); +void print_tokens_escaped(const TokenList *list, FILE *out); + +#endif // FAST_TOKENIZER_H + + diff --git a/fregex/fuzz.py b/fregex/fuzz.py new file mode 100644 index 0000000..0b22aaf --- /dev/null +++ b/fregex/fuzz.py @@ -0,0 +1,361 @@ +import sys +import time +import random +import argparse +import unicodedata as u +import ctypes +from pathlib import Path + +from fregex.cload import * + +HERE = Path(__file__).resolve().parent +TESTS_DIR = HERE / "tests" + +from fregex.py_tokenizer import tokenize_py as py_tokenize_str + +def escape_bytes(b: bytes) -> str: + buf = [] + for code in b: + if code == 0x5C: + buf.append('\\\\') + elif code == 0x0A: + buf.append('\\n') + elif code == 0x0D: + buf.append('\\r') + elif code == 0x09: + buf.append('\\t') + elif code == 0x0C: + buf.append('\\f') + elif code == 0x0B: + buf.append('\\v') + elif code == 0x22: + buf.append('\\"') + elif code < 32 or code >= 127: + buf.append(f"\\x{code:02X}") + else: + buf.append(chr(code)) + return ''.join(buf) + +def gen_valid_unicode_string(rng: random.Random, max_len: int) -> str: + target_len = rng.randint(0, max_len) + + ws_cps = [ + 0x20, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, # space, \t, \n, \v, \f, \r + 0x00A0, # NO-BREAK SPACE + 0x1680, # OGHAM SPACE MARK + 0x2000, 0x2001, 0x2002, 0x2003, 0x2004, 0x2005, 0x2006, + 0x2007, 0x2008, 0x2009, 0x200A, # EN/EM/THIN/HAIR SPACES etc. + 0x2028, 0x2029, # LINE SEPARATOR, PARAGRAPH SEPARATOR + 0x202F, # NARROW NO-BREAK SPACE + 0x205F, # MEDIUM MATHEMATICAL SPACE + 0x3000, # IDEOGRAPHIC SPACE + 0x200B, # ZERO WIDTH SPACE (not WS in Python, but hits tokenizer class) + 0xFEFF, # ZERO WIDTH NO-BREAK SPACE + ] + + ascii_punct = [ + ord(c) + for c in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" + ] + + def rand_scalar_excluding_surrogates(lo: int, hi: int) -> int: + while True: + cp = rng.randint(lo, hi) + if 0xD800 <= cp <= 0xDFFF: + continue + return cp + + MAX_WS_RUN = 255 + + def is_ws_char(ch: str) -> bool: + cp = ord(ch) + return ch.isspace() or (cp in ws_cps) + + def gen_ws_segment(max_run: int) -> str: + # Mix of various spaces, often multi-length; sometimes explicit CRLFs + if rng.random() < 0.35: + # Build CR, LF, CRLF, or repeated newlines + seqs = ["\n", "\r", "\r\n"] + unit = rng.choice(seqs) + unit_len = len(unit) + max_reps = max(1, min(max_run // unit_len, MAX_WS_RUN // unit_len)) + seg = unit * rng.randint(1, max_reps) + return seg + run = rng.randint(1, min(MAX_WS_RUN, max(1, max_run))) + buf = [] + for _ in range(run): + cp = rng.choice(ws_cps) + buf.append(chr(cp)) + return ''.join(buf) + + def gen_letter_run(max_run: int) -> str: + run = rng.randint(1, max(1, max_run)) + buf = [] + for _ in range(run): + if rng.random() < 0.6: + # ASCII letters + base = ord('A') if rng.random() < 0.5 else ord('a') + buf.append(chr(base + rng.randint(0, 25))) + else: + # Any Unicode letter + while True: + cp = rand_scalar_excluding_surrogates(0x00A0, 0x10FFFF) + if u.category(chr(cp)).startswith('L'): + buf.append(chr(cp)) + break + # optional prefix of single non-WS, non-letter, non-number to stress + # the leading [^\r\n\p{L}\p{N}]?+ in the regex + if rng.random() < 0.3: + buf.insert(0, gen_punc_run(1, allow_space=False)) + return ''.join(buf) + + def gen_number_run(max_run: int) -> str: + # Bias to lengths 1..2 per \p{N}{1,2}, but sometimes longer + if rng.random() < 0.7: + run = rng.randint(1, min(2, max_run)) + else: + run = rng.randint(3, max(3, max_run)) + buf = [] + for _ in range(run): + if rng.random() < 0.75: + buf.append(chr(ord('0') + rng.randint(0, 9))) + else: + # Other numeric categories (Nd/Nl/No) + while True: + cp = rand_scalar_excluding_surrogates(0x00A0, 0x10FFFF) + if u.category(chr(cp)).startswith('N'): + buf.append(chr(cp)) + break + return ''.join(buf) + + def gen_punc_run(max_run: int, allow_space: bool = True) -> str: + run = rng.randint(1, max(1, max_run)) + buf = [] + # optional leading single space before punc block + if allow_space and rng.random() < 0.5: + buf.append(' ') + for _ in range(run): + if rng.random() < 0.6: + cp = rng.choice(ascii_punct) + else: + while True: + cp = rand_scalar_excluding_surrogates(0, 0x10FFFF) + ch = chr(cp) + if ( + not u.category(ch).startswith('L') and + not u.category(ch).startswith('N') and + cp not in ws_cps and + not ch.isspace() + ): + break + # ensure we don't accidentally add null + buf.append(chr(cp)) + # optional trailing newlines to stress [\r\n]* + if rng.random() < 0.35: + tail = gen_ws_segment(3) + # Keep only CR/LF components in the tail for this case + tail = tail.replace('\t', '').replace('\v', '').replace('\f', '').replace(' ', '') + buf.append(tail) + return ''.join(buf) + + def gen_contraction() -> str: + # e.g., we're, he'll, I'd, I'm, can't, they've + prefixes = [gen_letter_run( rng.randint(1, 6) )] + suffix = rng.choice(["s", "d", "m", "t", "ll", "ve", "re"]) + return prefixes[0] + "'" + suffix + + def gen_random_unicode(max_run: int) -> str: + run = rng.randint(1, max(1, max_run)) + buf = [] + for _ in range(run): + cp = rand_scalar_excluding_surrogates(0, 0x10FFFF) + try: + buf.append(chr(cp)) + except ValueError: + continue + return ''.join(buf) + + buf: list[str] = [] + curr_len = 0 + curr_ws_run = 0 + # Build by segments until target_len + while curr_len < target_len: + remain = target_len - curr_len + r = rng.random() + if r < 0.40: + seg = gen_ws_segment(remain) + elif r < 0.45: + # Explicit newline-focused segment + seg = ("\r\n" if rng.random() < 0.5 else ("\n" if rng.random() < 0.5 else "\r")) * rng.randint(1, max(1, remain)) + elif r < 0.65: + seg = gen_letter_run(remain) + elif r < 0.75: + seg = gen_number_run(remain) + elif r < 0.90: + seg = gen_punc_run(remain) + elif r < 0.95: + seg = gen_contraction() + else: + seg = gen_random_unicode(remain) + + if not seg: + continue + # Trim if needed + # Append with whitespace-run capping + for ch in seg: + if curr_len >= target_len: + break + if is_ws_char(ch): + if curr_ws_run >= MAX_WS_RUN: + # insert a non-whitespace breaker + breaker = '.' + buf.append(breaker) + curr_len += 1 + curr_ws_run = 0 + if curr_len >= target_len: + break + buf.append(ch) + curr_len += 1 + curr_ws_run += 1 + else: + buf.append(ch) + curr_len += 1 + curr_ws_run = 0 + + # Occasionally end with trailing spaces to stress \s+(?!\S) + if curr_len < max_len and rng.random() < 0.3: + trail = gen_ws_segment(max_len - curr_len) + if rng.random() < 0.7: + trail = (' ' if rng.random() < 0.6 else '\t') * rng.randint(1, min(8, max_len - curr_len)) + # Append trailing with cap as well + for ch in trail: + if curr_len >= max_len: + break + if is_ws_char(ch): + if curr_ws_run >= MAX_WS_RUN: + buf.append('.') + curr_len += 1 + curr_ws_run = 0 + if curr_len >= max_len: + break + buf.append(ch) + curr_len += 1 + curr_ws_run += 1 + else: + buf.append(ch) + curr_len += 1 + curr_ws_run = 0 + + return ''.join(buf) + +def write_temp_case(text: str, tag: str = "RUN") -> Path: + TESTS_DIR.mkdir(parents=True, exist_ok=True) + ts = int(time.time() * 1000) + fname = f"in_fuzz_{tag}_{ts}.txt" + path = TESTS_DIR / fname + with open(path, 'wb') as f: + f.write(text.encode('utf-8', errors='surrogatepass')) + return path + +def _format_tokens_dump(tokens: list[bytes]) -> str: + lines = [] + for b in tokens: + lines.append(f"{len(b)}\t{escape_bytes(b)}") + return "\n".join(lines) + +def tokenize_c_bytes(data: bytes) -> list[bytes]: + tl = TokenList() + c_lib.tokenlist_init(ctypes.byref(tl)) + try: + c_lib.tokenize_fast(data, len(data), ctypes.byref(tl)) + out: list[bytes] = [] + count = int(tl.count) + for i in range(count): + ptr = tl.tokens[i] + ln = int(tl.lengths[i]) + out.append(ctypes.string_at(ptr, ln)) + return out + finally: + c_lib.tokenlist_free(ctypes.byref(tl)) + +def tokenize_py_bytes(data: bytes) -> list[bytes]: + if py_tokenize_str is None: + raise RuntimeError("py_tokenizer not available") + text = data.decode('utf-8', errors='surrogatepass') + toks = py_tokenize_str(text) + return [t.encode('utf-8', errors='surrogatepass') for t in toks] + +def compare_pair_text(text: str): + data = text.encode('utf-8', errors='surrogatepass') + try: + toks_c = tokenize_c_bytes(data) + except Exception as e: + return False, f"C failed: {e}", None, None + try: + toks_py = tokenize_py_bytes(data) + except Exception as e: + return False, f"Py failed: {e}", None, None + ok = toks_c == toks_py + return ok, None, _format_tokens_dump(toks_c), _format_tokens_dump(toks_py) + +def run_fuzz(iters: int, max_len: int, seed: int, stop_on_first: bool): + rng = random.Random(seed) + total = 0 + mismatches = 0 + last_save = None + + for i in range(iters if iters > 0 else 1_000_000_000): + s = gen_valid_unicode_string(rng, max_len) + ok, err, out_c, out_py = compare_pair_text(s) + total += 1 + if not ok: + mismatches += 1 + fail_path = write_temp_case(s, tag="FAIL") + last_save = fail_path + print(f"Mismatch at iter {i}, saved to {fail_path}") + print(f"Seed: {seed}") + + cps = [f"U+{ord(ch):04X}" for ch in s] + cats = [u.category(ch) for ch in s] + print(f"Text bytes len: {len(s.encode('utf-8','surrogatepass'))}, chars: {len(s)}") + print(f"Codepoints: {' '.join(cps)}") + print(f"Categories: {' '.join(cats)}") + if err: + print(err) + if out_c is not None: + print("--- C tokens ---") + print(out_c) + if out_py is not None: + print("--- Py tokens ---") + print(out_py) + + if stop_on_first: + break + + if (i + 1) % 100 == 0: + print(f"[fuzz] {i+1} cases, mismatches={mismatches}") + # print(out_c, out_py, sep="\n") + + return total, mismatches, last_save + + +def main(): + ap = argparse.ArgumentParser(description="Fuzz C vs Python tokenizers on random valid UTF-8 inputs") + ap.add_argument("--iters", type=int, default=0, help="Number of iterations (0 = very large run)") + ap.add_argument("--max-len", type=int, default=256, help="Maximum number of Unicode scalars per case") + ap.add_argument("--seed", type=int, default=12345, help="PRNG seed for reproducibility") + ap.add_argument("--stop-on-first", action="store_true", help="Stop at first mismatch (default: run all)") + args = ap.parse_args() + + total, mismatches, last = run_fuzz(args.iters, args.max_len, args.seed, args.stop_on_first) + print(f"Completed {total} cases, mismatches={mismatches}") + if last: + print(f"Last failing case saved at: {last}") + if mismatches: + sys.exit(1) + + +if __name__ == "__main__": + main() + + diff --git a/fregex/py_tokenizer.py b/fregex/py_tokenizer.py new file mode 100644 index 0000000..9bea371 --- /dev/null +++ b/fregex/py_tokenizer.py @@ -0,0 +1,43 @@ +import sys + +from nanochat.tokenizer import SPLIT_PATTERN +from tokenizers import pre_tokenizers, Regex + +_SPLITTER = pre_tokenizers.Split(pattern=Regex(SPLIT_PATTERN), behavior="isolated", invert=False) + +def escape_bytes(b: bytes) -> str: + buf = [] + for code in b: + if code == 0x5C: buf.append('\\\\') + elif code == 0x0A: buf.append('\\n') + elif code == 0x0D: buf.append('\\r') + elif code == 0x09: buf.append('\\t') + elif code == 0x0C: buf.append('\\f') + elif code == 0x0B: buf.append('\\v') + elif code == 0x22: buf.append('\\"') + elif code < 32 or code >= 127: + buf.append(f"\\x{code:02X}") + else: + buf.append(chr(code)) + return ''.join(buf) + +def tokenize_py(text: str): + parts = _SPLITTER.pre_tokenize_str(text) + return [s for (s, _range) in parts] + +def main(): + if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} ", file=sys.stderr) + sys.exit(2) + with open(sys.argv[1], 'rb') as f: + data = f.read() + text = data.decode('utf-8', errors='surrogatepass') + for tok in tokenize_py(text): + b = tok.encode('utf-8', errors='surrogatepass') + esc = escape_bytes(b) + print(f"{len(b)}\t{esc}") + +if __name__ == "__main__": + main() + + diff --git a/fregex/rust_split.py b/fregex/rust_split.py new file mode 100644 index 0000000..27b0263 --- /dev/null +++ b/fregex/rust_split.py @@ -0,0 +1,38 @@ +import sys + +from nanochat.tokenizer import SPLIT_PATTERN +from rustbpe import split_text as rust_split_text + +def escape_bytes(b: bytes) -> str: + buf = [] + for code in b: + if code == 0x5C: buf.append('\\\\') + elif code == 0x0A: buf.append('\\n') + elif code == 0x0D: buf.append('\\r') + elif code == 0x09: buf.append('\\t') + elif code == 0x0C: buf.append('\\f') + elif code == 0x0B: buf.append('\\v') + elif code == 0x22: buf.append('\\"') + elif code < 32 or code >= 127: + buf.append(f"\\x{code:02X}") + else: + buf.append(chr(code)) + return ''.join(buf) + +def main(): + if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} ", file=sys.stderr) + sys.exit(2) + with open(sys.argv[1], 'rb') as f: + data = f.read() + text = data.decode('utf-8', errors='surrogatepass') + parts = rust_split_text(SPLIT_PATTERN, text) + for tok in parts: + b = tok.encode('utf-8', errors='surrogatepass') + esc = escape_bytes(b) + print(f"{len(b)}\t{esc}") + +if __name__ == "__main__": + main() + + diff --git a/rustbpe/src/lib.rs b/rustbpe/src/lib.rs index 273d7f2..c092cc8 100644 --- a/rustbpe/src/lib.rs +++ b/rustbpe/src/lib.rs @@ -4,6 +4,7 @@ use std::collections::HashMap as StdHashMap; use dary_heap::OctonaryHeap; use fancy_regex::Regex; use pyo3::prelude::*; +use pyo3::wrap_pyfunction; use ahash::{AHashMap, AHashSet}; use compact_str::CompactString; @@ -467,9 +468,23 @@ impl Tokenizer { } } +#[pyfunction] +pub fn split_text(pattern: String, text: String) -> PyResult> { + let re = Regex::new(&pattern) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?; + let mut out: Vec = Vec::new(); + for m in re.find_iter(&text) { + let m = m + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("Regex match failed: {}", e)))?; + out.push(m.as_str().to_string()); + } + Ok(out) +} + #[pymodule] fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> { pyo3_log::init(); // forwards Rust `log` to Python's `logging` m.add_class::()?; + m.add_function(wrap_pyfunction!(split_text, m)?)?; Ok(()) } From e02938c0aa6608d06b2ea7805652338bb7e19b37 Mon Sep 17 00:00:00 2001 From: MadMax129 Date: Thu, 23 Oct 2025 17:55:33 -0400 Subject: [PATCH 2/5] cleanup --- fregex/bench.py | 12 ------------ fregex/compare.py | 1 - fregex/fregex.h | 13 ++++--------- 3 files changed, 4 insertions(+), 22 deletions(-) diff --git a/fregex/bench.py b/fregex/bench.py index f916722..f707839 100755 --- a/fregex/bench.py +++ b/fregex/bench.py @@ -1,15 +1,3 @@ -""" -Benchmarker for comparing tok.c tokenize_fast() vs rust split_text() -Measures speed WITHOUT subprocess overhead - direct function calls only. - -Usage: - cd pytok - source ../.venv/bin/activate - python3 bench.py # Run synthetic data benchmarks - python3 bench.py /path/to/file.txt # Benchmark a specific file - python3 bench.py file1.txt file2.txt ... # Benchmark multiple files -""" - import sys import ctypes import random diff --git a/fregex/compare.py b/fregex/compare.py index 7e8a7ff..ffdcb50 100644 --- a/fregex/compare.py +++ b/fregex/compare.py @@ -119,7 +119,6 @@ def compare_one(path: Path) -> int: p_offs = byte_offsets(p_parsed) r_offs = byte_offsets(r_parsed) - # Load original input bytes so we can show precise substrings and code points data_bytes = Path(path).read_bytes() def print_unicode_debug(label, offs_list, idx): diff --git a/fregex/fregex.h b/fregex/fregex.h index 1ea0cd9..f41e3f1 100644 --- a/fregex/fregex.h +++ b/fregex/fregex.h @@ -1,12 +1,12 @@ -#ifndef FAST_TOKENIZER_H -#define FAST_TOKENIZER_H +#ifndef FAST_REGEX_H +#define FAST_REGEX_H #include #include typedef struct { char **tokens; - size_t *lengths; // Store length of each token to handle null bytes + size_t *lengths; size_t count; size_t capacity; } TokenList; @@ -14,15 +14,10 @@ typedef struct { void tokenlist_init(TokenList *list); void tokenlist_free(TokenList *list); -// Tokenize input according to the GPT-like regex split semantics -// r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" void tokenize_fast(const char *input, size_t input_len, TokenList *out); - -// Utility to print tokens with C-like escaping (one per line): -// \t\n void print_token_escaped(const char *s, size_t len, FILE *out); void print_tokens_escaped(const TokenList *list, FILE *out); -#endif // FAST_TOKENIZER_H +#endif // FAST_REGEX_H From 41c8b8dbde74d162f9b93fb376870a185ef59d0b Mon Sep 17 00:00:00 2001 From: MadMax129 Date: Thu, 23 Oct 2025 20:23:59 -0400 Subject: [PATCH 3/5] removed buffer approuch --- fregex/bench.py | 2 +- fregex/fregex.c | 133 +++++++++++++++++++++++------------------------- fregex/fuzz.py | 29 ++--------- 3 files changed, 68 insertions(+), 96 deletions(-) diff --git a/fregex/bench.py b/fregex/bench.py index f707839..64dcda0 100755 --- a/fregex/bench.py +++ b/fregex/bench.py @@ -115,7 +115,7 @@ def main(): try: data = path.read_bytes() - benchmark_dataset(path.name, data, 10) + benchmark_dataset(path.name, data, 10_000) except Exception as e: print(f"❌ Error reading {file_path}: {e}") else: diff --git a/fregex/fregex.c b/fregex/fregex.c index 538430b..c91aa35 100644 --- a/fregex/fregex.c +++ b/fregex/fregex.c @@ -48,6 +48,10 @@ static inline size_t utf8_decode_cp( return (size_t)ret; } +static inline bool is_utf8_cont_byte(unsigned char b) { + return (b & 0xC0) == 0x80; +} + static inline bool is_cr_or_lf(unsigned int cp) { return cp == UNICODE_LF || cp == UNICODE_CR; } @@ -312,7 +316,6 @@ static size_t match_short_number(const char *p, const char *end) { } /* D) ?[^\s\p{L}\p{N}]++[\r\n]* */ -// Optional single ASCII space, then 1+ of (not whitespace, not letter, not number), static size_t match_punct_run(const char *p, const char *end) { const char *q = p; @@ -365,92 +368,82 @@ static size_t match_punct_run(const char *p, const char *end) { /* E) \s*[\r\n] */ static size_t match_ws_then_linebreak(const char *p, const char *end) { const char *q = p; + const char *best = NULL; - // Collect all positions while consuming whitespace - // TODO: ? Could we hit the limit - const char *positions[256]; - int pos_count = 0; - - // Store initial position (zero whitespace consumed) - positions[pos_count++] = q; - - while (q < end && pos_count < 255) { - unsigned int cp; - size_t n = utf8_decode_cp(q, end, &cp); - if (n == 0 || !is_space(cp)) - break; - q += n; - positions[pos_count++] = q; - } - - // Try positions from longest to shortest (backtracking) - // We need to find a position where the next character is a linebreak - for (int i = pos_count - 1; i >= 0; i--) { - q = positions[i]; - - // Check if next character is a linebreak - if (q < end) { - unsigned int br; - size_t nb = utf8_decode_cp(q, end, &br); - if (nb > 0 && is_cr_or_lf(br)) { - // Found a linebreak, include it and return - return (size_t)(q + nb - p); - } - } else { - // EOF reached, rule requires a linebreak so fail - continue; + // Check boundary before consuming any whitespace, too (zero-length \s*) + if (q < end) { + unsigned int nx; + size_t nn = utf8_decode_cp(q, end, &nx); + if (nn > 0 && is_cr_or_lf(nx)) { + best = q; // \s* = 0, [\r\n] = this char } } - // No position found where next char is a linebreak - return 0; + // Scan whitespace; at each boundary, test the next cp + while (q < end) { + unsigned int cp; + size_t n = utf8_decode_cp(q, end, &cp); + if (n == 0 || !is_space(cp)) + break; + q += n; // we consumed one whitespace cp; boundary is at q now + + if (q < end) { + unsigned int nx; + size_t nn = utf8_decode_cp(q, end, &nx); + if (nn > 0 && is_cr_or_lf(nx)) { + best = q; // prefer the rightmost usable boundary + } + } + } + + if (!best) return 0; + + // At 'best' the next cp is the CR/LF to include + unsigned int br; + size_t nb = utf8_decode_cp(best, end, &br); + return (size_t)((best + nb) - p); } /* F) \s+(?!\S) */ static size_t match_trailing_ws(const char *p, const char *end) { - if (p >= end) return 0; + if (p >= end) + return 0; - /* Must start with at least one whitespace */ - const char *q = p; - unsigned int cp; - size_t n = utf8_decode_cp(q, end, &cp); + // First cp must be whitespace + unsigned int cp; + size_t n = utf8_decode_cp(p, end, &cp); if (n == 0 || !is_space(cp)) return 0; - /* Collect all whitespace positions */ - // TODO: ? Could we hit the limit - const char *positions[256]; - positions[0] = q + n; // Position after first whitespace - int pos_count = 1; - - q += n; - - while (q < end && pos_count < 255) { - size_t m = utf8_decode_cp(q, end, &cp); + // Consume full whitespace run [p, r) + const char *r = p + n; + while (r < end) { + size_t m = utf8_decode_cp(r, end, &cp); if (m == 0 || !is_space(cp)) break; - q += m; - positions[pos_count++] = q; + r += m; } - /* Try positions from longest to shortest (backtracking) */ - for (int i = pos_count - 1; i >= 0; i--) { - q = positions[i]; - - /* Check negative lookahead: (?!\S) at this position */ - if (q < end) { - size_t k = utf8_decode_cp(q, end, &cp); - if (k > 0 && !is_space(cp)) { - continue; /* Next char is non-space, try shorter match */ - } - } - - /* Lookahead succeeded at this position */ - return (size_t)(q - p); + if (r == end) { + // Only whitespace to EOF -> take all of it + return (size_t)(r - p); } - - /* All positions failed lookahead */ - return 0; + + // Backtrack by exactly one whitespace cp + // If the run length is only 1 cp, F must fail. + // Find the start of the last whitespace cp in [p, r) + const char *t = r; + // step back to beginning of previous UTF-8 cp + do { + --t; + } while (t > p && is_utf8_cont_byte(*t)); + + if (t == p) { + // run had length 1 cp -> cannot backtrack to keep \s+ >= 1 + return 0; + } + // Now [p, t) is k-1 whitespace cps + return (size_t)(t - p); } /* G) \s+ */ diff --git a/fregex/fuzz.py b/fregex/fuzz.py index 0b22aaf..b0cd6d5 100644 --- a/fregex/fuzz.py +++ b/fregex/fuzz.py @@ -65,8 +65,6 @@ def gen_valid_unicode_string(rng: random.Random, max_len: int) -> str: continue return cp - MAX_WS_RUN = 255 - def is_ws_char(ch: str) -> bool: cp = ord(ch) return ch.isspace() or (cp in ws_cps) @@ -78,10 +76,10 @@ def gen_valid_unicode_string(rng: random.Random, max_len: int) -> str: seqs = ["\n", "\r", "\r\n"] unit = rng.choice(seqs) unit_len = len(unit) - max_reps = max(1, min(max_run // unit_len, MAX_WS_RUN // unit_len)) + max_reps = max(1, max_run // unit_len) seg = unit * rng.randint(1, max_reps) return seg - run = rng.randint(1, min(MAX_WS_RUN, max(1, max_run))) + run = rng.randint(1, max(1, max_run)) buf = [] for _ in range(run): cp = rng.choice(ws_cps) @@ -177,7 +175,6 @@ def gen_valid_unicode_string(rng: random.Random, max_len: int) -> str: buf: list[str] = [] curr_len = 0 - curr_ws_run = 0 # Build by segments until target_len while curr_len < target_len: remain = target_len - curr_len @@ -201,50 +198,32 @@ def gen_valid_unicode_string(rng: random.Random, max_len: int) -> str: if not seg: continue # Trim if needed - # Append with whitespace-run capping + # Append for ch in seg: if curr_len >= target_len: break if is_ws_char(ch): - if curr_ws_run >= MAX_WS_RUN: - # insert a non-whitespace breaker - breaker = '.' - buf.append(breaker) - curr_len += 1 - curr_ws_run = 0 - if curr_len >= target_len: - break buf.append(ch) curr_len += 1 - curr_ws_run += 1 else: buf.append(ch) curr_len += 1 - curr_ws_run = 0 # Occasionally end with trailing spaces to stress \s+(?!\S) if curr_len < max_len and rng.random() < 0.3: trail = gen_ws_segment(max_len - curr_len) if rng.random() < 0.7: trail = (' ' if rng.random() < 0.6 else '\t') * rng.randint(1, min(8, max_len - curr_len)) - # Append trailing with cap as well + # Append trailing for ch in trail: if curr_len >= max_len: break if is_ws_char(ch): - if curr_ws_run >= MAX_WS_RUN: - buf.append('.') - curr_len += 1 - curr_ws_run = 0 - if curr_len >= max_len: - break buf.append(ch) curr_len += 1 - curr_ws_run += 1 else: buf.append(ch) curr_len += 1 - curr_ws_run = 0 return ''.join(buf) From 851810c7d57b0dd94b97465f9c1ccc1d6cc3cf50 Mon Sep 17 00:00:00 2001 From: MadMax129 Date: Fri, 24 Oct 2025 17:06:06 -0400 Subject: [PATCH 4/5] remove string allocations --- fregex/bench.py | 91 +++++++++++++++++++++++++++++++---------------- fregex/cload.py | 51 ++++++++++++++++++++------ fregex/compare.py | 15 -------- fregex/fregex.c | 43 +++++++--------------- fregex/fregex.h | 12 +++---- fregex/fuzz.py | 15 -------- 6 files changed, 118 insertions(+), 109 deletions(-) diff --git a/fregex/bench.py b/fregex/bench.py index 64dcda0..7231e18 100755 --- a/fregex/bench.py +++ b/fregex/bench.py @@ -3,37 +3,45 @@ import ctypes import random import time import statistics +import os +import gc from pathlib import Path from nanochat.tokenizer import SPLIT_PATTERN + +os.environ.update({ + 'OMP_NUM_THREADS': '1', + 'OPENBLAS_NUM_THREADS': '1', + 'MKL_NUM_THREADS': '1', + 'VECLIB_MAXIMUM_THREADS': '1', + 'NUMEXPR_NUM_THREADS': '1', + 'RAYON_NUM_THREADS': '1', +}) + +os.setpriority(os.PRIO_PROCESS, 0, -10) + from rustbpe import split_text as rust_split_text from fregex.fuzz import gen_valid_unicode_string, compare_pair_text from fregex.cload import * -def bench_c_regex(data: bytes, iterations: int) -> list: - times = [] - for _ in range(iterations): - token_list = TokenList() - c_lib.tokenlist_init(ctypes.byref(token_list)) - - start = time.perf_counter() - c_lib.tokenize_fast(data, len(data), ctypes.byref(token_list)) - elapsed = time.perf_counter() - start - - c_lib.tokenlist_free(ctypes.byref(token_list)) - times.append(elapsed * 1000) - - return times +PyBytes_AsString = ctypes.pythonapi.PyBytes_AsString +PyBytes_AsString.restype = ctypes.c_void_p +PyBytes_AsString.argtypes = [ctypes.py_object] -def bench_rust_regex(text: str, iterations: int) -> list: - times = [] - for _ in range(iterations): - start = time.perf_counter() - rust_split_text(SPLIT_PATTERN, text) - elapsed = time.perf_counter() - start - times.append(elapsed * 1000) - - return times +def _run_once_c(data: bytes) -> float: + token_list = TokenList() + c_lib.tokenlist_init(ctypes.byref(token_list)) + base_ptr = PyBytes_AsString(data) + t0 = time.perf_counter_ns() + c_lib.tokenize_fast(base_ptr, len(data), ctypes.byref(token_list)) + dt_ms = (time.perf_counter_ns() - t0) / 1e6 + c_lib.tokenlist_free(ctypes.byref(token_list)) + return dt_ms + +def _run_once_rust(text: str) -> float: + t0 = time.perf_counter_ns() + rust_split_text(SPLIT_PATTERN, text) + return (time.perf_counter_ns() - t0) / 1e6 def stats_summary(times: list) -> dict: """Compute statistics from timing list.""" @@ -65,11 +73,33 @@ def benchmark_dataset(name: str, data_bytes: bytes, iterations: int) -> None: print(f"\n--- Dataset: {name} ({len(data_bytes)} bytes, {iterations} iterations) ---") print() - - c_times = bench_c_regex(data_bytes, iterations) + + # Pre-touch data to avoid first-touch/page-fault skew + if data_bytes: + _ = data_bytes[0] + for i in range(0, len(data_bytes), 4096): + _ = data_bytes[i] + + # Warm-up + for _ in range(20): + _run_once_c(data_bytes) + _run_once_rust(test_text) + + # Disable GC during timed section + gc_was_enabled = gc.isenabled() + if gc_was_enabled: + gc.disable() + + c_times = [] + rust_times = [] + for _ in range(iterations): + c_times.append(_run_once_c(data_bytes)) + rust_times.append(_run_once_rust(test_text)) + + if gc_was_enabled: + gc.enable() + print(format_stats("C tokenizer", len(data_bytes), c_times), end='') - - rust_times = bench_rust_regex(test_text, iterations) print(format_stats("Rust split", len(data_bytes), rust_times), end='') if c_times and rust_times: @@ -115,7 +145,7 @@ def main(): try: data = path.read_bytes() - benchmark_dataset(path.name, data, 10_000) + benchmark_dataset(path.name, data, 1_000) except Exception as e: print(f"❌ Error reading {file_path}: {e}") else: @@ -124,8 +154,8 @@ def main(): ("tiny", 100, 1000), ("small", 1024, 500), ("medium", 10 * 1024, 100), - ("large", 100 * 1024, 30), - ("xlarge", 1024 * 1024, 10), + ("large", 100 * 1024, 100), + ("xlarge", 1024 * 1024, 100), ] for name, size_bytes, iterations in configs: @@ -140,6 +170,5 @@ def main(): print("=" * 140) - if __name__ == "__main__": main() diff --git a/fregex/cload.py b/fregex/cload.py index 2a17c94..3def82d 100644 --- a/fregex/cload.py +++ b/fregex/cload.py @@ -2,19 +2,50 @@ import ctypes c_lib = ctypes.CDLL("fregex/libfregex.dylib") -class TokenList(ctypes.Structure): - pass -TokenList._fields_ = [ - ("tokens", ctypes.POINTER(ctypes.POINTER(ctypes.c_char))), - ("lengths", ctypes.POINTER(ctypes.c_size_t)), - ("count", ctypes.c_size_t), - ("capacity", ctypes.c_size_t), -] +class TokenPos(ctypes.Structure): + _fields_ = [ + ("start", ctypes.c_size_t), + ("end", ctypes.c_size_t), + ] + + +class TokenList(ctypes.Structure): + _fields_ = [ + ("splits", ctypes.POINTER(TokenPos)), + ("count", ctypes.c_size_t), + ("capacity", ctypes.c_size_t), + ] + c_lib.tokenlist_init.argtypes = [ctypes.POINTER(TokenList)] c_lib.tokenlist_init.restype = None c_lib.tokenlist_free.argtypes = [ctypes.POINTER(TokenList)] c_lib.tokenlist_free.restype = None -c_lib.tokenize_fast.argtypes = [ctypes.c_char_p, ctypes.c_size_t, ctypes.POINTER(TokenList)] -c_lib.tokenize_fast.restype = None \ No newline at end of file +# Accept a raw pointer to the input buffer rather than a Python bytes object +c_lib.tokenize_fast.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(TokenList)] +c_lib.tokenize_fast.restype = None + +def tokenize_c_bytes(data: bytes) -> list[bytes]: + # Use a C char* view of the original bytes; offsets computed from this base + c_data = ctypes.c_char_p(data) + tl = TokenList() + c_lib.tokenlist_init(ctypes.byref(tl)) + try: + base_addr = ctypes.cast(c_data, ctypes.c_void_p).value + # Pass the same pointer to C + c_lib.tokenize_fast(ctypes.cast(c_data, ctypes.c_void_p), len(data), ctypes.byref(tl)) + out: list[bytes] = [] + count = int(tl.count) + for i in range(count): + start_addr = int(tl.splits[i].start) + end_addr = int(tl.splits[i].end) + # Compute offsets into our local buffer + off_start = start_addr - base_addr + off_end = end_addr - base_addr + if off_start < 0 or off_end < off_start or off_end > len(data): + raise RuntimeError(f"Invalid span [{start_addr}:{end_addr}] for buffer base {base_addr}") + out.append(data[off_start:off_end]) + return out + finally: + c_lib.tokenlist_free(ctypes.byref(tl)) \ No newline at end of file diff --git a/fregex/compare.py b/fregex/compare.py index ffdcb50..380c2d1 100644 --- a/fregex/compare.py +++ b/fregex/compare.py @@ -33,21 +33,6 @@ def escape_bytes(b: bytes) -> str: def dump_tokens(tokens: list[bytes]) -> str: return "\n".join(f"{len(b)}\t{escape_bytes(b)}" for b in tokens) -def tokenize_c_bytes(data: bytes) -> list[bytes]: - tl = TokenList() - c_lib.tokenlist_init(ctypes.byref(tl)) - try: - c_lib.tokenize_fast(data, len(data), ctypes.byref(tl)) - out: list[bytes] = [] - count = int(tl.count) - for i in range(count): - ptr = tl.tokens[i] - ln = int(tl.lengths[i]) - out.append(ctypes.string_at(ptr, ln)) - return out - finally: - c_lib.tokenlist_free(ctypes.byref(tl)) - def tokenize_py_bytes(data: bytes) -> list[bytes]: text = data.decode('utf-8', errors='surrogatepass') toks = py_tokenize_str(text) diff --git a/fregex/fregex.c b/fregex/fregex.c index c91aa35..466d41c 100644 --- a/fregex/fregex.c +++ b/fregex/fregex.c @@ -1,4 +1,4 @@ -#include "fregex.h" +#include "fregex-2.h" #include #include @@ -136,34 +136,29 @@ static char *xmemdup(const char *src, size_t len) { } void tokenlist_init(TokenList *list) { - list->tokens = NULL; - list->lengths = NULL; + list->splits = NULL; list->count = 0; list->capacity = 0; } void tokenlist_free(TokenList *list) { - if (!list) + if (!list) return; - for (size_t i = 0; i < list->count; ++i) - free(list->tokens[i]); - free(list->tokens); - free(list->lengths); - list->tokens = NULL; - list->lengths = NULL; - list->count = 0; + free(list->splits); + list->splits = NULL; + list->count = 0; list->capacity = 0; } static void tokenlist_push(TokenList *list, const char *start, size_t len) { if (list->count == list->capacity) { - const size_t new_cap = list->capacity ? (list->capacity * 2) : 64; - list->tokens = (char**)xrealloc(list->tokens, new_cap * sizeof(char*)); - list->lengths = (size_t*)xrealloc(list->lengths, new_cap * sizeof(size_t)); + const size_t new_cap = list->capacity ? (list->capacity * 2) : 128; + list->splits = (TokenPos*)xrealloc(list->splits, new_cap * sizeof(TokenPos)); list->capacity = new_cap; } - list->tokens[list->count] = xmemdup(start, len); - list->lengths[list->count] = len; + /* Write the start / end position of string */ + list->splits[list->count].start = (size_t)start; + list->splits[list->count].end = (size_t)(start + len); // len - 1 ? list->count++; } @@ -185,20 +180,6 @@ static void fput_escaped_char(unsigned char c, FILE *out) { } } -void print_token_escaped(const char *s, size_t len, FILE *out) { - fprintf(out, "%zu\t", len); - for (size_t i = 0; i < len; ++i) fput_escaped_char((unsigned char)s[i], out); - fputc('\n', out); -} - -void print_tokens_escaped(const TokenList *list, FILE *out) { - for (size_t i = 0; i < list->count; ++i) { - const char *tok = list->tokens[i]; - size_t len = list->lengths[i]; - print_token_escaped(tok, len, out); - } -} - /* A) '(?i:[sdmt]|ll|ve|re) */ static size_t match_contraction(const char *p, const char *end) { if (p >= end || *p != '\'' || (p + 1) >= end) @@ -470,7 +451,7 @@ static size_t match_ws_run(const char *p, const char *end) { void tokenize_fast(const char *input, size_t input_len, TokenList *out) { if (!input) { - out->tokens = NULL; + out->splits = NULL; out->count = 0; out->capacity = 0; return; diff --git a/fregex/fregex.h b/fregex/fregex.h index f41e3f1..e332115 100644 --- a/fregex/fregex.h +++ b/fregex/fregex.h @@ -5,19 +5,17 @@ #include typedef struct { - char **tokens; - size_t *lengths; + size_t start, end; +} TokenPos; + +typedef struct { + TokenPos *splits; size_t count; size_t capacity; } TokenList; void tokenlist_init(TokenList *list); void tokenlist_free(TokenList *list); - void tokenize_fast(const char *input, size_t input_len, TokenList *out); -void print_token_escaped(const char *s, size_t len, FILE *out); -void print_tokens_escaped(const TokenList *list, FILE *out); #endif // FAST_REGEX_H - - diff --git a/fregex/fuzz.py b/fregex/fuzz.py index b0cd6d5..2c29fbf 100644 --- a/fregex/fuzz.py +++ b/fregex/fuzz.py @@ -242,21 +242,6 @@ def _format_tokens_dump(tokens: list[bytes]) -> str: lines.append(f"{len(b)}\t{escape_bytes(b)}") return "\n".join(lines) -def tokenize_c_bytes(data: bytes) -> list[bytes]: - tl = TokenList() - c_lib.tokenlist_init(ctypes.byref(tl)) - try: - c_lib.tokenize_fast(data, len(data), ctypes.byref(tl)) - out: list[bytes] = [] - count = int(tl.count) - for i in range(count): - ptr = tl.tokens[i] - ln = int(tl.lengths[i]) - out.append(ctypes.string_at(ptr, ln)) - return out - finally: - c_lib.tokenlist_free(ctypes.byref(tl)) - def tokenize_py_bytes(data: bytes) -> list[bytes]: if py_tokenize_str is None: raise RuntimeError("py_tokenizer not available") From 0a1059d571e67d127bf2b665a120139ea457adda Mon Sep 17 00:00:00 2001 From: MadMax129 Date: Fri, 24 Oct 2025 18:52:26 -0400 Subject: [PATCH 5/5] add into rustbpe --- rustbpe/Cargo.lock | 18 +++++++ rustbpe/Cargo.toml | 2 + rustbpe/src/lib.rs | 121 +++++++++++++++++++++++++++++++++++---------- 3 files changed, 116 insertions(+), 25 deletions(-) diff --git a/rustbpe/Cargo.lock b/rustbpe/Cargo.lock index 69f8754..8c873ac 100644 --- a/rustbpe/Cargo.lock +++ b/rustbpe/Cargo.lock @@ -186,6 +186,16 @@ version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if", + "windows-link", +] + [[package]] name = "log" version = "0.4.28" @@ -363,7 +373,9 @@ dependencies = [ "dary_heap", "fancy-regex", "indexmap", + "libloading", "log", + "once_cell", "pyo3", "pyo3-log", "rayon", @@ -431,6 +443,12 @@ dependencies = [ "wit-bindgen", ] +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + [[package]] name = "wit-bindgen" version = "0.45.1" diff --git a/rustbpe/Cargo.toml b/rustbpe/Cargo.toml index 392a828..fb4084f 100644 --- a/rustbpe/Cargo.toml +++ b/rustbpe/Cargo.toml @@ -13,3 +13,5 @@ pyo3-log = "0.12.4" ahash = "0.8.12" rayon = "1.11.0" compact_str = "0.9.0" +libloading = "0.8.5" +once_cell = "1.19.0" diff --git a/rustbpe/src/lib.rs b/rustbpe/src/lib.rs index c092cc8..9da42fc 100644 --- a/rustbpe/src/lib.rs +++ b/rustbpe/src/lib.rs @@ -3,11 +3,12 @@ use std::collections::HashMap as StdHashMap; use dary_heap::OctonaryHeap; use fancy_regex::Regex; +use libloading::Library; +use once_cell::sync::OnceCell; use pyo3::prelude::*; use pyo3::wrap_pyfunction; use ahash::{AHashMap, AHashSet}; -use compact_str::CompactString; use rayon::prelude::*; // Default GPT-4 style regex pattern for splitting text @@ -15,6 +16,79 @@ const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{ type Pair = (u32, u32); +#[allow(non_camel_case_types)] +type c_char = std::os::raw::c_char; + +#[repr(C)] +struct CTokenPos { + start: usize, + end: usize, +} + +#[repr(C)] +struct CTokenList { + splits: *mut CTokenPos, + count: usize, + capacity: usize, +} + +type FnTokenListInit = unsafe extern "C" fn(list: *mut CTokenList); +type FnTokenListFree = unsafe extern "C" fn(list: *mut CTokenList); +type FnTokenizeFast = unsafe extern "C" fn(input: *const c_char, input_len: usize, out: *mut CTokenList); + +struct FregexSymbols { + _lib: Library, + tokenlist_init: FnTokenListInit, + tokenlist_free: FnTokenListFree, + tokenize_fast: FnTokenizeFast, +} + +static FREGEX: OnceCell = OnceCell::new(); + +fn load_fregex() -> &'static FregexSymbols { + FREGEX.get_or_init(|| { + // NOTE: adjust this path per user if needed. + let path = "fregex/libfregex.dylib"; + let lib = unsafe { Library::new(path) }.unwrap_or_else(|e| { + panic!("Failed to load libfregex from {}: {}", path, e); + }); + unsafe { + let tokenlist_init: FnTokenListInit = *lib.get(b"tokenlist_init\0").expect("symbol tokenlist_init"); + let tokenlist_free: FnTokenListFree = *lib.get(b"tokenlist_free\0").expect("symbol tokenlist_free"); + let tokenize_fast: FnTokenizeFast = *lib.get(b"tokenize_fast\0").expect("symbol tokenize_fast"); + println!("rustbpe: loaded libfregex.dylib from {}", path); + FregexSymbols { _lib: lib, tokenlist_init, tokenlist_free, tokenize_fast } + } + }) +} + +fn tokenize_with_c_each<'a, F>(input: &'a [u8], mut on_piece: F) +where + F: FnMut(&'a [u8]), +{ + if input.is_empty() { + return; + } + let syms = load_fregex(); + let mut out = CTokenList { splits: std::ptr::null_mut(), count: 0, capacity: 0 }; + let base_ptr = input.as_ptr() as usize; + unsafe { + (syms.tokenlist_init)(&mut out as *mut CTokenList); + (syms.tokenize_fast)(input.as_ptr() as *const c_char, input.len(), &mut out as *mut CTokenList); + if !out.splits.is_null() { + let slice = std::slice::from_raw_parts(out.splits, out.count); + for pos in slice.iter() { + let start = pos.start.saturating_sub(base_ptr); + let end = pos.end.saturating_sub(base_ptr); + if end <= input.len() && start <= end { + on_piece(&input[start..end]); + } + } + } + (syms.tokenlist_free)(&mut out as *mut CTokenList); + } +} + /// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation #[pyclass] pub struct Tokenizer { @@ -296,8 +370,8 @@ impl Tokenizer { pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))? }; - // Global chunk counts - let mut counts: AHashMap = AHashMap::new(); + // Global chunk counts: own bytes once per unique chunk (no string copies) + let mut counts: AHashMap, i32> = AHashMap::new(); // Temporary buffer we refill under the GIL let mut buf: Vec = Vec::with_capacity(buffer_size); @@ -345,31 +419,28 @@ impl Tokenizer { total_sequences += buf.len() as u64; - let pattern = self.compiled_pattern.clone(); - let local: AHashMap = py.allow_threads(|| { + // Build per-string local counts that reference the buffer slices (no allocations) + let locals: Vec> = py.allow_threads(|| { buf.par_iter() .map(|s| { - let mut m: AHashMap = AHashMap::new(); - for mat in pattern.find_iter(s) { - let piece = mat.expect("regex match failed").as_str(); - *m.entry(CompactString::from(piece)).or_default() += 1; - } - m + let mut m: AHashMap<&[u8], i32> = AHashMap::new(); + let bytes = s.as_bytes(); + tokenize_with_c_each(bytes, |piece| { *m.entry(piece).or_default() += 1; }); + // Materialize as Vec to allow merging after parallel section + m.into_iter().collect::>() }) - .reduce( - || AHashMap::new(), - |mut a, b| { - for (k, v) in b { - *a.entry(k).or_default() += v; - } - a - }, - ) + .collect() }); - // Merge local into global (single-threaded) - for (k, v) in local { - *counts.entry(k).or_default() += v; + // Merge locals into global (single-threaded) without copying unless inserting new keys + for local in locals { + for (piece, v) in local { + if let Some(cnt) = counts.get_mut(piece) { + *cnt += v; + } else { + counts.insert(piece.to_vec(), v); + } + } } if exhausted { @@ -381,8 +452,8 @@ impl Tokenizer { // Materialize words & counts let mut words = Vec::with_capacity(counts.len()); let mut cvec = Vec::with_capacity(counts.len()); - for (chunk, c) in counts.into_iter() { - words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect())); + for (chunk_bytes, c) in counts.into_iter() { + words.push(Word::new(chunk_bytes.iter().map(|&b| b as u32).collect())); cvec.push(c); }