This commit is contained in:
Maks S. 2025-11-16 06:10:11 +01:00 committed by GitHub
commit cc6830a9f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1443 additions and 25 deletions

174
fregex/bench.py Executable file
View File

@ -0,0 +1,174 @@
import sys
import ctypes
import random
import time
import statistics
import os
import gc
from pathlib import Path
from nanochat.tokenizer import SPLIT_PATTERN
os.environ.update({
'OMP_NUM_THREADS': '1',
'OPENBLAS_NUM_THREADS': '1',
'MKL_NUM_THREADS': '1',
'VECLIB_MAXIMUM_THREADS': '1',
'NUMEXPR_NUM_THREADS': '1',
'RAYON_NUM_THREADS': '1',
})
os.setpriority(os.PRIO_PROCESS, 0, -10)
from rustbpe import split_text as rust_split_text
from fregex.fuzz import gen_valid_unicode_string, compare_pair_text
from fregex.cload import *
PyBytes_AsString = ctypes.pythonapi.PyBytes_AsString
PyBytes_AsString.restype = ctypes.c_void_p
PyBytes_AsString.argtypes = [ctypes.py_object]
def _run_once_c(data: bytes) -> float:
token_list = TokenList()
c_lib.tokenlist_init(ctypes.byref(token_list))
base_ptr = PyBytes_AsString(data)
t0 = time.perf_counter_ns()
c_lib.tokenize_fast(base_ptr, len(data), ctypes.byref(token_list))
dt_ms = (time.perf_counter_ns() - t0) / 1e6
c_lib.tokenlist_free(ctypes.byref(token_list))
return dt_ms
def _run_once_rust(text: str) -> float:
t0 = time.perf_counter_ns()
rust_split_text(SPLIT_PATTERN, text)
return (time.perf_counter_ns() - t0) / 1e6
def stats_summary(times: list) -> dict:
"""Compute statistics from timing list."""
if not times or len(times) == 0:
return {}
return {
'min': min(times),
'max': max(times),
'mean': statistics.mean(times),
'median': statistics.median(times),
'stdev': statistics.stdev(times) if len(times) > 1 else 0,
}
def format_stats(name: str, data_size: int, times: list) -> str:
"""Format timing statistics for output."""
if not times or len(times) == 0:
return f"{name:20} {data_size:>10} B --\n"
stats = stats_summary(times)
return (f"{name:20} {data_size:>10} B "
f"min={stats['min']:.3f}ms max={stats['max']:.3f}ms "
f"mean={stats['mean']:.3f}ms median={stats['median']:.3f}ms "
f"stdev={stats['stdev']:.3f}ms\n")
def benchmark_dataset(name: str, data_bytes: bytes, iterations: int) -> None:
test_text = data_bytes.decode('utf-8', errors='replace')
print(f"\n--- Dataset: {name} ({len(data_bytes)} bytes, {iterations} iterations) ---")
print()
# Pre-touch data to avoid first-touch/page-fault skew
if data_bytes:
_ = data_bytes[0]
for i in range(0, len(data_bytes), 4096):
_ = data_bytes[i]
# Warm-up
for _ in range(20):
_run_once_c(data_bytes)
_run_once_rust(test_text)
# Disable GC during timed section
gc_was_enabled = gc.isenabled()
if gc_was_enabled:
gc.disable()
c_times = []
rust_times = []
for _ in range(iterations):
c_times.append(_run_once_c(data_bytes))
rust_times.append(_run_once_rust(test_text))
if gc_was_enabled:
gc.enable()
print(format_stats("C tokenizer", len(data_bytes), c_times), end='')
print(format_stats("Rust split", len(data_bytes), rust_times), end='')
if c_times and rust_times:
c_mean = statistics.mean(c_times)
rust_mean = statistics.mean(rust_times)
ratio = rust_mean / c_mean
speedup = "C is faster" if ratio > 1 else "Rust is faster"
print(f"Speedup: {ratio:.2f}x ({speedup})")
print()
# Verify token splits match between C and Python regex tokenizer
cmp_text = data_bytes.decode('utf-8', errors='surrogatepass')
ok, err, out_c, out_py = compare_pair_text(cmp_text)
if ok:
print("Compare: OK (C vs Py splits match)")
else:
print("Compare: MISMATCH (C vs Py)")
if err:
print(err)
if out_c is not None and out_py is not None:
c_lines = out_c.splitlines()
p_lines = out_py.splitlines()
print(f"C tokens: {len(c_lines)} | Py tokens: {len(p_lines)}")
print("--- C (head) ---")
print("\n".join(c_lines[:10]))
print("--- Py (head) ---")
print("\n".join(p_lines[:10]))
# Stop the benchmark if mismatch detected
raise SystemExit(1)
def main():
# Check if files were provided as arguments
file_args = sys.argv[1:] if len(sys.argv) > 1 else []
# If files provided, benchmark them
if file_args:
for file_path in file_args:
path = Path(file_path)
if not path.exists():
print(f"❌ File not found: {file_path}")
continue
try:
data = path.read_bytes()
benchmark_dataset(path.name, data, 1_000)
except Exception as e:
print(f"❌ Error reading {file_path}: {e}")
else:
# Run random generated data
configs = [
("tiny", 100, 1000),
("small", 1024, 500),
("medium", 10 * 1024, 100),
("large", 100 * 1024, 100),
("xlarge", 1024 * 1024, 100),
]
for name, size_bytes, iterations in configs:
# Generate test data
test_text = gen_valid_unicode_string(
random.Random(hash(name)),
size_bytes
)
test_bytes = test_text.encode('utf-8')
benchmark_dataset(name, test_bytes, iterations)
print("=" * 140)
if __name__ == "__main__":
main()

51
fregex/cload.py Normal file
View File

