faster regex in C

This commit is contained in:
MadMax129 2025-10-23 16:59:10 -04:00
parent 2e938530ce
commit 12f418f0a1
9 changed files with 1364 additions and 0 deletions

157
fregex/bench.py Executable file
View File

@ -0,0 +1,157 @@
"""
Benchmarker for comparing tok.c tokenize_fast() vs rust split_text()
Measures speed WITHOUT subprocess overhead - direct function calls only.
Usage:
cd pytok
source ../.venv/bin/activate
python3 bench.py # Run synthetic data benchmarks
python3 bench.py /path/to/file.txt # Benchmark a specific file
python3 bench.py file1.txt file2.txt ... # Benchmark multiple files
"""
import sys
import ctypes
import random
import time
import statistics
from pathlib import Path
from nanochat.tokenizer import SPLIT_PATTERN
from rustbpe import split_text as rust_split_text
from fregex.fuzz import gen_valid_unicode_string, compare_pair_text
from fregex.cload import *
def bench_c_regex(data: bytes, iterations: int) -> list:
times = []
for _ in range(iterations):
token_list = TokenList()
c_lib.tokenlist_init(ctypes.byref(token_list))
start = time.perf_counter()
c_lib.tokenize_fast(data, len(data), ctypes.byref(token_list))
elapsed = time.perf_counter() - start
c_lib.tokenlist_free(ctypes.byref(token_list))
times.append(elapsed * 1000)
return times
def bench_rust_regex(text: str, iterations: int) -> list:
times = []
for _ in range(iterations):
start = time.perf_counter()
rust_split_text(SPLIT_PATTERN, text)
elapsed = time.perf_counter() - start
times.append(elapsed * 1000)
return times
def stats_summary(times: list) -> dict:
"""Compute statistics from timing list."""
if not times or len(times) == 0:
return {}
return {
'min': min(times),
'max': max(times),
'mean': statistics.mean(times),
'median': statistics.median(times),
'stdev': statistics.stdev(times) if len(times) > 1 else 0,
}
def format_stats(name: str, data_size: int, times: list) -> str:
"""Format timing statistics for output."""
if not times or len(times) == 0:
return f"{name:20} {data_size:>10} B --\n"
stats = stats_summary(times)
return (f"{name:20} {data_size:>10} B "
f"min={stats['min']:.3f}ms max={stats['max']:.3f}ms "
f"mean={stats['mean']:.3f}ms median={stats['median']:.3f}ms "
f"stdev={stats['stdev']:.3f}ms\n")
def benchmark_dataset(name: str, data_bytes: bytes, iterations: int) -> None:
test_text = data_bytes.decode('utf-8', errors='replace')
print(f"\n--- Dataset: {name} ({len(data_bytes)} bytes, {iterations} iterations) ---")
print()
c_times = bench_c_regex(data_bytes, iterations)
print(format_stats("C tokenizer", len(data_bytes), c_times), end='')
rust_times = bench_rust_regex(test_text, iterations)
print(format_stats("Rust split", len(data_bytes), rust_times), end='')
if c_times and rust_times:
c_mean = statistics.mean(c_times)
rust_mean = statistics.mean(rust_times)
ratio = rust_mean / c_mean
speedup = "C is faster" if ratio > 1 else "Rust is faster"
print(f"Speedup: {ratio:.2f}x ({speedup})")
print()
# Verify token splits match between C and Python regex tokenizer
cmp_text = data_bytes.decode('utf-8', errors='surrogatepass')
ok, err, out_c, out_py = compare_pair_text(cmp_text)
if ok:
print("Compare: OK (C vs Py splits match)")
else:
print("Compare: MISMATCH (C vs Py)")
if err:
print(err)
if out_c is not None and out_py is not None:
c_lines = out_c.splitlines()
p_lines = out_py.splitlines()
print(f"C tokens: {len(c_lines)} | Py tokens: {len(p_lines)}")
print("--- C (head) ---")
print("\n".join(c_lines[:10]))
print("--- Py (head) ---")
print("\n".join(p_lines[:10]))
# Stop the benchmark if mismatch detected
raise SystemExit(1)
def main():
# Check if files were provided as arguments
file_args = sys.argv[1:] if len(sys.argv) > 1 else []
# If files provided, benchmark them
if file_args:
for file_path in file_args:
path = Path(file_path)
if not path.exists():
print(f"❌ File not found: {file_path}")
continue
try:
data = path.read_bytes()
benchmark_dataset(path.name, data, 10)
except Exception as e:
print(f"❌ Error reading {file_path}: {e}")
else:
# Run random generated data
configs = [
("tiny", 100, 1000),
("small", 1024, 500),
("medium", 10 * 1024, 100),
("large", 100 * 1024, 30),
("xlarge", 1024 * 1024, 10),
]
for name, size_bytes, iterations in configs:
# Generate test data
test_text = gen_valid_unicode_string(
random.Random(hash(name)),
size_bytes
)
test_bytes = test_text.encode('utf-8')
benchmark_dataset(name, test_bytes, iterations)
print("=" * 140)
if __name__ == "__main__":
main()

20
fregex/cload.py Normal file
View File