@ -0,0 +1,51 @@
import ctypes
c_lib = ctypes.CDLL("fregex/libfregex.dylib")
class TokenPos(ctypes.Structure):
_fields_ = [
("start", ctypes.c_size_t),
("end", ctypes.c_size_t),
]
class TokenList(ctypes.Structure):
_fields_ = [
("splits", ctypes.POINTER(TokenPos)),
("count", ctypes.c_size_t),
("capacity", ctypes.c_size_t),
]
c_lib.tokenlist_init.argtypes = [ctypes.POINTER(TokenList)]
c_lib.tokenlist_init.restype = None
c_lib.tokenlist_free.argtypes = [ctypes.POINTER(TokenList)]
c_lib.tokenlist_free.restype = None
# Accept a raw pointer to the input buffer rather than a Python bytes object
c_lib.tokenize_fast.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(TokenList)]
c_lib.tokenize_fast.restype = None
def tokenize_c_bytes(data: bytes) -> list[bytes]:
# Use a C char* view of the original bytes; offsets computed from this base
c_data = ctypes.c_char_p(data)
tl = TokenList()
c_lib.tokenlist_init(ctypes.byref(tl))
try:
base_addr = ctypes.cast(c_data, ctypes.c_void_p).value
# Pass the same pointer to C
c_lib.tokenize_fast(ctypes.cast(c_data, ctypes.c_void_p), len(data), ctypes.byref(tl))
out: list[bytes] = []
count = int(tl.count)
for i in range(count):
start_addr = int(tl.splits[i].start)
end_addr = int(tl.splits[i].end)
# Compute offsets into our local buffer
off_start = start_addr - base_addr
off_end = end_addr - base_addr
if off_start < 0 or off_end < off_start or off_end > len(data):
raise RuntimeError(f"Invalid span [{start_addr}:{end_addr}] for buffer base {base_addr}")
out.append(data[off_start:off_end])
return out
finally:
c_lib.tokenlist_free(ctypes.byref(tl))

162
fregex/compare.py Normal file
View File

@ -0,0 +1,162 @@
import sys
import ctypes
from pathlib import Path
from nanochat.tokenizer import SPLIT_PATTERN
from rustbpe import split_text as rust_split_text
from fregex.cload import *
from fregex.py_tokenizer import tokenize_py as py_tokenize_str
def escape_bytes(b: bytes) -> str:
buf = []
for code in b:
if code == 0x5C:
buf.append('\\')
elif code == 0x0A:
buf.append('\\n')
elif code == 0x0D:
buf.append('\\r')
elif code == 0x09:
buf.append('\\t')
elif code == 0x0C:
buf.append('\\f')
elif code == 0x0B:
buf.append('\\v')
elif code == 0x22:
buf.append('\\"')
elif code < 32 or code >= 127:
buf.append(f"\\x{code:02X}")
else:
buf.append(chr(code))
return ''.join(buf)
def dump_tokens(tokens: list[bytes]) -> str:
return "\n".join(f"{len(b)}\t{escape_bytes(b)}" for b in tokens)
def tokenize_py_bytes(data: bytes) -> list[bytes]:
text = data.decode('utf-8', errors='surrogatepass')
toks = py_tokenize_str(text)
return [t.encode('utf-8', errors='surrogatepass') for t in toks]
def tokenize_rs_bytes(data: bytes) -> list[bytes]:
text = data.decode('utf-8', errors='surrogatepass')
parts = rust_split_text(SPLIT_PATTERN, text)
return [t.encode('utf-8', errors='surrogatepass') for t in parts]
def compare_one(path: Path) -> int:
data_bytes = Path(path).read_bytes()
try:
c_toks = tokenize_c_bytes(data_bytes)
except Exception as e:
print(f"C tokenizer failed on {path}:\n{e}", file=sys.stderr)
return 1
try:
py_toks = tokenize_py_bytes(data_bytes)
except Exception as e:
print(f"Python tokenizer failed on {path}:\n{e}", file=sys.stderr)
return 1
try:
rs_toks = tokenize_rs_bytes(data_bytes)
except Exception as e:
print(f"Rust split failed on {path}:\n{e}", file=sys.stderr)
return 1
out_c = dump_tokens(c_toks)
out_py = dump_tokens(py_toks)
out_rs = dump_tokens(rs_toks)
if out_c == out_py == out_rs:
print(f"OK {path.name}")
return 0
else:
print(f"DIFF {path.name}")
# Show a small 3-way diff at first differing line, with byte offsets
c_lines = out_c.splitlines()
p_lines = out_py.splitlines()
r_lines = out_rs.splitlines()
def parse_lines(lines):
parsed = []
for ln in lines:
# Format is: "<len>\t<escaped>"
try:
left, right = ln.split('\t', 1)
blen = int(left)
except Exception:
blen = 0
right = ln
parsed.append((blen, right))
return parsed
c_parsed = parse_lines(c_lines)
p_parsed = parse_lines(p_lines)
r_parsed = parse_lines(r_lines)
def byte_offsets(parsed):
offs = []
pos = 0
for blen, _ in parsed:
offs.append((pos, pos + blen))
pos += blen
return offs
c_offs = byte_offsets(c_parsed)
p_offs = byte_offsets(p_parsed)
r_offs = byte_offsets(r_parsed)
data_bytes = Path(path).read_bytes()
def print_unicode_debug(label, offs_list, idx):
if idx >= len(offs_list):
print(f" {label} piece: [n/a]")
return
start, end = offs_list[idx]
piece_bytes = data_bytes[start:end]
piece_text = piece_bytes.decode('utf-8', errors='replace')
if not piece_bytes:
print(f" {label} piece: [EMPTY]")
return
cp_parts = []
for ch in piece_text:
cp_parts.append(f"U+{ord(ch):04X}")
bytes_hex = ' '.join(f"{b:02X}" for b in piece_bytes)
print(f" {label} chars: {' | '.join(cp_parts)}")
print(f" {label} bytes: {bytes_hex} ({len(piece_bytes)}B, {len(piece_text)} chars)")
max_len = max(len(c_lines), len(p_lines), len(r_lines))
for i in range(max_len):
cl = c_lines[i] if i < len(c_lines) else "<eof>"
pl = p_lines[i] if i < len(p_lines) else "<eof>"
rl = r_lines[i] if i < len(r_lines) else "<eof>"
if not (cl == pl == rl):
# Collect byte positions if available
c_pos = f"[{c_offs[i][0]}:{c_offs[i][1]}]" if i < len(c_offs) else "[n/a]"
p_pos = f"[{p_offs[i][0]}:{p_offs[i][1]}]" if i < len(p_offs) else "[n/a]"
r_pos = f"[{r_offs[i][0]}:{r_offs[i][1]}]" if i < len(r_offs) else "[n/a]"
print(
f" line {i+1}:\n"
f" C: {cl} @ bytes {c_pos}\n"
f" Py: {pl} @ bytes {p_pos}\n"
f" Rs: {rl} @ bytes {r_pos}"
)
print(" === Unicode split detail ===")
print_unicode_debug("C", c_offs, i)
print_unicode_debug("Py", p_offs, i)
print_unicode_debug("Rs", r_offs, i)
break
return 2
def main():
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <tests-dir>")
sys.exit(2)
paths = sorted(Path(sys.argv[1]).glob('*.txt'))
bad = 0
for p in paths:
bad += compare_one(p)
print(f"Completed. Failures: {bad}")
if __name__ == '__main__':
main()

498
fregex/fregex.c Normal file
View File

@ -0,0 +1,498 @@
#include "fregex-2.h"
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#include <stdint.h>
#include <stdbool.h>
#include "utf8proc/utf8proc.h"
/*
Regex pattern we care about from nanochat/tokenizer.py SPLIT_PATTERN
Break it down:
A) '(?i:[sdmt]|ll|ve|re)
B) [^\r\n\p{L}\p{N}]?+\p{L}+
C) \p{N}{1,2}
D) ?[^\s\p{L}\p{N}]++[\r\n]*
E) \s*[\r\n]
F) \s+(?!\S)
G) \s+
*/
#define UNICODE_LF 0x000A // Line Feed
#define UNICODE_CR 0x000D // Carriage Return
static inline size_t utf8_decode_cp(
const char *s,
const char *end,
unsigned int *out_cp
) {
if (s >= end) {
*out_cp = 0;
return 0;
}
utf8proc_int32_t ch = 0;
ssize_t ret = utf8proc_iterate(
(const utf8proc_uint8_t*)s,
(ssize_t)(end - s),
&ch
);
if (ret < 0) {
// invalid sequence: treat as single byte
*out_cp = (unsigned char)*s;
return 1;
}
*out_cp = (unsigned int)ch;
return (size_t)ret;
}
static inline bool is_utf8_cont_byte(unsigned char b) {
return (b & 0xC0) == 0x80;
}
static inline bool is_cr_or_lf(unsigned int cp) {
return cp == UNICODE_LF || cp == UNICODE_CR;
}
static inline bool is_letter(unsigned int cp) {
utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp);
switch (cat) {
case UTF8PROC_CATEGORY_LU: // Letter, Uppercase
case UTF8PROC_CATEGORY_LL: // Letter, Lowercase
case UTF8PROC_CATEGORY_LT: // Letter, Titlecase
case UTF8PROC_CATEGORY_LM: // Letter, Modifier
case UTF8PROC_CATEGORY_LO: // Letter, Other
return true;
default:
return false;
}
}
static inline bool is_number(unsigned int cp) {
utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp);
switch (cat) {
case UTF8PROC_CATEGORY_ND: // Number, Decimal Digit
case UTF8PROC_CATEGORY_NL: // Number, Letter
case UTF8PROC_CATEGORY_NO: // Number, Other
return true;
default:
return false;
}
}
static inline bool is_space(unsigned int cp) {
utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp);
if (
cat == UTF8PROC_CATEGORY_ZS ||
cat == UTF8PROC_CATEGORY_ZL ||
cat == UTF8PROC_CATEGORY_ZP
) {
return true;
}
switch (cp) {
case 0x0009: // TAB
case 0x000A: // LF
case 0x000B: // VT
case 0x000C: // FF
case 0x000D: // CR
case 0x0085: // NEL
return true;
default:
return false;
}
}
static inline bool is_alnum(unsigned int cp) {
return is_letter(cp) || is_number(cp);
}
static inline bool is_non_space_letter_number(unsigned int cp) {
return !is_space(cp) && !is_alnum(cp);
}
static void *xrealloc(void *ptr, size_t new_size) {
void *p = realloc(ptr, new_size);
if (!p) {
fprintf(stderr, "Out of memory while reallocating %zu bytes\n", new_size);
exit(1);
}
return p;
}
static char *xmemdup(const char *src, size_t len) {
char *s = (char*)malloc(len + 1);
if (!s) {
fprintf(stderr, "Out of memory while allocating %zu bytes\n", len + 1);
exit(1);
}
memcpy(s, src, len);
s[len] = '\0';
return s;
}
void tokenlist_init(TokenList *list) {
list->splits = NULL;
list->count = 0;
list->capacity = 0;
}
void tokenlist_free(TokenList *list) {
if (!list)
return;
free(list->splits);
list->splits = NULL;
list->count = 0;
list->capacity = 0;
}
static void tokenlist_push(TokenList *list, const char *start, size_t len) {
if (list->count == list->capacity) {
const size_t new_cap = list->capacity ? (list->capacity * 2) : 128;
list->splits = (TokenPos*)xrealloc(list->splits, new_cap * sizeof(TokenPos));
list->capacity = new_cap;
}
/* Write the start / end position of string */
list->splits[list->count].start = (size_t)start;
list->splits[list->count].end = (size_t)(start + len); // len - 1 ?
list->count++;
}
static void fput_escaped_char(unsigned char c, FILE *out) {
switch (c) {
case '\\': fputs("\\\\", out); break;
case '\n': fputs("\\n", out); break;
case '\r': fputs("\\r", out); break;
case '\t': fputs("\\t", out); break;
case '\f': fputs("\\f", out); break;
case '\v': fputs("\\v", out); break;
case '\"': fputs("\\\"", out); break;
default:
if (c < 32 || c >= 127) {
fprintf(out, "\\x%02X", c);
} else {
fputc(c, out);
}
}
}
/* A) '(?i:[sdmt]|ll|ve|re) */
static size_t match_contraction(const char *p, const char *end) {
if (p >= end || *p != '\'' || (p + 1) >= end)
return 0;
unsigned char a = (unsigned char)p[1];
// locale-independent lowercase for ASCII letters:
// map AZ to az; leaves others unchanged
if (a >= 'A' && a <= 'Z')
a = (unsigned char)(a + ('a' - 'A'));
// 's 'd 'm 't
if (a == 's' || a == 'd' || a == 'm' || a == 't')
return 2;
// Need a second following byte for 'll 've 're
if (p + 2 >= end)
return 0;
unsigned char b = (unsigned char)p[2];
if (b >= 'A' && b <= 'Z')
b = (unsigned char)(b + ('a' - 'A'));
if ((a == 'l' && b == 'l') ||
(a == 'v' && b == 'e') ||
(a == 'r' && b == 'e')) {
return 3;
}
return 0;
}
/* B) [^\r\n\p{L}\p{N}]?+\p{L}+ */
static size_t match_word_with_optional_prefix(const char *p, const char *end) {
if (p >= end)
return 0;
const char *q = p;
unsigned int cp0;
size_t n0 = utf8_decode_cp(q, end, &cp0);
if (n0 == 0)
return 0;
const char *letters_start = q; // will point to first letter of \p{L}+
size_t prefix_bytes = 0;
// Consider optional one-codepoint prefix if first cp is NOT (CR/LF or alnum)
if (!is_cr_or_lf(cp0) && !is_alnum(cp0)) {
// Look ahead: must have at least one letter right after to commit the prefix
const char *after_prefix = q + n0;
if (after_prefix >= end) {
return 0; // no room for required \p{L}+
}
unsigned int cp1;
size_t n1 = utf8_decode_cp(after_prefix, end, &cp1);
if (n1 > 0 && is_letter(cp1)) {
// Commit the prefix (possessive) and start letters after it
prefix_bytes = n0;
letters_start = after_prefix;
q = after_prefix;
} else {
// Can't commit the prefix (no letter follows), so this branch fails
return 0;
}
} else if (is_letter(cp0)) {
// No prefix; we already sit on the first letter
letters_start = q;
} else {
// First cp is CR/LF or a number; this branch doesn't match
return 0;
}
// Consume \p{L}+ (require at least one letter)
size_t letter_count = 0;
const char *scan = letters_start;
while (scan < end) {
unsigned int cp;
size_t n = utf8_decode_cp(scan, end, &cp);
if (n == 0 || !is_letter(cp))
break;
scan += n;
letter_count++;
}
if (letter_count == 0) {
// Shouldn't happen given the look-ahead, but keep the guard
return 0;
}
return (size_t)(scan - p); // includes prefix (if any) + letters
}
/* C) \p{N}{1,2} */
static size_t match_short_number(const char *p, const char *end) {
if (p >= end)
return 0;
/* First number required */
unsigned int cp1;
size_t n1 = utf8_decode_cp(p, end, &cp1);
if (n1 == 0 || !is_number(cp1))
return 0;
const char *q = p + n1;
// Optional second number cp
if (q < end) {
unsigned int cp2;
size_t n2 = utf8_decode_cp(q, end, &cp2);
if (n2 > 0 && is_number(cp2))
return (size_t)((q + n2) - p); // 2 numbers
}
return (size_t)(q - p); // 1 number
}
/* D) ?[^\s\p{L}\p{N}]++[\r\n]* */
static size_t match_punct_run(const char *p, const char *end) {
const char *q = p;
/* Optional single ASCII space */
if (q < end && *q == ' ') {
const char *r = q + 1;
if (r >= end)
return 0;
unsigned int cp;
size_t n = utf8_decode_cp(r, end, &cp);
if (n == 0)
return 0;
// Eligible b/c not whitespace and not alnum
if (!is_space(cp) && !is_alnum(cp)) {
q = r; // commit the space
} else {
// Not followed by eligible punct
return 0;
}
}
// Now require at least one eligible (not whitespace, not alnum)
size_t took = 0;
while (q < end) {
unsigned int cp;
size_t n = utf8_decode_cp(q, end, &cp);
if (n == 0) break;
if (is_space(cp) || is_alnum(cp))
break; // stop on any whitespace or letter/number
q += n;
took++;
}
if (took == 0)
return 0; // must have at least one punctuation/symbol
// Finally, optionally absorb CR/LF sequence(s)
while (q < end) {
unsigned int cp;
size_t n = utf8_decode_cp(q, end, &cp);
if (n == 0 || !is_cr_or_lf(cp))
break;
q += n;
}
return (size_t)(q - p);
}
/* E) \s*[\r\n] */
static size_t match_ws_then_linebreak(const char *p, const char *end) {
const char *q = p;
const char *best = NULL;
// Check boundary before consuming any whitespace, too (zero-length \s*)
if (q < end) {
unsigned int nx;
size_t nn = utf8_decode_cp(q, end, &nx);
if (nn > 0 && is_cr_or_lf(nx)) {
best = q; // \s* = 0, [\r\n] = this char
}
}
// Scan whitespace; at each boundary, test the next cp
while (q < end) {
unsigned int cp;
size_t n = utf8_decode_cp(q, end, &cp);
if (n == 0 || !is_space(cp))
break;
q += n; // we consumed one whitespace cp; boundary is at q now
if (q < end) {
unsigned int nx;
size_t nn = utf8_decode_cp(q, end, &nx);
if (nn > 0 && is_cr_or_lf(nx)) {
best = q; // prefer the rightmost usable boundary
}
}
}
if (!best) return 0;
// At 'best' the next cp is the CR/LF to include
unsigned int br;
size_t nb = utf8_decode_cp(best, end, &br);
return (size_t)((best + nb) - p);
}
/* F) \s+(?!\S) */
static size_t match_trailing_ws(const char *p, const char *end) {
if (p >= end)
return 0;
// First cp must be whitespace
unsigned int cp;
size_t n = utf8_decode_cp(p, end, &cp);
if (n == 0 || !is_space(cp))
return 0;
// Consume full whitespace run [p, r)
const char *r = p + n;
while (r < end) {
size_t m = utf8_decode_cp(r, end, &cp);
if (m == 0 || !is_space(cp))
break;
r += m;
}
if (r == end) {
// Only whitespace to EOF -> take all of it
return (size_t)(r - p);
}
// Backtrack by exactly one whitespace cp
// If the run length is only 1 cp, F must fail.
// Find the start of the last whitespace cp in [p, r)
const char *t = r;
// step back to beginning of previous UTF-8 cp
do {
--t;
} while (t > p && is_utf8_cont_byte(*t));
if (t == p) {
// run had length 1 cp -> cannot backtrack to keep \s+ >= 1
return 0;
}
// Now [p, t) is k-1 whitespace cps
return (size_t)(t - p);
}
/* G) \s+ */
static size_t match_ws_run(const char *p, const char *end) {
if (p >= end)
return 0;
const char *q = p;
unsigned int cp;
size_t n = utf8_decode_cp(q, end, &cp);
if (n == 0 || !is_space(cp))
return 0;
/* We have at least one whitespace, consume the run */
q += n;
while (q < end) {
size_t m = utf8_decode_cp(q, end, &cp);
if (m == 0 || !is_space(cp))
break;
q += m;
}
return (size_t)(q - p);
}
void tokenize_fast(const char *input, size_t input_len, TokenList *out) {
if (!input) {
out->splits = NULL;
out->count = 0;
out->capacity = 0;
return;
}
const char *p = input;
const char *end = input + input_len;
while (p < end) {
/* Special tokens take precedence */
// TODO LATER
int captured = 0;
/* Evaluate case A */
captured = match_contraction(p, end);
/* Evaluate case B */
if (!captured) captured = match_word_with_optional_prefix(p, end);
/* Evaluate case C */
if (!captured) captured = match_short_number(p, end);
/* Evaluate case D */
if (!captured) captured = match_punct_run(p, end);
/* Evaluate case E */
if (!captured) captured = match_ws_then_linebreak(p, end);
/* Evaluate case F */
if (!captured) captured = match_trailing_ws(p, end);
/* Evaluate case G */
if (!captured) captured = match_ws_run(p, end);
if (captured) {
tokenlist_push(out, p, captured);
p += captured;
}
else {
/* Fallback take a single CP */
unsigned int cp;
size_t adv = utf8_decode_cp(p, end, &cp);
if (adv == 0)
adv = 1;
tokenlist_push(out, p, adv);
p += adv;
}
}
}