@ -0,0 +1,20 @@
import ctypes
c_lib = ctypes.CDLL("fregex/libfregex.dylib")
class TokenList(ctypes.Structure):
pass
TokenList._fields_ = [
("tokens", ctypes.POINTER(ctypes.POINTER(ctypes.c_char))),
("lengths", ctypes.POINTER(ctypes.c_size_t)),
("count", ctypes.c_size_t),
("capacity", ctypes.c_size_t),
]
c_lib.tokenlist_init.argtypes = [ctypes.POINTER(TokenList)]
c_lib.tokenlist_init.restype = None
c_lib.tokenlist_free.argtypes = [ctypes.POINTER(TokenList)]
c_lib.tokenlist_free.restype = None
c_lib.tokenize_fast.argtypes = [ctypes.c_char_p, ctypes.c_size_t, ctypes.POINTER(TokenList)]
c_lib.tokenize_fast.restype = None

178
fregex/compare.py Normal file
View File

@ -0,0 +1,178 @@
import sys
import ctypes
from pathlib import Path
from nanochat.tokenizer import SPLIT_PATTERN
from rustbpe import split_text as rust_split_text
from fregex.cload import *
from fregex.py_tokenizer import tokenize_py as py_tokenize_str
def escape_bytes(b: bytes) -> str:
buf = []
for code in b:
if code == 0x5C:
buf.append('\\')
elif code == 0x0A:
buf.append('\\n')
elif code == 0x0D:
buf.append('\\r')
elif code == 0x09:
buf.append('\\t')
elif code == 0x0C:
buf.append('\\f')
elif code == 0x0B:
buf.append('\\v')
elif code == 0x22:
buf.append('\\"')
elif code < 32 or code >= 127:
buf.append(f"\\x{code:02X}")
else:
buf.append(chr(code))
return ''.join(buf)
def dump_tokens(tokens: list[bytes]) -> str:
return "\n".join(f"{len(b)}\t{escape_bytes(b)}" for b in tokens)
def tokenize_c_bytes(data: bytes) -> list[bytes]:
tl = TokenList()
c_lib.tokenlist_init(ctypes.byref(tl))
try:
c_lib.tokenize_fast(data, len(data), ctypes.byref(tl))
out: list[bytes] = []
count = int(tl.count)
for i in range(count):
ptr = tl.tokens[i]
ln = int(tl.lengths[i])
out.append(ctypes.string_at(ptr, ln))
return out
finally:
c_lib.tokenlist_free(ctypes.byref(tl))
def tokenize_py_bytes(data: bytes) -> list[bytes]:
text = data.decode('utf-8', errors='surrogatepass')
toks = py_tokenize_str(text)
return [t.encode('utf-8', errors='surrogatepass') for t in toks]
def tokenize_rs_bytes(data: bytes) -> list[bytes]:
text = data.decode('utf-8', errors='surrogatepass')
parts = rust_split_text(SPLIT_PATTERN, text)
return [t.encode('utf-8', errors='surrogatepass') for t in parts]
def compare_one(path: Path) -> int:
data_bytes = Path(path).read_bytes()
try:
c_toks = tokenize_c_bytes(data_bytes)
except Exception as e:
print(f"C tokenizer failed on {path}:\n{e}", file=sys.stderr)
return 1
try:
py_toks = tokenize_py_bytes(data_bytes)
except Exception as e:
print(f"Python tokenizer failed on {path}:\n{e}", file=sys.stderr)
return 1
try:
rs_toks = tokenize_rs_bytes(data_bytes)
except Exception as e:
print(f"Rust split failed on {path}:\n{e}", file=sys.stderr)
return 1
out_c = dump_tokens(c_toks)
out_py = dump_tokens(py_toks)
out_rs = dump_tokens(rs_toks)
if out_c == out_py == out_rs:
print(f"OK {path.name}")
return 0
else:
print(f"DIFF {path.name}")
# Show a small 3-way diff at first differing line, with byte offsets
c_lines = out_c.splitlines()
p_lines = out_py.splitlines()
r_lines = out_rs.splitlines()
def parse_lines(lines):
parsed = []
for ln in lines:
# Format is: "<len>\t<escaped>"
try:
left, right = ln.split('\t', 1)
blen = int(left)
except Exception:
blen = 0
right = ln
parsed.append((blen, right))
return parsed
c_parsed = parse_lines(c_lines)
p_parsed = parse_lines(p_lines)
r_parsed = parse_lines(r_lines)
def byte_offsets(parsed):
offs = []
pos = 0
for blen, _ in parsed:
offs.append((pos, pos + blen))
pos += blen
return offs
c_offs = byte_offsets(c_parsed)
p_offs = byte_offsets(p_parsed)
r_offs = byte_offsets(r_parsed)
# Load original input bytes so we can show precise substrings and code points
data_bytes = Path(path).read_bytes()
def print_unicode_debug(label, offs_list, idx):
if idx >= len(offs_list):
print(f" {label} piece: [n/a]")
return
start, end = offs_list[idx]
piece_bytes = data_bytes[start:end]
piece_text = piece_bytes.decode('utf-8', errors='replace')
if not piece_bytes:
print(f" {label} piece: [EMPTY]")
return
cp_parts = []
for ch in piece_text:
cp_parts.append(f"U+{ord(ch):04X}")
bytes_hex = ' '.join(f"{b:02X}" for b in piece_bytes)
print(f" {label} chars: {' | '.join(cp_parts)}")
print(f" {label} bytes: {bytes_hex} ({len(piece_bytes)}B, {len(piece_text)} chars)")
max_len = max(len(c_lines), len(p_lines), len(r_lines))
for i in range(max_len):
cl = c_lines[i] if i < len(c_lines) else "<eof>"
pl = p_lines[i] if i < len(p_lines) else "<eof>"
rl = r_lines[i] if i < len(r_lines) else "<eof>"
if not (cl == pl == rl):
# Collect byte positions if available
c_pos = f"[{c_offs[i][0]}:{c_offs[i][1]}]" if i < len(c_offs) else "[n/a]"
p_pos = f"[{p_offs[i][0]}:{p_offs[i][1]}]" if i < len(p_offs) else "[n/a]"
r_pos = f"[{r_offs[i][0]}:{r_offs[i][1]}]" if i < len(r_offs) else "[n/a]"
print(
f" line {i+1}:\n"
f" C: {cl} @ bytes {c_pos}\n"
f" Py: {pl} @ bytes {p_pos}\n"
f" Rs: {rl} @ bytes {r_pos}"
)
print(" === Unicode split detail ===")
print_unicode_debug("C", c_offs, i)
print_unicode_debug("Py", p_offs, i)
print_unicode_debug("Rs", r_offs, i)
break
return 2
def main():
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <tests-dir>")
sys.exit(2)
paths = sorted(Path(sys.argv[1]).glob('*.txt'))
bad = 0
for p in paths:
bad += compare_one(p)
print(f"Completed. Failures: {bad}")
if __name__ == '__main__':
main()