21
fregex/fregex.h Normal file
View File

@ -0,0 +1,21 @@
#ifndef FAST_REGEX_H
#define FAST_REGEX_H
#include <stddef.h>
#include <stdio.h>
typedef struct {
size_t start, end;
} TokenPos;
typedef struct {
TokenPos *splits;
size_t count;
size_t capacity;
} TokenList;
void tokenlist_init(TokenList *list);
void tokenlist_free(TokenList *list);
void tokenize_fast(const char *input, size_t input_len, TokenList *out);
#endif // FAST_REGEX_H

325
fregex/fuzz.py Normal file
View File

@ -0,0 +1,325 @@
import sys
import time
import random
import argparse
import unicodedata as u
import ctypes
from pathlib import Path
from fregex.cload import *
HERE = Path(__file__).resolve().parent
TESTS_DIR = HERE / "tests"
from fregex.py_tokenizer import tokenize_py as py_tokenize_str
def escape_bytes(b: bytes) -> str:
buf = []
for code in b:
if code == 0x5C:
buf.append('\\\\')
elif code == 0x0A:
buf.append('\\n')
elif code == 0x0D:
buf.append('\\r')
elif code == 0x09:
buf.append('\\t')
elif code == 0x0C:
buf.append('\\f')
elif code == 0x0B:
buf.append('\\v')
elif code == 0x22:
buf.append('\\"')
elif code < 32 or code >= 127:
buf.append(f"\\x{code:02X}")
else:
buf.append(chr(code))
return ''.join(buf)
def gen_valid_unicode_string(rng: random.Random, max_len: int) -> str:
target_len = rng.randint(0, max_len)
ws_cps = [
0x20, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, # space, \t, \n, \v, \f, \r
0x00A0, # NO-BREAK SPACE
0x1680, # OGHAM SPACE MARK
0x2000, 0x2001, 0x2002, 0x2003, 0x2004, 0x2005, 0x2006,
0x2007, 0x2008, 0x2009, 0x200A, # EN/EM/THIN/HAIR SPACES etc.
0x2028, 0x2029, # LINE SEPARATOR, PARAGRAPH SEPARATOR
0x202F, # NARROW NO-BREAK SPACE
0x205F, # MEDIUM MATHEMATICAL SPACE
0x3000, # IDEOGRAPHIC SPACE
0x200B, # ZERO WIDTH SPACE (not WS in Python, but hits tokenizer class)
0xFEFF, # ZERO WIDTH NO-BREAK SPACE
]
ascii_punct = [
ord(c)
for c in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
]
def rand_scalar_excluding_surrogates(lo: int, hi: int) -> int:
while True:
cp = rng.randint(lo, hi)
if 0xD800 <= cp <= 0xDFFF:
continue
return cp
def is_ws_char(ch: str) -> bool:
cp = ord(ch)
return ch.isspace() or (cp in ws_cps)
def gen_ws_segment(max_run: int) -> str:
# Mix of various spaces, often multi-length; sometimes explicit CRLFs
if rng.random() < 0.35:
# Build CR, LF, CRLF, or repeated newlines
seqs = ["\n", "\r", "\r\n"]
unit = rng.choice(seqs)
unit_len = len(unit)
max_reps = max(1, max_run // unit_len)
seg = unit * rng.randint(1, max_reps)
return seg
run = rng.randint(1, max(1, max_run))
buf = []
for _ in range(run):
cp = rng.choice(ws_cps)
buf.append(chr(cp))
return ''.join(buf)
def gen_letter_run(max_run: int) -> str:
run = rng.randint(1, max(1, max_run))
buf = []
for _ in range(run):
if rng.random() < 0.6:
# ASCII letters
base = ord('A') if rng.random() < 0.5 else ord('a')
buf.append(chr(base + rng.randint(0, 25)))
else:
# Any Unicode letter
while True:
cp = rand_scalar_excluding_surrogates(0x00A0, 0x10FFFF)
if u.category(chr(cp)).startswith('L'):
buf.append(chr(cp))
break
# optional prefix of single non-WS, non-letter, non-number to stress
# the leading [^\r\n\p{L}\p{N}]?+ in the regex
if rng.random() < 0.3:
buf.insert(0, gen_punc_run(1, allow_space=False))
return ''.join(buf)
def gen_number_run(max_run: int) -> str:
# Bias to lengths 1..2 per \p{N}{1,2}, but sometimes longer
if rng.random() < 0.7:
run = rng.randint(1, min(2, max_run))
else:
run = rng.randint(3, max(3, max_run))
buf = []
for _ in range(run):
if rng.random() < 0.75:
buf.append(chr(ord('0') + rng.randint(0, 9)))
else:
# Other numeric categories (Nd/Nl/No)
while True:
cp = rand_scalar_excluding_surrogates(0x00A0, 0x10FFFF)
if u.category(chr(cp)).startswith('N'):
buf.append(chr(cp))
break
return ''.join(buf)
def gen_punc_run(max_run: int, allow_space: bool = True) -> str:
run = rng.randint(1, max(1, max_run))
buf = []
# optional leading single space before punc block
if allow_space and rng.random() < 0.5:
buf.append(' ')
for _ in range(run):
if rng.random() < 0.6:
cp = rng.choice(ascii_punct)
else:
while True:
cp = rand_scalar_excluding_surrogates(0, 0x10FFFF)
ch = chr(cp)
if (
not u.category(ch).startswith('L') and
not u.category(ch).startswith('N') and
cp not in ws_cps and
not ch.isspace()
):
break
# ensure we don't accidentally add null
buf.append(chr(cp))
# optional trailing newlines to stress [\r\n]*
if rng.random() < 0.35:
tail = gen_ws_segment(3)
# Keep only CR/LF components in the tail for this case
tail = tail.replace('\t', '').replace('\v', '').replace('\f', '').replace(' ', '')
buf.append(tail)
return ''.join(buf)
def gen_contraction() -> str:
# e.g., we're, he'll, I'd, I'm, can't, they've
prefixes = [gen_letter_run( rng.randint(1, 6) )]
suffix = rng.choice(["s", "d", "m", "t", "ll", "ve", "re"])
return prefixes[0] + "'" + suffix
def gen_random_unicode(max_run: int) -> str:
run = rng.randint(1, max(1, max_run))
buf = []
for _ in range(run):
cp = rand_scalar_excluding_surrogates(0, 0x10FFFF)
try:
buf.append(chr(cp))
except ValueError:
continue
return ''.join(buf)
buf: list[str] = []
curr_len = 0
# Build by segments until target_len
while curr_len < target_len:
remain = target_len - curr_len
r = rng.random()
if r < 0.40:
seg = gen_ws_segment(remain)
elif r < 0.45:
# Explicit newline-focused segment
seg = ("\r\n" if rng.random() < 0.5 else ("\n" if rng.random() < 0.5 else "\r")) * rng.randint(1, max(1, remain))
elif r < 0.65:
seg = gen_letter_run(remain)
elif r < 0.75:
seg = gen_number_run(remain)
elif r < 0.90:
seg = gen_punc_run(remain)
elif r < 0.95:
seg = gen_contraction()
else:
seg = gen_random_unicode(remain)
if not seg:
continue
# Trim if needed
# Append
for ch in seg:
if curr_len >= target_len:
break
if is_ws_char(ch):
buf.append(ch)
curr_len += 1
else:
buf.append(ch)
curr_len += 1
# Occasionally end with trailing spaces to stress \s+(?!\S)
if curr_len < max_len and rng.random() < 0.3:
trail = gen_ws_segment(max_len - curr_len)
if rng.random() < 0.7:
trail = (' ' if rng.random() < 0.6 else '\t') * rng.randint(1, min(8, max_len - curr_len))
# Append trailing
for ch in trail:
if curr_len >= max_len:
break
if is_ws_char(ch):
buf.append(ch)
curr_len += 1
else:
buf.append(ch)
curr_len += 1
return ''.join(buf)
def write_temp_case(text: str, tag: str = "RUN") -> Path:
TESTS_DIR.mkdir(parents=True, exist_ok=True)
ts = int(time.time() * 1000)
fname = f"in_fuzz_{tag}_{ts}.txt"
path = TESTS_DIR / fname
with open(path, 'wb') as f:
f.write(text.encode('utf-8', errors='surrogatepass'))
return path
def _format_tokens_dump(tokens: list[bytes]) -> str:
lines = []
for b in tokens:
lines.append(f"{len(b)}\t{escape_bytes(b)}")
return "\n".join(lines)
def tokenize_py_bytes(data: bytes) -> list[bytes]:
if py_tokenize_str is None:
raise RuntimeError("py_tokenizer not available")
text = data.decode('utf-8', errors='surrogatepass')
toks = py_tokenize_str(text)
return [t.encode('utf-8', errors='surrogatepass') for t in toks]
def compare_pair_text(text: str):
data = text.encode('utf-8', errors='surrogatepass')
try:
toks_c = tokenize_c_bytes(data)
except Exception as e:
return False, f"C failed: {e}", None, None
try:
toks_py = tokenize_py_bytes(data)
except Exception as e:
return False, f"Py failed: {e}", None, None
ok = toks_c == toks_py
return ok, None, _format_tokens_dump(toks_c), _format_tokens_dump(toks_py)
def run_fuzz(iters: int, max_len: int, seed: int, stop_on_first: bool):
rng = random.Random(seed)
total = 0
mismatches = 0
last_save = None
for i in range(iters if iters > 0 else 1_000_000_000):
s = gen_valid_unicode_string(rng, max_len)
ok, err, out_c, out_py = compare_pair_text(s)
total += 1
if not ok:
mismatches += 1
fail_path = write_temp_case(s, tag="FAIL")
last_save = fail_path
print(f"Mismatch at iter {i}, saved to {fail_path}")
print(f"Seed: {seed}")
cps = [f"U+{ord(ch):04X}" for ch in s]
cats = [u.category(ch) for ch in s]
print(f"Text bytes len: {len(s.encode('utf-8','surrogatepass'))}, chars: {len(s)}")
print(f"Codepoints: {' '.join(cps)}")
print(f"Categories: {' '.join(cats)}")
if err:
print(err)
if out_c is not None:
print("--- C tokens ---")
print(out_c)
if out_py is not None:
print("--- Py tokens ---")
print(out_py)
if stop_on_first:
break
if (i + 1) % 100 == 0:
print(f"[fuzz] {i+1} cases, mismatches={mismatches}")
# print(out_c, out_py, sep="\n")
return total, mismatches, last_save
def main():
ap = argparse.ArgumentParser(description="Fuzz C vs Python tokenizers on random valid UTF-8 inputs")
ap.add_argument("--iters", type=int, default=0, help="Number of iterations (0 = very large run)")
ap.add_argument("--max-len", type=int, default=256, help="Maximum number of Unicode scalars per case")
ap.add_argument("--seed", type=int, default=12345, help="PRNG seed for reproducibility")
ap.add_argument("--stop-on-first", action="store_true", help="Stop at first mismatch (default: run all)")
args = ap.parse_args()
total, mismatches, last = run_fuzz(args.iters, args.max_len, args.seed, args.stop_on_first)
print(f"Completed {total} cases, mismatches={mismatches}")
if last:
print(f"Last failing case saved at: {last}")
if mismatches:
sys.exit(1)
if __name__ == "__main__":
main()

43
fregex/py_tokenizer.py Normal file
View File

@ -0,0 +1,43 @@
import sys
from nanochat.tokenizer import SPLIT_PATTERN
from tokenizers import pre_tokenizers, Regex
_SPLITTER = pre_tokenizers.Split(pattern=Regex(SPLIT_PATTERN), behavior="isolated", invert=False)
def escape_bytes(b: bytes) -> str:
buf = []
for code in b:
if code == 0x5C: buf.append('\\\\')
elif code == 0x0A: buf.append('\\n')
elif code == 0x0D: buf.append('\\r')
elif code == 0x09: buf.append('\\t')
elif code == 0x0C: buf.append('\\f')
elif code == 0x0B: buf.append('\\v')
elif code == 0x22: buf.append('\\"')
elif code < 32 or code >= 127:
buf.append(f"\\x{code:02X}")
else:
buf.append(chr(code))
return ''.join(buf)
def tokenize_py(text: str):
parts = _SPLITTER.pre_tokenize_str(text)
return [s for (s, _range) in parts]
def main():
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <input-file>", file=sys.stderr)
sys.exit(2)
with open(sys.argv[1], 'rb') as f:
data = f.read()
text = data.decode('utf-8', errors='surrogatepass')
for tok in tokenize_py(text):
b = tok.encode('utf-8', errors='surrogatepass')
esc = escape_bytes(b)
print(f"{len(b)}\t{esc}")
if __name__ == "__main__":
main()

38
fregex/rust_split.py Normal file
View File

@ -0,0 +1,38 @@
import sys
from nanochat.tokenizer import SPLIT_PATTERN
from rustbpe import split_text as rust_split_text
def escape_bytes(b: bytes) -> str:
buf = []
for code in b:
if code == 0x5C: buf.append('\\\\')
elif code == 0x0A: buf.append('\\n')
elif code == 0x0D: buf.append('\\r')
elif code == 0x09: buf.append('\\t')
elif code == 0x0C: buf.append('\\f')
elif code == 0x0B: buf.append('\\v')
elif code == 0x22: buf.append('\\"')
elif code < 32 or code >= 127:
buf.append(f"\\x{code:02X}")
else:
buf.append(chr(code))
return ''.join(buf)
def main():
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <input-file>", file=sys.stderr)
sys.exit(2)
with open(sys.argv[1], 'rb') as f:
data = f.read()
text = data.decode('utf-8', errors='surrogatepass')
parts = rust_split_text(SPLIT_PATTERN, text)
for tok in parts:
b = tok.encode('utf-8', errors='surrogatepass')
esc = escape_bytes(b)
print(f"{len(b)}\t{esc}")
if __name__ == "__main__":
main()

18
rustbpe/Cargo.lock generated
View File

@ -186,6 +186,16 @@ version = "0.2.175"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
[[package]]
name = "libloading"
version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55"
dependencies = [
"cfg-if",
"windows-link",
]
[[package]] [[package]]
name = "log" name = "log"
version = "0.4.28" version = "0.4.28"
@ -363,7 +373,9 @@ dependencies = [
"dary_heap", "dary_heap",
"fancy-regex", "fancy-regex",
"indexmap", "indexmap",
"libloading",
"log", "log",
"once_cell",
"pyo3", "pyo3",
"pyo3-log", "pyo3-log",
"rayon", "rayon",
@ -431,6 +443,12 @@ dependencies = [
"wit-bindgen", "wit-bindgen",
] ]
[[package]]
name = "windows-link"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]] [[package]]
name = "wit-bindgen" name = "wit-bindgen"
version = "0.45.1" version = "0.45.1"