524
fregex/fregex.c Normal file
View File

@ -0,0 +1,524 @@
#include "fregex.h"
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#include <stdint.h>
#include <stdbool.h>
#include "utf8proc/utf8proc.h"
/*
Regex pattern we care about from nanochat/tokenizer.py SPLIT_PATTERN
Break it down:
A) '(?i:[sdmt]|ll|ve|re)
B) [^\r\n\p{L}\p{N}]?+\p{L}+
C) \p{N}{1,2}
D) ?[^\s\p{L}\p{N}]++[\r\n]*
E) \s*[\r\n]
F) \s+(?!\S)
G) \s+
*/
#define UNICODE_LF 0x000A // Line Feed
#define UNICODE_CR 0x000D // Carriage Return
static inline size_t utf8_decode_cp(
const char *s,
const char *end,
unsigned int *out_cp
) {
if (s >= end) {
*out_cp = 0;
return 0;
}
utf8proc_int32_t ch = 0;
ssize_t ret = utf8proc_iterate(
(const utf8proc_uint8_t*)s,
(ssize_t)(end - s),
&ch
);
if (ret < 0) {
// invalid sequence: treat as single byte
*out_cp = (unsigned char)*s;
return 1;
}
*out_cp = (unsigned int)ch;
return (size_t)ret;
}
static inline bool is_cr_or_lf(unsigned int cp) {
return cp == UNICODE_LF || cp == UNICODE_CR;
}
static inline bool is_letter(unsigned int cp) {
utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp);
switch (cat) {
case UTF8PROC_CATEGORY_LU: // Letter, Uppercase
case UTF8PROC_CATEGORY_LL: // Letter, Lowercase
case UTF8PROC_CATEGORY_LT: // Letter, Titlecase
case UTF8PROC_CATEGORY_LM: // Letter, Modifier
case UTF8PROC_CATEGORY_LO: // Letter, Other
return true;
default:
return false;
}
}
static inline bool is_number(unsigned int cp) {
utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp);
switch (cat) {
case UTF8PROC_CATEGORY_ND: // Number, Decimal Digit
case UTF8PROC_CATEGORY_NL: // Number, Letter
case UTF8PROC_CATEGORY_NO: // Number, Other
return true;
default:
return false;
}
}
static inline bool is_space(unsigned int cp) {
utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp);
if (
cat == UTF8PROC_CATEGORY_ZS ||
cat == UTF8PROC_CATEGORY_ZL ||
cat == UTF8PROC_CATEGORY_ZP
) {
return true;
}
switch (cp) {
case 0x0009: // TAB
case 0x000A: // LF
case 0x000B: // VT
case 0x000C: // FF
case 0x000D: // CR
case 0x0085: // NEL
return true;
default:
return false;
}
}
static inline bool is_alnum(unsigned int cp) {
return is_letter(cp) || is_number(cp);
}
static inline bool is_non_space_letter_number(unsigned int cp) {
return !is_space(cp) && !is_alnum(cp);
}
static void *xrealloc(void *ptr, size_t new_size) {
void *p = realloc(ptr, new_size);
if (!p) {
fprintf(stderr, "Out of memory while reallocating %zu bytes\n", new_size);
exit(1);
}
return p;
}
static char *xmemdup(const char *src, size_t len) {
char *s = (char*)malloc(len + 1);
if (!s) {
fprintf(stderr, "Out of memory while allocating %zu bytes\n", len + 1);
exit(1);
}
memcpy(s, src, len);
s[len] = '\0';
return s;
}
void tokenlist_init(TokenList *list) {
list->tokens = NULL;
list->lengths = NULL;
list->count = 0;
list->capacity = 0;
}
void tokenlist_free(TokenList *list) {
if (!list)
return;
for (size_t i = 0; i < list->count; ++i)
free(list->tokens[i]);
free(list->tokens);
free(list->lengths);
list->tokens = NULL;
list->lengths = NULL;
list->count = 0;
list->capacity = 0;
}
static void tokenlist_push(TokenList *list, const char *start, size_t len) {
if (list->count == list->capacity) {
const size_t new_cap = list->capacity ? (list->capacity * 2) : 64;
list->tokens = (char**)xrealloc(list->tokens, new_cap * sizeof(char*));
list->lengths = (size_t*)xrealloc(list->lengths, new_cap * sizeof(size_t));
list->capacity = new_cap;
}
list->tokens[list->count] = xmemdup(start, len);
list->lengths[list->count] = len;
list->count++;
}
static void fput_escaped_char(unsigned char c, FILE *out) {
switch (c) {
case '\\': fputs("\\\\", out); break;
case '\n': fputs("\\n", out); break;
case '\r': fputs("\\r", out); break;
case '\t': fputs("\\t", out); break;
case '\f': fputs("\\f", out); break;
case '\v': fputs("\\v", out); break;
case '\"': fputs("\\\"", out); break;
default:
if (c < 32 || c >= 127) {
fprintf(out, "\\x%02X", c);
} else {
fputc(c, out);
}
}
}
void print_token_escaped(const char *s, size_t len, FILE *out) {
fprintf(out, "%zu\t", len);
for (size_t i = 0; i < len; ++i) fput_escaped_char((unsigned char)s[i], out);
fputc('\n', out);
}
void print_tokens_escaped(const TokenList *list, FILE *out) {
for (size_t i = 0; i < list->count; ++i) {
const char *tok = list->tokens[i];
size_t len = list->lengths[i];
print_token_escaped(tok, len, out);
}
}
/* A) '(?i:[sdmt]|ll|ve|re) */
static size_t match_contraction(const char *p, const char *end) {
if (p >= end || *p != '\'' || (p + 1) >= end)
return 0;
unsigned char a = (unsigned char)p[1];
// locale-independent lowercase for ASCII letters:
// map AZ to az; leaves others unchanged
if (a >= 'A' && a <= 'Z')
a = (unsigned char)(a + ('a' - 'A'));
// 's 'd 'm 't
if (a == 's' || a == 'd' || a == 'm' || a == 't')
return 2;
// Need a second following byte for 'll 've 're
if (p + 2 >= end)
return 0;
unsigned char b = (unsigned char)p[2];
if (b >= 'A' && b <= 'Z')
b = (unsigned char)(b + ('a' - 'A'));
if ((a == 'l' && b == 'l') ||
(a == 'v' && b == 'e') ||
(a == 'r' && b == 'e')) {
return 3;
}
return 0;
}
/* B) [^\r\n\p{L}\p{N}]?+\p{L}+ */
static size_t match_word_with_optional_prefix(const char *p, const char *end) {
if (p >= end)
return 0;
const char *q = p;
unsigned int cp0;
size_t n0 = utf8_decode_cp(q, end, &cp0);
if (n0 == 0)
return 0;
const char *letters_start = q; // will point to first letter of \p{L}+
size_t prefix_bytes = 0;
// Consider optional one-codepoint prefix if first cp is NOT (CR/LF or alnum)
if (!is_cr_or_lf(cp0) && !is_alnum(cp0)) {
// Look ahead: must have at least one letter right after to commit the prefix
const char *after_prefix = q + n0;
if (after_prefix >= end) {
return 0; // no room for required \p{L}+
}
unsigned int cp1;
size_t n1 = utf8_decode_cp(after_prefix, end, &cp1);
if (n1 > 0 && is_letter(cp1)) {
// Commit the prefix (possessive) and start letters after it
prefix_bytes = n0;
letters_start = after_prefix;
q = after_prefix;
} else {
// Can't commit the prefix (no letter follows), so this branch fails
return 0;
}
} else if (is_letter(cp0)) {
// No prefix; we already sit on the first letter
letters_start = q;
} else {
// First cp is CR/LF or a number; this branch doesn't match
return 0;
}
// Consume \p{L}+ (require at least one letter)
size_t letter_count = 0;
const char *scan = letters_start;
while (scan < end) {
unsigned int cp;
size_t n = utf8_decode_cp(scan, end, &cp);
if (n == 0 || !is_letter(cp))
break;
scan += n;
letter_count++;
}
if (letter_count == 0) {
// Shouldn't happen given the look-ahead, but keep the guard
return 0;
}
return (size_t)(scan - p); // includes prefix (if any) + letters
}
/* C) \p{N}{1,2} */
static size_t match_short_number(const char *p, const char *end) {
if (p >= end)
return 0;
/* First number required */
unsigned int cp1;
size_t n1 = utf8_decode_cp(p, end, &cp1);
if (n1 == 0 || !is_number(cp1))
return 0;
const char *q = p + n1;
// Optional second number cp
if (q < end) {
unsigned int cp2;
size_t n2 = utf8_decode_cp(q, end, &cp2);
if (n2 > 0 && is_number(cp2))
return (size_t)((q + n2) - p); // 2 numbers
}
return (size_t)(q - p); // 1 number
}
/* D) ?[^\s\p{L}\p{N}]++[\r\n]* */
// Optional single ASCII space, then 1+ of (not whitespace, not letter, not number),
static size_t match_punct_run(const char *p, const char *end) {
const char *q = p;
/* Optional single ASCII space */
if (q < end && *q == ' ') {
const char *r = q + 1;
if (r >= end)
return 0;
unsigned int cp;
size_t n = utf8_decode_cp(r, end, &cp);
if (n == 0)
return 0;
// Eligible b/c not whitespace and not alnum
if (!is_space(cp) && !is_alnum(cp)) {
q = r; // commit the space
} else {
// Not followed by eligible punct
return 0;
}
}
// Now require at least one eligible (not whitespace, not alnum)
size_t took = 0;
while (q < end) {
unsigned int cp;
size_t n = utf8_decode_cp(q, end, &cp);
if (n == 0) break;
if (is_space(cp) || is_alnum(cp))
break; // stop on any whitespace or letter/number
q += n;
took++;
}
if (took == 0)
return 0; // must have at least one punctuation/symbol
// Finally, optionally absorb CR/LF sequence(s)
while (q < end) {
unsigned int cp;
size_t n = utf8_decode_cp(q, end, &cp);
if (n == 0 || !is_cr_or_lf(cp))
break;
q += n;
}
return (size_t)(q - p);
}
/* E) \s*[\r\n] */
static size_t match_ws_then_linebreak(const char *p, const char *end) {
const char *q = p;
// Collect all positions while consuming whitespace
// TODO: ? Could we hit the limit
const char *positions[256];
int pos_count = 0;
// Store initial position (zero whitespace consumed)
positions[pos_count++] = q;
while (q < end && pos_count < 255) {
unsigned int cp;
size_t n = utf8_decode_cp(q, end, &cp);
if (n == 0 || !is_space(cp))
break;
q += n;
positions[pos_count++] = q;
}
// Try positions from longest to shortest (backtracking)
// We need to find a position where the next character is a linebreak
for (int i = pos_count - 1; i >= 0; i--) {
q = positions[i];
// Check if next character is a linebreak
if (q < end) {
unsigned int br;
size_t nb = utf8_decode_cp(q, end, &br);
if (nb > 0 && is_cr_or_lf(br)) {
// Found a linebreak, include it and return
return (size_t)(q + nb - p);
}
} else {
// EOF reached, rule requires a linebreak so fail
continue;
}
}
// No position found where next char is a linebreak
return 0;
}
/* F) \s+(?!\S) */
static size_t match_trailing_ws(const char *p, const char *end) {
if (p >= end) return 0;
/* Must start with at least one whitespace */
const char *q = p;
unsigned int cp;
size_t n = utf8_decode_cp(q, end, &cp);
if (n == 0 || !is_space(cp))
return 0;
/* Collect all whitespace positions */
// TODO: ? Could we hit the limit
const char *positions[256];
positions[0] = q + n; // Position after first whitespace
int pos_count = 1;
q += n;
while (q < end && pos_count < 255) {
size_t m = utf8_decode_cp(q, end, &cp);
if (m == 0 || !is_space(cp))
break;
q += m;
positions[pos_count++] = q;
}
/* Try positions from longest to shortest (backtracking) */
for (int i = pos_count - 1; i >= 0; i--) {
q = positions[i];
/* Check negative lookahead: (?!\S) at this position */
if (q < end) {
size_t k = utf8_decode_cp(q, end, &cp);
if (k > 0 && !is_space(cp)) {
continue; /* Next char is non-space, try shorter match */
}
}
/* Lookahead succeeded at this position */
return (size_t)(q - p);
}
/* All positions failed lookahead */
return 0;
}
/* G) \s+ */
static size_t match_ws_run(const char *p, const char *end) {
if (p >= end)
return 0;
const char *q = p;
unsigned int cp;
size_t n = utf8_decode_cp(q, end, &cp);
if (n == 0 || !is_space(cp))
return 0;
/* We have at least one whitespace, consume the run */
q += n;
while (q < end) {
size_t m = utf8_decode_cp(q, end, &cp);
if (m == 0 || !is_space(cp))
break;
q += m;
}
return (size_t)(q - p);
}
void tokenize_fast(const char *input, size_t input_len, TokenList *out) {
if (!input) {
out->tokens = NULL;
out->count = 0;
out->capacity = 0;
return;
}
const char *p = input;
const char *end = input + input_len;
while (p < end) {
/* Special tokens take precedence */
// TODO LATER
int captured = 0;
/* Evaluate case A */
captured = match_contraction(p, end);
/* Evaluate case B */
if (!captured) captured = match_word_with_optional_prefix(p, end);
/* Evaluate case C */
if (!captured) captured = match_short_number(p, end);
/* Evaluate case D */
if (!captured) captured = match_punct_run(p, end);
/* Evaluate case E */
if (!captured) captured = match_ws_then_linebreak(p, end);
/* Evaluate case F */
if (!captured) captured = match_trailing_ws(p, end);
/* Evaluate case G */
if (!captured) captured = match_ws_run(p, end);
if (captured) {
tokenlist_push(out, p, captured);
p += captured;
}
else {
/* Fallback take a single CP */
unsigned int cp;
size_t adv = utf8_decode_cp(p, end, &cp);
if (adv == 0)
adv = 1;
tokenlist_push(out, p, adv);
p += adv;
}
}
}

28
fregex/fregex.h Normal file
View File

@ -0,0 +1,28 @@
#ifndef FAST_TOKENIZER_H
#define FAST_TOKENIZER_H
#include <stddef.h>
#include <stdio.h>
typedef struct {
char **tokens;
size_t *lengths; // Store length of each token to handle null bytes
size_t count;
size_t capacity;
} TokenList;
void tokenlist_init(TokenList *list);
void tokenlist_free(TokenList *list);
// Tokenize input according to the GPT-like regex split semantics
// r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
void tokenize_fast(const char *input, size_t input_len, TokenList *out);
// Utility to print tokens with C-like escaping (one per line):
// <length>\t<escaped-bytes>\n
void print_token_escaped(const char *s, size_t len, FILE *out);
void print_tokens_escaped(const TokenList *list, FILE *out);
#endif // FAST_TOKENIZER_H

361
fregex/fuzz.py Normal file
View File

@ -0,0 +1,361 @@
import sys
import time
import random
import argparse
import unicodedata as u
import ctypes
from pathlib import Path
from fregex.cload import *
HERE = Path(__file__).resolve().parent
TESTS_DIR = HERE / "tests"
from fregex.py_tokenizer import tokenize_py as py_tokenize_str
def escape_bytes(b: bytes) -> str:
buf = []
for code in b:
if code == 0x5C:
buf.append('\\\\')
elif code == 0x0A:
buf.append('\\n')
elif code == 0x0D:
buf.append('\\r')
elif code == 0x09:
buf.append('\\t')
elif code == 0x0C:
buf.append('\\f')
elif code == 0x0B:
buf.append('\\v')
elif code == 0x22:
buf.append('\\"')
elif code < 32 or code >= 127:
buf.append(f"\\x{code:02X}")
else:
buf.append(chr(code))
return ''.join(buf)
def gen_valid_unicode_string(rng: random.Random, max_len: int) -> str:
target_len = rng.randint(0, max_len)
ws_cps = [
0x20, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, # space, \t, \n, \v, \f, \r
0x00A0, # NO-BREAK SPACE
0x1680, # OGHAM SPACE MARK
0x2000, 0x2001, 0x2002, 0x2003, 0x2004, 0x2005, 0x2006,
0x2007, 0x2008, 0x2009, 0x200A, # EN/EM/THIN/HAIR SPACES etc.
0x2028, 0x2029, # LINE SEPARATOR, PARAGRAPH SEPARATOR
0x202F, # NARROW NO-BREAK SPACE
0x205F, # MEDIUM MATHEMATICAL SPACE
0x3000, # IDEOGRAPHIC SPACE
0x200B, # ZERO WIDTH SPACE (not WS in Python, but hits tokenizer class)
0xFEFF, # ZERO WIDTH NO-BREAK SPACE
]
ascii_punct = [
ord(c)
for c in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
]
def rand_scalar_excluding_surrogates(lo: int, hi: int) -> int:
while True:
cp = rng.randint(lo, hi)
if 0xD800 <= cp <= 0xDFFF:
continue
return cp
MAX_WS_RUN = 255
def is_ws_char(ch: str) -> bool:
cp = ord(ch)
return ch.isspace() or (cp in ws_cps)
def gen_ws_segment(max_run: int) -> str:
# Mix of various spaces, often multi-length; sometimes explicit CRLFs
if rng.random() < 0.35:
# Build CR, LF, CRLF, or repeated newlines
seqs = ["\n", "\r", "\r\n"]
unit = rng.choice(seqs)
unit_len = len(unit)
max_reps = max(1, min(max_run // unit_len, MAX_WS_RUN // unit_len))
seg = unit * rng.randint(1, max_reps)
return seg
run = rng.randint(1, min(MAX_WS_RUN, max(1, max_run)))
buf = []
for _ in range(run):
cp = rng.choice(ws_cps)
buf.append(chr(cp))
return ''.join(buf)
def gen_letter_run(max_run: int) -> str:
run = rng.randint(1, max(1, max_run))
buf = []
for _ in range(run):
if rng.random() < 0.6:
# ASCII letters
base = ord('A') if rng.random() < 0.5 else ord('a')
buf.append(chr(base + rng.randint(0, 25)))
else:
# Any Unicode letter
while True:
cp = rand_scalar_excluding_surrogates(0x00A0, 0x10FFFF)
if u.category(chr(cp)).startswith('L'):
buf.append(chr(cp))
break
# optional prefix of single non-WS, non-letter, non-number to stress
# the leading [^\r\n\p{L}\p{N}]?+ in the regex
if rng.random() < 0.3:
buf.insert(0, gen_punc_run(1, allow_space=False))
return ''.join(buf)
def gen_number_run(max_run: int) -> str:
# Bias to lengths 1..2 per \p{N}{1,2}, but sometimes longer
if rng.random() < 0.7:
run = rng.randint(1, min(2, max_run))
else:
run = rng.randint(3, max(3, max_run))
buf = []
for _ in range(run):
if rng.random() < 0.75:
buf.append(chr(ord('0') + rng.randint(0, 9)))
else:
# Other numeric categories (Nd/Nl/No)
while True:
cp = rand_scalar_excluding_surrogates(0x00A0, 0x10FFFF)
if u.category(chr(cp)).startswith('N'):
buf.append(chr(cp))
break
return ''.join(buf)
def gen_punc_run(max_run: int, allow_space: bool = True) -> str:
run = rng.randint(1, max(1, max_run))
buf = []
# optional leading single space before punc block
if allow_space and rng.random() < 0.5:
buf.append(' ')
for _ in range(run):
if rng.random() < 0.6:
cp = rng.choice(ascii_punct)
else:
while True:
cp = rand_scalar_excluding_surrogates(0, 0x10FFFF)
ch = chr(cp)
if (
not u.category(ch).startswith('L') and
not u.category(ch).startswith('N') and
cp not in ws_cps and
not ch.isspace()
):
break
# ensure we don't accidentally add null
buf.append(chr(cp))
# optional trailing newlines to stress [\r\n]*
if rng.random() < 0.35:
tail = gen_ws_segment(3)
# Keep only CR/LF components in the tail for this case
tail = tail.replace('\t', '').replace('\v', '').replace('\f', '').replace(' ', '')
buf.append(tail)
return ''.join(buf)
def gen_contraction() -> str:
# e.g., we're, he'll, I'd, I'm, can't, they've
prefixes = [gen_letter_run( rng.randint(1, 6) )]
suffix = rng.choice(["s", "d", "m", "t", "ll", "ve", "re"])
return prefixes[0] + "'" + suffix
def gen_random_unicode(max_run: int) -> str:
run = rng.randint(1, max(1, max_run))
buf = []
for _ in range(run):
cp = rand_scalar_excluding_surrogates(0, 0x10FFFF)
try:
buf.append(chr(cp))
except ValueError:
continue
return ''.join(buf)
buf: list[str] = []
curr_len = 0
curr_ws_run = 0
# Build by segments until target_len
while curr_len < target_len:
remain = target_len - curr_len
r = rng.random()
if r < 0.40:
seg = gen_ws_segment(remain)
elif r < 0.45:
# Explicit newline-focused segment
seg = ("\r\n" if rng.random() < 0.5 else ("\n" if rng.random() < 0.5 else "\r")) * rng.randint(1, max(1, remain))
elif r < 0.65:
seg = gen_letter_run(remain)
elif r < 0.75:
seg = gen_number_run(remain)
elif r < 0.90:
seg = gen_punc_run(remain)
elif r < 0.95:
seg = gen_contraction()
else:
seg = gen_random_unicode(remain)
if not seg:
continue
# Trim if needed
# Append with whitespace-run capping
for ch in seg:
if curr_len >= target_len:
break
if is_ws_char(ch):
if curr_ws_run >= MAX_WS_RUN:
# insert a non-whitespace breaker
breaker = '.'
buf.append(breaker)
curr_len += 1
curr_ws_run = 0
if curr_len >= target_len:
break
buf.append(ch)
curr_len += 1
curr_ws_run += 1
else:
buf.append(ch)
curr_len += 1
curr_ws_run = 0
# Occasionally end with trailing spaces to stress \s+(?!\S)
if curr_len < max_len and rng.random() < 0.3:
trail = gen_ws_segment(max_len - curr_len)
if rng.random() < 0.7:
trail = (' ' if rng.random() < 0.6 else '\t') * rng.randint(1, min(8, max_len - curr_len))
# Append trailing with cap as well
for ch in trail:
if curr_len >= max_len:
break
if is_ws_char(ch):
if curr_ws_run >= MAX_WS_RUN:
buf.append('.')
curr_len += 1
curr_ws_run = 0
if curr_len >= max_len:
break
buf.append(ch)
curr_len += 1
curr_ws_run += 1
else:
buf.append(ch)
curr_len += 1
curr_ws_run = 0
return ''.join(buf)
def write_temp_case(text: str, tag: str = "RUN") -> Path:
TESTS_DIR.mkdir(parents=True, exist_ok=True)
ts = int(time.time() * 1000)
fname = f"in_fuzz_{tag}_{ts}.txt"
path = TESTS_DIR / fname
with open(path, 'wb') as f:
f.write(text.encode('utf-8', errors='surrogatepass'))
return path
def _format_tokens_dump(tokens: list[bytes]) -> str:
lines = []
for b in tokens:
lines.append(f"{len(b)}\t{escape_bytes(b)}")
return "\n".join(lines)
def tokenize_c_bytes(data: bytes) -> list[bytes]:
tl = TokenList()
c_lib.tokenlist_init(ctypes.byref(tl))
try:
c_lib.tokenize_fast(data, len(data), ctypes.byref(tl))
out: list[bytes] = []
count = int(tl.count)
for i in range(count):
ptr = tl.tokens[i]
ln = int(tl.lengths[i])
out.append(ctypes.string_at(ptr, ln))
return out
finally:
c_lib.tokenlist_free(ctypes.byref(tl))
def tokenize_py_bytes(data: bytes) -> list[bytes]:
if py_tokenize_str is None:
raise RuntimeError("py_tokenizer not available")
text = data.decode('utf-8', errors='surrogatepass')
toks = py_tokenize_str(text)
return [t.encode('utf-8', errors='surrogatepass') for t in toks]
def compare_pair_text(text: str):
data = text.encode('utf-8', errors='surrogatepass')
try:
toks_c = tokenize_c_bytes(data)
except Exception as e:
return False, f"C failed: {e}", None, None
try:
toks_py = tokenize_py_bytes(data)
except Exception as e:
return False, f"Py failed: {e}", None, None
ok = toks_c == toks_py
return ok, None, _format_tokens_dump(toks_c), _format_tokens_dump(toks_py)
def run_fuzz(iters: int, max_len: int, seed: int, stop_on_first: bool):
rng = random.Random(seed)
total = 0
mismatches = 0
last_save = None
for i in range(iters if iters > 0 else 1_000_000_000):
s = gen_valid_unicode_string(rng, max_len)
ok, err, out_c, out_py = compare_pair_text(s)
total += 1
if not ok:
mismatches += 1
fail_path = write_temp_case(s, tag="FAIL")
last_save = fail_path
print(f"Mismatch at iter {i}, saved to {fail_path}")
print(f"Seed: {seed}")
cps = [f"U+{ord(ch):04X}" for ch in s]
cats = [u.category(ch) for ch in s]
print(f"Text bytes len: {len(s.encode('utf-8','surrogatepass'))}, chars: {len(s)}")
print(f"Codepoints: {' '.join(cps)}")
print(f"Categories: {' '.join(cats)}")
if err:
print(err)
if out_c is not None:
print("--- C tokens ---")
print(out_c)
if out_py is not None:
print("--- Py tokens ---")
print(out_py)
if stop_on_first:
break
if (i + 1) % 100 == 0:
print(f"[fuzz] {i+1} cases, mismatches={mismatches}")
# print(out_c, out_py, sep="\n")
return total, mismatches, last_save
def main():
ap = argparse.ArgumentParser(description="Fuzz C vs Python tokenizers on random valid UTF-8 inputs")
ap.add_argument("--iters", type=int, default=0, help="Number of iterations (0 = very large run)")
ap.add_argument("--max-len", type=int, default=256, help="Maximum number of Unicode scalars per case")
ap.add_argument("--seed", type=int, default=12345, help="PRNG seed for reproducibility")
ap.add_argument("--stop-on-first", action="store_true", help="Stop at first mismatch (default: run all)")
args = ap.parse_args()
total, mismatches, last = run_fuzz(args.iters, args.max_len, args.seed, args.stop_on_first)
print(f"Completed {total} cases, mismatches={mismatches}")
if last:
print(f"Last failing case saved at: {last}")
if mismatches:
sys.exit(1)
if __name__ == "__main__":
main()

43
fregex/py_tokenizer.py Normal file
View File

@ -0,0 +1,43 @@
import sys
from nanochat.tokenizer import SPLIT_PATTERN
from tokenizers import pre_tokenizers, Regex
_SPLITTER = pre_tokenizers.Split(pattern=Regex(SPLIT_PATTERN), behavior="isolated", invert=False)
def escape_bytes(b: bytes) -> str:
buf = []
for code in b:
if code == 0x5C: buf.append('\\\\')
elif code == 0x0A: buf.append('\\n')
elif code == 0x0D: buf.append('\\r')
elif code == 0x09: buf.append('\\t')
elif code == 0x0C: buf.append('\\f')
elif code == 0x0B: buf.append('\\v')
elif code == 0x22: buf.append('\\"')
elif code < 32 or code >= 127:
buf.append(f"\\x{code:02X}")
else:
buf.append(chr(code))
return ''.join(buf)
def tokenize_py(text: str):
parts = _SPLITTER.pre_tokenize_str(text)
return [s for (s, _range) in parts]
def main():
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <input-file>", file=sys.stderr)
sys.exit(2)
with open(sys.argv[1], 'rb') as f:
data = f.read()
text = data.decode('utf-8', errors='surrogatepass')
for tok in tokenize_py(text):
b = tok.encode('utf-8', errors='surrogatepass')
esc = escape_bytes(b)
print(f"{len(b)}\t{esc}")
if __name__ == "__main__":
main()

38
fregex/rust_split.py Normal file
View File

@ -0,0 +1,38 @@
import sys
from nanochat.tokenizer import SPLIT_PATTERN
from rustbpe import split_text as rust_split_text
def escape_bytes(b: bytes) -> str:
buf = []
for code in b:
if code == 0x5C: buf.append('\\\\')
elif code == 0x0A: buf.append('\\n')
elif code == 0x0D: buf.append('\\r')
elif code == 0x09: buf.append('\\t')
elif code == 0x0C: buf.append('\\f')
elif code == 0x0B: buf.append('\\v')
elif code == 0x22: buf.append('\\"')
elif code < 32 or code >= 127:
buf.append(f"\\x{code:02X}")
else:
buf.append(chr(code))
return ''.join(buf)
def main():
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <input-file>", file=sys.stderr)
sys.exit(2)
with open(sys.argv[1], 'rb') as f:
data = f.read()
text = data.decode('utf-8', errors='surrogatepass')
parts = rust_split_text(SPLIT_PATTERN, text)
for tok in parts:
b = tok.encode('utf-8', errors='surrogatepass')
esc = escape_bytes(b)
print(f"{len(b)}\t{esc}")
if __name__ == "__main__":
main()

View File

@ -4,6 +4,7 @@ use std::collections::HashMap as StdHashMap;
use dary_heap::OctonaryHeap;
use fancy_regex::Regex;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use ahash::{AHashMap, AHashSet};
use compact_str::CompactString;
@ -467,9 +468,23 @@ impl Tokenizer {
}
}
#[pyfunction]
pub fn split_text(pattern: String, text: String) -> PyResult<Vec<String>> {
let re = Regex::new(&pattern)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?;
let mut out: Vec<String> = Vec::new();
for m in re.find_iter(&text) {
let m = m
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("Regex match failed: {}", e)))?;
out.push(m.as_str().to_string());
}
Ok(out)
}
#[pymodule]
fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> {
pyo3_log::init(); // forwards Rust `log` to Python's `logging`
m.add_class::<Tokenizer>()?;
m.add_function(wrap_pyfunction!(split_text, m)?)?;
Ok(())
}