View File

@ -13,3 +13,5 @@ pyo3-log = "0.12.4"
ahash = "0.8.12" ahash = "0.8.12"
rayon = "1.11.0" rayon = "1.11.0"
compact_str = "0.9.0" compact_str = "0.9.0"
libloading = "0.8.5"
once_cell = "1.19.0"

View File

@ -3,10 +3,12 @@ use std::collections::HashMap as StdHashMap;
use dary_heap::OctonaryHeap; use dary_heap::OctonaryHeap;
use fancy_regex::Regex; use fancy_regex::Regex;
use libloading::Library;
use once_cell::sync::OnceCell;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use ahash::{AHashMap, AHashSet}; use ahash::{AHashMap, AHashSet};
use compact_str::CompactString;
use rayon::prelude::*; use rayon::prelude::*;
// Default GPT-4 style regex pattern for splitting text // Default GPT-4 style regex pattern for splitting text
@ -14,6 +16,79 @@ const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{
type Pair = (u32, u32); type Pair = (u32, u32);
#[allow(non_camel_case_types)]
type c_char = std::os::raw::c_char;
#[repr(C)]
struct CTokenPos {
start: usize,
end: usize,
}
#[repr(C)]
struct CTokenList {
splits: *mut CTokenPos,
count: usize,
capacity: usize,
}
type FnTokenListInit = unsafe extern "C" fn(list: *mut CTokenList);
type FnTokenListFree = unsafe extern "C" fn(list: *mut CTokenList);
type FnTokenizeFast = unsafe extern "C" fn(input: *const c_char, input_len: usize, out: *mut CTokenList);
struct FregexSymbols {
_lib: Library,
tokenlist_init: FnTokenListInit,
tokenlist_free: FnTokenListFree,
tokenize_fast: FnTokenizeFast,
}
static FREGEX: OnceCell<FregexSymbols> = OnceCell::new();
fn load_fregex() -> &'static FregexSymbols {
FREGEX.get_or_init(|| {
// NOTE: adjust this path per user if needed.
let path = "fregex/libfregex.dylib";
let lib = unsafe { Library::new(path) }.unwrap_or_else(|e| {
panic!("Failed to load libfregex from {}: {}", path, e);
});
unsafe {
let tokenlist_init: FnTokenListInit = *lib.get(b"tokenlist_init\0").expect("symbol tokenlist_init");
let tokenlist_free: FnTokenListFree = *lib.get(b"tokenlist_free\0").expect("symbol tokenlist_free");
let tokenize_fast: FnTokenizeFast = *lib.get(b"tokenize_fast\0").expect("symbol tokenize_fast");
println!("rustbpe: loaded libfregex.dylib from {}", path);
FregexSymbols { _lib: lib, tokenlist_init, tokenlist_free, tokenize_fast }
}
})
}
fn tokenize_with_c_each<'a, F>(input: &'a [u8], mut on_piece: F)
where
F: FnMut(&'a [u8]),
{
if input.is_empty() {
return;
}
let syms = load_fregex();
let mut out = CTokenList { splits: std::ptr::null_mut(), count: 0, capacity: 0 };
let base_ptr = input.as_ptr() as usize;
unsafe {
(syms.tokenlist_init)(&mut out as *mut CTokenList);
(syms.tokenize_fast)(input.as_ptr() as *const c_char, input.len(), &mut out as *mut CTokenList);
if !out.splits.is_null() {
let slice = std::slice::from_raw_parts(out.splits, out.count);
for pos in slice.iter() {
let start = pos.start.saturating_sub(base_ptr);
let end = pos.end.saturating_sub(base_ptr);
if end <= input.len() && start <= end {
on_piece(&input[start..end]);
}
}
}
(syms.tokenlist_free)(&mut out as *mut CTokenList);
}
}
/// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation /// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
#[pyclass] #[pyclass]
pub struct Tokenizer { pub struct Tokenizer {
@ -295,8 +370,8 @@ impl Tokenizer {
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))? pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
}; };
// Global chunk counts // Global chunk counts: own bytes once per unique chunk (no string copies)
let mut counts: AHashMap<CompactString, i32> = AHashMap::new(); let mut counts: AHashMap<Vec<u8>, i32> = AHashMap::new();
// Temporary buffer we refill under the GIL // Temporary buffer we refill under the GIL
let mut buf: Vec<String> = Vec::with_capacity(buffer_size); let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
@ -344,31 +419,28 @@ impl Tokenizer {
total_sequences += buf.len() as u64; total_sequences += buf.len() as u64;
let pattern = self.compiled_pattern.clone(); // Build per-string local counts that reference the buffer slices (no allocations)
let local: AHashMap<CompactString, i32> = py.allow_threads(|| { let locals: Vec<Vec<(&[u8], i32)>> = py.allow_threads(|| {
buf.par_iter() buf.par_iter()
.map(|s| { .map(|s| {
let mut m: AHashMap<CompactString, i32> = AHashMap::new(); let mut m: AHashMap<&[u8], i32> = AHashMap::new();
for mat in pattern.find_iter(s) { let bytes = s.as_bytes();
let piece = mat.expect("regex match failed").as_str(); tokenize_with_c_each(bytes, |piece| { *m.entry(piece).or_default() += 1; });
*m.entry(CompactString::from(piece)).or_default() += 1; // Materialize as Vec to allow merging after parallel section
} m.into_iter().collect::<Vec<(&[u8], i32)>>()
m
}) })
.reduce( .collect()
|| AHashMap::new(),
|mut a, b| {
for (k, v) in b {
*a.entry(k).or_default() += v;
}
a
},
)
}); });
// Merge local into global (single-threaded) // Merge locals into global (single-threaded) without copying unless inserting new keys
for (k, v) in local { for local in locals {
*counts.entry(k).or_default() += v; for (piece, v) in local {
if let Some(cnt) = counts.get_mut(piece) {
*cnt += v;
} else {
counts.insert(piece.to_vec(), v);
}
}
} }
if exhausted { if exhausted {
@ -380,8 +452,8 @@ impl Tokenizer {
// Materialize words & counts // Materialize words & counts
let mut words = Vec::with_capacity(counts.len()); let mut words = Vec::with_capacity(counts.len());
let mut cvec = Vec::with_capacity(counts.len()); let mut cvec = Vec::with_capacity(counts.len());
for (chunk, c) in counts.into_iter() { for (chunk_bytes, c) in counts.into_iter() {
words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect())); words.push(Word::new(chunk_bytes.iter().map(|&b| b as u32).collect()));
cvec.push(c); cvec.push(c);
} }
@ -467,9 +539,23 @@ impl Tokenizer {
} }
} }
#[pyfunction]
pub fn split_text(pattern: String, text: String) -> PyResult<Vec<String>> {
let re = Regex::new(&pattern)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?;
let mut out: Vec<String> = Vec::new();
for m in re.find_iter(&text) {
let m = m
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("Regex match failed: {}", e)))?;
out.push(m.as_str().to_string());
}
Ok(out)
}
#[pymodule] #[pymodule]
fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> { fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> {
pyo3_log::init(); // forwards Rust `log` to Python's `logging` pyo3_log::init(); // forwards Rust `log` to Python's `logging`
m.add_class::<Tokenizer>()?; m.add_class::<Tokenizer>()?;
m.add_function(wrap_pyfunction!(split_text, m)?)?;
Ok(()) Ok(())
} }