mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
faster regex in C
This commit is contained in:
parent
2e938530ce
commit
12f418f0a1
157
fregex/bench.py
Executable file
157
fregex/bench.py
Executable file
|
|
@ -0,0 +1,157 @@
|
|||
"""
|
||||
Benchmarker for comparing tok.c tokenize_fast() vs rust split_text()
|
||||
Measures speed WITHOUT subprocess overhead - direct function calls only.
|
||||
|
||||
Usage:
|
||||
cd pytok
|
||||
source ../.venv/bin/activate
|
||||
python3 bench.py # Run synthetic data benchmarks
|
||||
python3 bench.py /path/to/file.txt # Benchmark a specific file
|
||||
python3 bench.py file1.txt file2.txt ... # Benchmark multiple files
|
||||
"""
|
||||
|
||||
import sys
|
||||
import ctypes
|
||||
import random
|
||||
import time
|
||||
import statistics
|
||||
from pathlib import Path
|
||||
|
||||
from nanochat.tokenizer import SPLIT_PATTERN
|
||||
from rustbpe import split_text as rust_split_text
|
||||
from fregex.fuzz import gen_valid_unicode_string, compare_pair_text
|
||||
from fregex.cload import *
|
||||
|
||||
def bench_c_regex(data: bytes, iterations: int) -> list:
|
||||
times = []
|
||||
for _ in range(iterations):
|
||||
token_list = TokenList()
|
||||
c_lib.tokenlist_init(ctypes.byref(token_list))
|
||||
|
||||
start = time.perf_counter()
|
||||
c_lib.tokenize_fast(data, len(data), ctypes.byref(token_list))
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
c_lib.tokenlist_free(ctypes.byref(token_list))
|
||||
times.append(elapsed * 1000)
|
||||
|
||||
return times
|
||||
|
||||
def bench_rust_regex(text: str, iterations: int) -> list:
|
||||
times = []
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
rust_split_text(SPLIT_PATTERN, text)
|
||||
elapsed = time.perf_counter() - start
|
||||
times.append(elapsed * 1000)
|
||||
|
||||
return times
|
||||
|
||||
def stats_summary(times: list) -> dict:
|
||||
"""Compute statistics from timing list."""
|
||||
if not times or len(times) == 0:
|
||||
return {}
|
||||
|
||||
return {
|
||||
'min': min(times),
|
||||
'max': max(times),
|
||||
'mean': statistics.mean(times),
|
||||
'median': statistics.median(times),
|
||||
'stdev': statistics.stdev(times) if len(times) > 1 else 0,
|
||||
}
|
||||
|
||||
def format_stats(name: str, data_size: int, times: list) -> str:
|
||||
"""Format timing statistics for output."""
|
||||
if not times or len(times) == 0:
|
||||
return f"{name:20} {data_size:>10} B --\n"
|
||||
|
||||
stats = stats_summary(times)
|
||||
|
||||
return (f"{name:20} {data_size:>10} B "
|
||||
f"min={stats['min']:.3f}ms max={stats['max']:.3f}ms "
|
||||
f"mean={stats['mean']:.3f}ms median={stats['median']:.3f}ms "
|
||||
f"stdev={stats['stdev']:.3f}ms\n")
|
||||
|
||||
def benchmark_dataset(name: str, data_bytes: bytes, iterations: int) -> None:
|
||||
test_text = data_bytes.decode('utf-8', errors='replace')
|
||||
|
||||
print(f"\n--- Dataset: {name} ({len(data_bytes)} bytes, {iterations} iterations) ---")
|
||||
print()
|
||||
|
||||
c_times = bench_c_regex(data_bytes, iterations)
|
||||
print(format_stats("C tokenizer", len(data_bytes), c_times), end='')
|
||||
|
||||
rust_times = bench_rust_regex(test_text, iterations)
|
||||
print(format_stats("Rust split", len(data_bytes), rust_times), end='')
|
||||
|
||||
if c_times and rust_times:
|
||||
c_mean = statistics.mean(c_times)
|
||||
rust_mean = statistics.mean(rust_times)
|
||||
ratio = rust_mean / c_mean
|
||||
speedup = "C is faster" if ratio > 1 else "Rust is faster"
|
||||
print(f"Speedup: {ratio:.2f}x ({speedup})")
|
||||
|
||||
print()
|
||||
|
||||
# Verify token splits match between C and Python regex tokenizer
|
||||
cmp_text = data_bytes.decode('utf-8', errors='surrogatepass')
|
||||
ok, err, out_c, out_py = compare_pair_text(cmp_text)
|
||||
if ok:
|
||||
print("Compare: OK (C vs Py splits match)")
|
||||
else:
|
||||
print("Compare: MISMATCH (C vs Py)")
|
||||
if err:
|
||||
print(err)
|
||||
if out_c is not None and out_py is not None:
|
||||
c_lines = out_c.splitlines()
|
||||
p_lines = out_py.splitlines()
|
||||
print(f"C tokens: {len(c_lines)} | Py tokens: {len(p_lines)}")
|
||||
print("--- C (head) ---")
|
||||
print("\n".join(c_lines[:10]))
|
||||
print("--- Py (head) ---")
|
||||
print("\n".join(p_lines[:10]))
|
||||
# Stop the benchmark if mismatch detected
|
||||
raise SystemExit(1)
|
||||
|
||||
def main():
|
||||
# Check if files were provided as arguments
|
||||
file_args = sys.argv[1:] if len(sys.argv) > 1 else []
|
||||
|
||||
# If files provided, benchmark them
|
||||
if file_args:
|
||||
for file_path in file_args:
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
print(f"❌ File not found: {file_path}")
|
||||
continue
|
||||
|
||||
try:
|
||||
data = path.read_bytes()
|
||||
benchmark_dataset(path.name, data, 10)
|
||||
except Exception as e:
|
||||
print(f"❌ Error reading {file_path}: {e}")
|
||||
else:
|
||||
# Run random generated data
|
||||
configs = [
|
||||
("tiny", 100, 1000),
|
||||
("small", 1024, 500),
|
||||
("medium", 10 * 1024, 100),
|
||||
("large", 100 * 1024, 30),
|
||||
("xlarge", 1024 * 1024, 10),
|
||||
]
|
||||
|
||||
for name, size_bytes, iterations in configs:
|
||||
# Generate test data
|
||||
test_text = gen_valid_unicode_string(
|
||||
random.Random(hash(name)),
|
||||
size_bytes
|
||||
)
|
||||
test_bytes = test_text.encode('utf-8')
|
||||
|
||||
benchmark_dataset(name, test_bytes, iterations)
|
||||
|
||||
print("=" * 140)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
20
fregex/cload.py
Normal file
20
fregex/cload.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
import ctypes
|
||||
|
||||
c_lib = ctypes.CDLL("fregex/libfregex.dylib")
|
||||
|
||||
class TokenList(ctypes.Structure):
|
||||
pass
|
||||
|
||||
TokenList._fields_ = [
|
||||
("tokens", ctypes.POINTER(ctypes.POINTER(ctypes.c_char))),
|
||||
("lengths", ctypes.POINTER(ctypes.c_size_t)),
|
||||
("count", ctypes.c_size_t),
|
||||
("capacity", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
c_lib.tokenlist_init.argtypes = [ctypes.POINTER(TokenList)]
|
||||
c_lib.tokenlist_init.restype = None
|
||||
c_lib.tokenlist_free.argtypes = [ctypes.POINTER(TokenList)]
|
||||
c_lib.tokenlist_free.restype = None
|
||||
c_lib.tokenize_fast.argtypes = [ctypes.c_char_p, ctypes.c_size_t, ctypes.POINTER(TokenList)]
|
||||
c_lib.tokenize_fast.restype = None
|
||||
178
fregex/compare.py
Normal file
178
fregex/compare.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
import sys
|
||||
import ctypes
|
||||
from pathlib import Path
|
||||
|
||||
from nanochat.tokenizer import SPLIT_PATTERN
|
||||
from rustbpe import split_text as rust_split_text
|
||||
from fregex.cload import *
|
||||
from fregex.py_tokenizer import tokenize_py as py_tokenize_str
|
||||
|
||||
def escape_bytes(b: bytes) -> str:
|
||||
buf = []
|
||||
for code in b:
|
||||
if code == 0x5C:
|
||||
buf.append('\\')
|
||||
elif code == 0x0A:
|
||||
buf.append('\\n')
|
||||
elif code == 0x0D:
|
||||
buf.append('\\r')
|
||||
elif code == 0x09:
|
||||
buf.append('\\t')
|
||||
elif code == 0x0C:
|
||||
buf.append('\\f')
|
||||
elif code == 0x0B:
|
||||
buf.append('\\v')
|
||||
elif code == 0x22:
|
||||
buf.append('\\"')
|
||||
elif code < 32 or code >= 127:
|
||||
buf.append(f"\\x{code:02X}")
|
||||
else:
|
||||
buf.append(chr(code))
|
||||
return ''.join(buf)
|
||||
|
||||
def dump_tokens(tokens: list[bytes]) -> str:
|
||||
return "\n".join(f"{len(b)}\t{escape_bytes(b)}" for b in tokens)
|
||||
|
||||
def tokenize_c_bytes(data: bytes) -> list[bytes]:
|
||||
tl = TokenList()
|
||||
c_lib.tokenlist_init(ctypes.byref(tl))
|
||||
try:
|
||||
c_lib.tokenize_fast(data, len(data), ctypes.byref(tl))
|
||||
out: list[bytes] = []
|
||||
count = int(tl.count)
|
||||
for i in range(count):
|
||||
ptr = tl.tokens[i]
|
||||
ln = int(tl.lengths[i])
|
||||
out.append(ctypes.string_at(ptr, ln))
|
||||
return out
|
||||
finally:
|
||||
c_lib.tokenlist_free(ctypes.byref(tl))
|
||||
|
||||
def tokenize_py_bytes(data: bytes) -> list[bytes]:
|
||||
text = data.decode('utf-8', errors='surrogatepass')
|
||||
toks = py_tokenize_str(text)
|
||||
return [t.encode('utf-8', errors='surrogatepass') for t in toks]
|
||||
|
||||
def tokenize_rs_bytes(data: bytes) -> list[bytes]:
|
||||
text = data.decode('utf-8', errors='surrogatepass')
|
||||
parts = rust_split_text(SPLIT_PATTERN, text)
|
||||
return [t.encode('utf-8', errors='surrogatepass') for t in parts]
|
||||
|
||||
def compare_one(path: Path) -> int:
|
||||
data_bytes = Path(path).read_bytes()
|
||||
try:
|
||||
c_toks = tokenize_c_bytes(data_bytes)
|
||||
except Exception as e:
|
||||
print(f"C tokenizer failed on {path}:\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
try:
|
||||
py_toks = tokenize_py_bytes(data_bytes)
|
||||
except Exception as e:
|
||||
print(f"Python tokenizer failed on {path}:\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
try:
|
||||
rs_toks = tokenize_rs_bytes(data_bytes)
|
||||
except Exception as e:
|
||||
print(f"Rust split failed on {path}:\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
out_c = dump_tokens(c_toks)
|
||||
out_py = dump_tokens(py_toks)
|
||||
out_rs = dump_tokens(rs_toks)
|
||||
|
||||
if out_c == out_py == out_rs:
|
||||
print(f"OK {path.name}")
|
||||
return 0
|
||||
else:
|
||||
print(f"DIFF {path.name}")
|
||||
# Show a small 3-way diff at first differing line, with byte offsets
|
||||
c_lines = out_c.splitlines()
|
||||
p_lines = out_py.splitlines()
|
||||
r_lines = out_rs.splitlines()
|
||||
|
||||
def parse_lines(lines):
|
||||
parsed = []
|
||||
for ln in lines:
|
||||
# Format is: "<len>\t<escaped>"
|
||||
try:
|
||||
left, right = ln.split('\t', 1)
|
||||
blen = int(left)
|
||||
except Exception:
|
||||
blen = 0
|
||||
right = ln
|
||||
parsed.append((blen, right))
|
||||
return parsed
|
||||
|
||||
c_parsed = parse_lines(c_lines)
|
||||
p_parsed = parse_lines(p_lines)
|
||||
r_parsed = parse_lines(r_lines)
|
||||
|
||||
def byte_offsets(parsed):
|
||||
offs = []
|
||||
pos = 0
|
||||
for blen, _ in parsed:
|
||||
offs.append((pos, pos + blen))
|
||||
pos += blen
|
||||
return offs
|
||||
|
||||
c_offs = byte_offsets(c_parsed)
|
||||
p_offs = byte_offsets(p_parsed)
|
||||
r_offs = byte_offsets(r_parsed)
|
||||
|
||||
# Load original input bytes so we can show precise substrings and code points
|
||||
data_bytes = Path(path).read_bytes()
|
||||
|
||||
def print_unicode_debug(label, offs_list, idx):
|
||||
if idx >= len(offs_list):
|
||||
print(f" {label} piece: [n/a]")
|
||||
return
|
||||
start, end = offs_list[idx]
|
||||
piece_bytes = data_bytes[start:end]
|
||||
piece_text = piece_bytes.decode('utf-8', errors='replace')
|
||||
if not piece_bytes:
|
||||
print(f" {label} piece: [EMPTY]")
|
||||
return
|
||||
cp_parts = []
|
||||
for ch in piece_text:
|
||||
cp_parts.append(f"U+{ord(ch):04X}")
|
||||
bytes_hex = ' '.join(f"{b:02X}" for b in piece_bytes)
|
||||
print(f" {label} chars: {' | '.join(cp_parts)}")
|
||||
print(f" {label} bytes: {bytes_hex} ({len(piece_bytes)}B, {len(piece_text)} chars)")
|
||||
|
||||
max_len = max(len(c_lines), len(p_lines), len(r_lines))
|
||||
for i in range(max_len):
|
||||
cl = c_lines[i] if i < len(c_lines) else "<eof>"
|
||||
pl = p_lines[i] if i < len(p_lines) else "<eof>"
|
||||
rl = r_lines[i] if i < len(r_lines) else "<eof>"
|
||||
if not (cl == pl == rl):
|
||||
# Collect byte positions if available
|
||||
c_pos = f"[{c_offs[i][0]}:{c_offs[i][1]}]" if i < len(c_offs) else "[n/a]"
|
||||
p_pos = f"[{p_offs[i][0]}:{p_offs[i][1]}]" if i < len(p_offs) else "[n/a]"
|
||||
r_pos = f"[{r_offs[i][0]}:{r_offs[i][1]}]" if i < len(r_offs) else "[n/a]"
|
||||
print(
|
||||
f" line {i+1}:\n"
|
||||
f" C: {cl} @ bytes {c_pos}\n"
|
||||
f" Py: {pl} @ bytes {p_pos}\n"
|
||||
f" Rs: {rl} @ bytes {r_pos}"
|
||||
)
|
||||
print(" === Unicode split detail ===")
|
||||
print_unicode_debug("C", c_offs, i)
|
||||
print_unicode_debug("Py", p_offs, i)
|
||||
print_unicode_debug("Rs", r_offs, i)
|
||||
break
|
||||
return 2
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: {sys.argv[0]} <tests-dir>")
|
||||
sys.exit(2)
|
||||
paths = sorted(Path(sys.argv[1]).glob('*.txt'))
|
||||
bad = 0
|
||||
for p in paths:
|
||||
bad += compare_one(p)
|
||||
print(f"Completed. Failures: {bad}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
524
fregex/fregex.c
Normal file
524
fregex/fregex.c
Normal file
|
|
@ -0,0 +1,524 @@
|
|||
#include "fregex.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <ctype.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "utf8proc/utf8proc.h"
|
||||
|
||||
/*
|
||||
Regex pattern we care about from nanochat/tokenizer.py SPLIT_PATTERN
|
||||
|
||||
Break it down:
|
||||
A) '(?i:[sdmt]|ll|ve|re)
|
||||
B) [^\r\n\p{L}\p{N}]?+\p{L}+
|
||||
C) \p{N}{1,2}
|
||||
D) ?[^\s\p{L}\p{N}]++[\r\n]*
|
||||
E) \s*[\r\n]
|
||||
F) \s+(?!\S)
|
||||
G) \s+
|
||||
*/
|
||||
|
||||
#define UNICODE_LF 0x000A // Line Feed
|
||||
#define UNICODE_CR 0x000D // Carriage Return
|
||||
|
||||
static inline size_t utf8_decode_cp(
|
||||
const char *s,
|
||||
const char *end,
|
||||
unsigned int *out_cp
|
||||
) {
|
||||
if (s >= end) {
|
||||
*out_cp = 0;
|
||||
return 0;
|
||||
}
|
||||
utf8proc_int32_t ch = 0;
|
||||
ssize_t ret = utf8proc_iterate(
|
||||
(const utf8proc_uint8_t*)s,
|
||||
(ssize_t)(end - s),
|
||||
&ch
|
||||
);
|
||||
if (ret < 0) {
|
||||
// invalid sequence: treat as single byte
|
||||
*out_cp = (unsigned char)*s;
|
||||
return 1;
|
||||
}
|
||||
*out_cp = (unsigned int)ch;
|
||||
return (size_t)ret;
|
||||
}
|
||||
|
||||
static inline bool is_cr_or_lf(unsigned int cp) {
|
||||
return cp == UNICODE_LF || cp == UNICODE_CR;
|
||||
}
|
||||
|
||||
static inline bool is_letter(unsigned int cp) {
|
||||
utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp);
|
||||
|
||||
switch (cat) {
|
||||
case UTF8PROC_CATEGORY_LU: // Letter, Uppercase
|
||||
case UTF8PROC_CATEGORY_LL: // Letter, Lowercase
|
||||
case UTF8PROC_CATEGORY_LT: // Letter, Titlecase
|
||||
case UTF8PROC_CATEGORY_LM: // Letter, Modifier
|
||||
case UTF8PROC_CATEGORY_LO: // Letter, Other
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool is_number(unsigned int cp) {
|
||||
utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp);
|
||||
switch (cat) {
|
||||
case UTF8PROC_CATEGORY_ND: // Number, Decimal Digit
|
||||
case UTF8PROC_CATEGORY_NL: // Number, Letter
|
||||
case UTF8PROC_CATEGORY_NO: // Number, Other
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool is_space(unsigned int cp) {
|
||||
utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp);
|
||||
|
||||
if (
|
||||
cat == UTF8PROC_CATEGORY_ZS ||
|
||||
cat == UTF8PROC_CATEGORY_ZL ||
|
||||
cat == UTF8PROC_CATEGORY_ZP
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
switch (cp) {
|
||||
case 0x0009: // TAB
|
||||
case 0x000A: // LF
|
||||
case 0x000B: // VT
|
||||
case 0x000C: // FF
|
||||
case 0x000D: // CR
|
||||
case 0x0085: // NEL
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool is_alnum(unsigned int cp) {
|
||||
return is_letter(cp) || is_number(cp);
|
||||
}
|
||||
|
||||
static inline bool is_non_space_letter_number(unsigned int cp) {
|
||||
return !is_space(cp) && !is_alnum(cp);
|
||||
}
|
||||
|
||||
static void *xrealloc(void *ptr, size_t new_size) {
|
||||
void *p = realloc(ptr, new_size);
|
||||
if (!p) {
|
||||
fprintf(stderr, "Out of memory while reallocating %zu bytes\n", new_size);
|
||||
exit(1);
|
||||
}
|
||||
return p;
|
||||
}
|
||||
|
||||
static char *xmemdup(const char *src, size_t len) {
|
||||
char *s = (char*)malloc(len + 1);
|
||||
if (!s) {
|
||||
fprintf(stderr, "Out of memory while allocating %zu bytes\n", len + 1);
|
||||
exit(1);
|
||||
}
|
||||
memcpy(s, src, len);
|
||||
s[len] = '\0';
|
||||
return s;
|
||||
}
|
||||
|
||||
void tokenlist_init(TokenList *list) {
|
||||
list->tokens = NULL;
|
||||
list->lengths = NULL;
|
||||
list->count = 0;
|
||||
list->capacity = 0;
|
||||
}
|
||||
|
||||
void tokenlist_free(TokenList *list) {
|
||||
if (!list)
|
||||
return;
|
||||
for (size_t i = 0; i < list->count; ++i)
|
||||
free(list->tokens[i]);
|
||||
free(list->tokens);
|
||||
free(list->lengths);
|
||||
list->tokens = NULL;
|
||||
list->lengths = NULL;
|
||||
list->count = 0;
|
||||
list->capacity = 0;
|
||||
}
|
||||
|
||||
static void tokenlist_push(TokenList *list, const char *start, size_t len) {
|
||||
if (list->count == list->capacity) {
|
||||
const size_t new_cap = list->capacity ? (list->capacity * 2) : 64;
|
||||
list->tokens = (char**)xrealloc(list->tokens, new_cap * sizeof(char*));
|
||||
list->lengths = (size_t*)xrealloc(list->lengths, new_cap * sizeof(size_t));
|
||||
list->capacity = new_cap;
|
||||
}
|
||||
list->tokens[list->count] = xmemdup(start, len);
|
||||
list->lengths[list->count] = len;
|
||||
list->count++;
|
||||
}
|
||||
|
||||
static void fput_escaped_char(unsigned char c, FILE *out) {
|
||||
switch (c) {
|
||||
case '\\': fputs("\\\\", out); break;
|
||||
case '\n': fputs("\\n", out); break;
|
||||
case '\r': fputs("\\r", out); break;
|
||||
case '\t': fputs("\\t", out); break;
|
||||
case '\f': fputs("\\f", out); break;
|
||||
case '\v': fputs("\\v", out); break;
|
||||
case '\"': fputs("\\\"", out); break;
|
||||
default:
|
||||
if (c < 32 || c >= 127) {
|
||||
fprintf(out, "\\x%02X", c);
|
||||
} else {
|
||||
fputc(c, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void print_token_escaped(const char *s, size_t len, FILE *out) {
|
||||
fprintf(out, "%zu\t", len);
|
||||
for (size_t i = 0; i < len; ++i) fput_escaped_char((unsigned char)s[i], out);
|
||||
fputc('\n', out);
|
||||
}
|
||||
|
||||
void print_tokens_escaped(const TokenList *list, FILE *out) {
|
||||
for (size_t i = 0; i < list->count; ++i) {
|
||||
const char *tok = list->tokens[i];
|
||||
size_t len = list->lengths[i];
|
||||
print_token_escaped(tok, len, out);
|
||||
}
|
||||
}
|
||||
|
||||
/* A) '(?i:[sdmt]|ll|ve|re) */
|
||||
static size_t match_contraction(const char *p, const char *end) {
|
||||
if (p >= end || *p != '\'' || (p + 1) >= end)
|
||||
return 0;
|
||||
|
||||
unsigned char a = (unsigned char)p[1];
|
||||
|
||||
// locale-independent lowercase for ASCII letters:
|
||||
// map A–Z to a–z; leaves others unchanged
|
||||
if (a >= 'A' && a <= 'Z')
|
||||
a = (unsigned char)(a + ('a' - 'A'));
|
||||
|
||||
// 's 'd 'm 't
|
||||
if (a == 's' || a == 'd' || a == 'm' || a == 't')
|
||||
return 2;
|
||||
|
||||
// Need a second following byte for 'll 've 're
|
||||
if (p + 2 >= end)
|
||||
return 0;
|
||||
|
||||
unsigned char b = (unsigned char)p[2];
|
||||
if (b >= 'A' && b <= 'Z')
|
||||
b = (unsigned char)(b + ('a' - 'A'));
|
||||
|
||||
if ((a == 'l' && b == 'l') ||
|
||||
(a == 'v' && b == 'e') ||
|
||||
(a == 'r' && b == 'e')) {
|
||||
return 3;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* B) [^\r\n\p{L}\p{N}]?+\p{L}+ */
|
||||
static size_t match_word_with_optional_prefix(const char *p, const char *end) {
|
||||
if (p >= end)
|
||||
return 0;
|
||||
|
||||
const char *q = p;
|
||||
unsigned int cp0;
|
||||
size_t n0 = utf8_decode_cp(q, end, &cp0);
|
||||
if (n0 == 0)
|
||||
return 0;
|
||||
|
||||
const char *letters_start = q; // will point to first letter of \p{L}+
|
||||
size_t prefix_bytes = 0;
|
||||
|
||||
// Consider optional one-codepoint prefix if first cp is NOT (CR/LF or alnum)
|
||||
if (!is_cr_or_lf(cp0) && !is_alnum(cp0)) {
|
||||
// Look ahead: must have at least one letter right after to commit the prefix
|
||||
const char *after_prefix = q + n0;
|
||||
if (after_prefix >= end) {
|
||||
return 0; // no room for required \p{L}+
|
||||
}
|
||||
unsigned int cp1;
|
||||
size_t n1 = utf8_decode_cp(after_prefix, end, &cp1);
|
||||
if (n1 > 0 && is_letter(cp1)) {
|
||||
// Commit the prefix (possessive) and start letters after it
|
||||
prefix_bytes = n0;
|
||||
letters_start = after_prefix;
|
||||
q = after_prefix;
|
||||
} else {
|
||||
// Can't commit the prefix (no letter follows), so this branch fails
|
||||
return 0;
|
||||
}
|
||||
} else if (is_letter(cp0)) {
|
||||
// No prefix; we already sit on the first letter
|
||||
letters_start = q;
|
||||
} else {
|
||||
// First cp is CR/LF or a number; this branch doesn't match
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Consume \p{L}+ (require at least one letter)
|
||||
size_t letter_count = 0;
|
||||
const char *scan = letters_start;
|
||||
while (scan < end) {
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(scan, end, &cp);
|
||||
if (n == 0 || !is_letter(cp))
|
||||
break;
|
||||
scan += n;
|
||||
letter_count++;
|
||||
}
|
||||
|
||||
if (letter_count == 0) {
|
||||
// Shouldn't happen given the look-ahead, but keep the guard
|
||||
return 0;
|
||||
}
|
||||
|
||||
return (size_t)(scan - p); // includes prefix (if any) + letters
|
||||
}
|
||||
|
||||
/* C) \p{N}{1,2} */
|
||||
static size_t match_short_number(const char *p, const char *end) {
|
||||
if (p >= end)
|
||||
return 0;
|
||||
|
||||
/* First number required */
|
||||
unsigned int cp1;
|
||||
size_t n1 = utf8_decode_cp(p, end, &cp1);
|
||||
if (n1 == 0 || !is_number(cp1))
|
||||
return 0;
|
||||
|
||||
const char *q = p + n1;
|
||||
|
||||
// Optional second number cp
|
||||
if (q < end) {
|
||||
unsigned int cp2;
|
||||
size_t n2 = utf8_decode_cp(q, end, &cp2);
|
||||
if (n2 > 0 && is_number(cp2))
|
||||
return (size_t)((q + n2) - p); // 2 numbers
|
||||
}
|
||||
return (size_t)(q - p); // 1 number
|
||||
}
|
||||
|
||||
/* D) ?[^\s\p{L}\p{N}]++[\r\n]* */
|
||||
// Optional single ASCII space, then 1+ of (not whitespace, not letter, not number),
|
||||
static size_t match_punct_run(const char *p, const char *end) {
|
||||
const char *q = p;
|
||||
|
||||
/* Optional single ASCII space */
|
||||
if (q < end && *q == ' ') {
|
||||
const char *r = q + 1;
|
||||
if (r >= end)
|
||||
return 0;
|
||||
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(r, end, &cp);
|
||||
if (n == 0)
|
||||
return 0;
|
||||
|
||||
// Eligible b/c not whitespace and not alnum
|
||||
if (!is_space(cp) && !is_alnum(cp)) {
|
||||
q = r; // commit the space
|
||||
} else {
|
||||
// Not followed by eligible punct
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Now require at least one eligible (not whitespace, not alnum)
|
||||
size_t took = 0;
|
||||
while (q < end) {
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(q, end, &cp);
|
||||
if (n == 0) break;
|
||||
if (is_space(cp) || is_alnum(cp))
|
||||
break; // stop on any whitespace or letter/number
|
||||
q += n;
|
||||
took++;
|
||||
}
|
||||
if (took == 0)
|
||||
return 0; // must have at least one punctuation/symbol
|
||||
|
||||
// Finally, optionally absorb CR/LF sequence(s)
|
||||
while (q < end) {
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(q, end, &cp);
|
||||
if (n == 0 || !is_cr_or_lf(cp))
|
||||
break;
|
||||
q += n;
|
||||
}
|
||||
|
||||
return (size_t)(q - p);
|
||||
}
|
||||
|
||||
/* E) \s*[\r\n] */
|
||||
static size_t match_ws_then_linebreak(const char *p, const char *end) {
|
||||
const char *q = p;
|
||||
|
||||
// Collect all positions while consuming whitespace
|
||||
// TODO: ? Could we hit the limit
|
||||
const char *positions[256];
|
||||
int pos_count = 0;
|
||||
|
||||
// Store initial position (zero whitespace consumed)
|
||||
positions[pos_count++] = q;
|
||||
|
||||
while (q < end && pos_count < 255) {
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(q, end, &cp);
|
||||
if (n == 0 || !is_space(cp))
|
||||
break;
|
||||
q += n;
|
||||
positions[pos_count++] = q;
|
||||
}
|
||||
|
||||
// Try positions from longest to shortest (backtracking)
|
||||
// We need to find a position where the next character is a linebreak
|
||||
for (int i = pos_count - 1; i >= 0; i--) {
|
||||
q = positions[i];
|
||||
|
||||
// Check if next character is a linebreak
|
||||
if (q < end) {
|
||||
unsigned int br;
|
||||
size_t nb = utf8_decode_cp(q, end, &br);
|
||||
if (nb > 0 && is_cr_or_lf(br)) {
|
||||
// Found a linebreak, include it and return
|
||||
return (size_t)(q + nb - p);
|
||||
}
|
||||
} else {
|
||||
// EOF reached, rule requires a linebreak so fail
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// No position found where next char is a linebreak
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* F) \s+(?!\S) */
|
||||
static size_t match_trailing_ws(const char *p, const char *end) {
|
||||
if (p >= end) return 0;
|
||||
|
||||
/* Must start with at least one whitespace */
|
||||
const char *q = p;
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(q, end, &cp);
|
||||
if (n == 0 || !is_space(cp))
|
||||
return 0;
|
||||
|
||||
/* Collect all whitespace positions */
|
||||
// TODO: ? Could we hit the limit
|
||||
const char *positions[256];
|
||||
positions[0] = q + n; // Position after first whitespace
|
||||
int pos_count = 1;
|
||||
|
||||
q += n;
|
||||
|
||||
while (q < end && pos_count < 255) {
|
||||
size_t m = utf8_decode_cp(q, end, &cp);
|
||||
if (m == 0 || !is_space(cp))
|
||||
break;
|
||||
q += m;
|
||||
positions[pos_count++] = q;
|
||||
}
|
||||
|
||||
/* Try positions from longest to shortest (backtracking) */
|
||||
for (int i = pos_count - 1; i >= 0; i--) {
|
||||
q = positions[i];
|
||||
|
||||
/* Check negative lookahead: (?!\S) at this position */
|
||||
if (q < end) {
|
||||
size_t k = utf8_decode_cp(q, end, &cp);
|
||||
if (k > 0 && !is_space(cp)) {
|
||||
continue; /* Next char is non-space, try shorter match */
|
||||
}
|
||||
}
|
||||
|
||||
/* Lookahead succeeded at this position */
|
||||
return (size_t)(q - p);
|
||||
}
|
||||
|
||||
/* All positions failed lookahead */
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* G) \s+ */
|
||||
static size_t match_ws_run(const char *p, const char *end) {
|
||||
if (p >= end)
|
||||
return 0;
|
||||
|
||||
const char *q = p;
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(q, end, &cp);
|
||||
if (n == 0 || !is_space(cp))
|
||||
return 0;
|
||||
|
||||
/* We have at least one whitespace, consume the run */
|
||||
q += n;
|
||||
while (q < end) {
|
||||
size_t m = utf8_decode_cp(q, end, &cp);
|
||||
if (m == 0 || !is_space(cp))
|
||||
break;
|
||||
q += m;
|
||||
}
|
||||
return (size_t)(q - p);
|
||||
}
|
||||
|
||||
void tokenize_fast(const char *input, size_t input_len, TokenList *out) {
|
||||
if (!input) {
|
||||
out->tokens = NULL;
|
||||
out->count = 0;
|
||||
out->capacity = 0;
|
||||
return;
|
||||
}
|
||||
const char *p = input;
|
||||
const char *end = input + input_len;
|
||||
|
||||
while (p < end) {
|
||||
/* Special tokens take precedence */
|
||||
// TODO LATER
|
||||
int captured = 0;
|
||||
|
||||
/* Evaluate case A */
|
||||
captured = match_contraction(p, end);
|
||||
/* Evaluate case B */
|
||||
if (!captured) captured = match_word_with_optional_prefix(p, end);
|
||||
/* Evaluate case C */
|
||||
if (!captured) captured = match_short_number(p, end);
|
||||
/* Evaluate case D */
|
||||
if (!captured) captured = match_punct_run(p, end);
|
||||
/* Evaluate case E */
|
||||
if (!captured) captured = match_ws_then_linebreak(p, end);
|
||||
/* Evaluate case F */
|
||||
if (!captured) captured = match_trailing_ws(p, end);
|
||||
/* Evaluate case G */
|
||||
if (!captured) captured = match_ws_run(p, end);
|
||||
|
||||
if (captured) {
|
||||
tokenlist_push(out, p, captured);
|
||||
p += captured;
|
||||
}
|
||||
else {
|
||||
/* Fallback take a single CP */
|
||||
unsigned int cp;
|
||||
size_t adv = utf8_decode_cp(p, end, &cp);
|
||||
if (adv == 0)
|
||||
adv = 1;
|
||||
tokenlist_push(out, p, adv);
|
||||
p += adv;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
28
fregex/fregex.h
Normal file
28
fregex/fregex.h
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
#ifndef FAST_TOKENIZER_H
|
||||
#define FAST_TOKENIZER_H
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
typedef struct {
|
||||
char **tokens;
|
||||
size_t *lengths; // Store length of each token to handle null bytes
|
||||
size_t count;
|
||||
size_t capacity;
|
||||
} TokenList;
|
||||
|
||||
void tokenlist_init(TokenList *list);
|
||||
void tokenlist_free(TokenList *list);
|
||||
|
||||
// Tokenize input according to the GPT-like regex split semantics
|
||||
// r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
void tokenize_fast(const char *input, size_t input_len, TokenList *out);
|
||||
|
||||
// Utility to print tokens with C-like escaping (one per line):
|
||||
// <length>\t<escaped-bytes>\n
|
||||
void print_token_escaped(const char *s, size_t len, FILE *out);
|
||||
void print_tokens_escaped(const TokenList *list, FILE *out);
|
||||
|
||||
#endif // FAST_TOKENIZER_H
|
||||
|
||||
|
||||
361
fregex/fuzz.py
Normal file
361
fregex/fuzz.py
Normal file
|
|
@ -0,0 +1,361 @@
|
|||
import sys
|
||||
import time
|
||||
import random
|
||||
import argparse
|
||||
import unicodedata as u
|
||||
import ctypes
|
||||
from pathlib import Path
|
||||
|
||||
from fregex.cload import *
|
||||
|
||||
HERE = Path(__file__).resolve().parent
|
||||
TESTS_DIR = HERE / "tests"
|
||||
|
||||
from fregex.py_tokenizer import tokenize_py as py_tokenize_str
|
||||
|
||||
def escape_bytes(b: bytes) -> str:
|
||||
buf = []
|
||||
for code in b:
|
||||
if code == 0x5C:
|
||||
buf.append('\\\\')
|
||||
elif code == 0x0A:
|
||||
buf.append('\\n')
|
||||
elif code == 0x0D:
|
||||
buf.append('\\r')
|
||||
elif code == 0x09:
|
||||
buf.append('\\t')
|
||||
elif code == 0x0C:
|
||||
buf.append('\\f')
|
||||
elif code == 0x0B:
|
||||
buf.append('\\v')
|
||||
elif code == 0x22:
|
||||
buf.append('\\"')
|
||||
elif code < 32 or code >= 127:
|
||||
buf.append(f"\\x{code:02X}")
|
||||
else:
|
||||
buf.append(chr(code))
|
||||
return ''.join(buf)
|
||||
|
||||
def gen_valid_unicode_string(rng: random.Random, max_len: int) -> str:
|
||||
target_len = rng.randint(0, max_len)
|
||||
|
||||
ws_cps = [
|
||||
0x20, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, # space, \t, \n, \v, \f, \r
|
||||
0x00A0, # NO-BREAK SPACE
|
||||
0x1680, # OGHAM SPACE MARK
|
||||
0x2000, 0x2001, 0x2002, 0x2003, 0x2004, 0x2005, 0x2006,
|
||||
0x2007, 0x2008, 0x2009, 0x200A, # EN/EM/THIN/HAIR SPACES etc.
|
||||
0x2028, 0x2029, # LINE SEPARATOR, PARAGRAPH SEPARATOR
|
||||
0x202F, # NARROW NO-BREAK SPACE
|
||||
0x205F, # MEDIUM MATHEMATICAL SPACE
|
||||
0x3000, # IDEOGRAPHIC SPACE
|
||||
0x200B, # ZERO WIDTH SPACE (not WS in Python, but hits tokenizer class)
|
||||
0xFEFF, # ZERO WIDTH NO-BREAK SPACE
|
||||
]
|
||||
|
||||
ascii_punct = [
|
||||
ord(c)
|
||||
for c in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
||||
]
|
||||
|
||||
def rand_scalar_excluding_surrogates(lo: int, hi: int) -> int:
|
||||
while True:
|
||||
cp = rng.randint(lo, hi)
|
||||
if 0xD800 <= cp <= 0xDFFF:
|
||||
continue
|
||||
return cp
|
||||
|
||||
MAX_WS_RUN = 255
|
||||
|
||||
def is_ws_char(ch: str) -> bool:
|
||||
cp = ord(ch)
|
||||
return ch.isspace() or (cp in ws_cps)
|
||||
|
||||
def gen_ws_segment(max_run: int) -> str:
|
||||
# Mix of various spaces, often multi-length; sometimes explicit CRLFs
|
||||
if rng.random() < 0.35:
|
||||
# Build CR, LF, CRLF, or repeated newlines
|
||||
seqs = ["\n", "\r", "\r\n"]
|
||||
unit = rng.choice(seqs)
|
||||
unit_len = len(unit)
|
||||
max_reps = max(1, min(max_run // unit_len, MAX_WS_RUN // unit_len))
|
||||
seg = unit * rng.randint(1, max_reps)
|
||||
return seg
|
||||
run = rng.randint(1, min(MAX_WS_RUN, max(1, max_run)))
|
||||
buf = []
|
||||
for _ in range(run):
|
||||
cp = rng.choice(ws_cps)
|
||||
buf.append(chr(cp))
|
||||
return ''.join(buf)
|
||||
|
||||
def gen_letter_run(max_run: int) -> str:
|
||||
run = rng.randint(1, max(1, max_run))
|
||||
buf = []
|
||||
for _ in range(run):
|
||||
if rng.random() < 0.6:
|
||||
# ASCII letters
|
||||
base = ord('A') if rng.random() < 0.5 else ord('a')
|
||||
buf.append(chr(base + rng.randint(0, 25)))
|
||||
else:
|
||||
# Any Unicode letter
|
||||
while True:
|
||||
cp = rand_scalar_excluding_surrogates(0x00A0, 0x10FFFF)
|
||||
if u.category(chr(cp)).startswith('L'):
|
||||
buf.append(chr(cp))
|
||||
break
|
||||
# optional prefix of single non-WS, non-letter, non-number to stress
|
||||
# the leading [^\r\n\p{L}\p{N}]?+ in the regex
|
||||
if rng.random() < 0.3:
|
||||
buf.insert(0, gen_punc_run(1, allow_space=False))
|
||||
return ''.join(buf)
|
||||
|
||||
def gen_number_run(max_run: int) -> str:
|
||||
# Bias to lengths 1..2 per \p{N}{1,2}, but sometimes longer
|
||||
if rng.random() < 0.7:
|
||||
run = rng.randint(1, min(2, max_run))
|
||||
else:
|
||||
run = rng.randint(3, max(3, max_run))
|
||||
buf = []
|
||||
for _ in range(run):
|
||||
if rng.random() < 0.75:
|
||||
buf.append(chr(ord('0') + rng.randint(0, 9)))
|
||||
else:
|
||||
# Other numeric categories (Nd/Nl/No)
|
||||
while True:
|
||||
cp = rand_scalar_excluding_surrogates(0x00A0, 0x10FFFF)
|
||||
if u.category(chr(cp)).startswith('N'):
|
||||
buf.append(chr(cp))
|
||||
break
|
||||
return ''.join(buf)
|
||||
|
||||
def gen_punc_run(max_run: int, allow_space: bool = True) -> str:
|
||||
run = rng.randint(1, max(1, max_run))
|
||||
buf = []
|
||||
# optional leading single space before punc block
|
||||
if allow_space and rng.random() < 0.5:
|
||||
buf.append(' ')
|
||||
for _ in range(run):
|
||||
if rng.random() < 0.6:
|
||||
cp = rng.choice(ascii_punct)
|
||||
else:
|
||||
while True:
|
||||
cp = rand_scalar_excluding_surrogates(0, 0x10FFFF)
|
||||
ch = chr(cp)
|
||||
if (
|
||||
not u.category(ch).startswith('L') and
|
||||
not u.category(ch).startswith('N') and
|
||||
cp not in ws_cps and
|
||||
not ch.isspace()
|
||||
):
|
||||
break
|
||||
# ensure we don't accidentally add null
|
||||
buf.append(chr(cp))
|
||||
# optional trailing newlines to stress [\r\n]*
|
||||
if rng.random() < 0.35:
|
||||
tail = gen_ws_segment(3)
|
||||
# Keep only CR/LF components in the tail for this case
|
||||
tail = tail.replace('\t', '').replace('\v', '').replace('\f', '').replace(' ', '')
|
||||
buf.append(tail)
|
||||
return ''.join(buf)
|
||||
|
||||
def gen_contraction() -> str:
|
||||
# e.g., we're, he'll, I'd, I'm, can't, they've
|
||||
prefixes = [gen_letter_run( rng.randint(1, 6) )]
|
||||
suffix = rng.choice(["s", "d", "m", "t", "ll", "ve", "re"])
|
||||
return prefixes[0] + "'" + suffix
|
||||
|
||||
def gen_random_unicode(max_run: int) -> str:
|
||||
run = rng.randint(1, max(1, max_run))
|
||||
buf = []
|
||||
for _ in range(run):
|
||||
cp = rand_scalar_excluding_surrogates(0, 0x10FFFF)
|
||||
try:
|
||||
buf.append(chr(cp))
|
||||
except ValueError:
|
||||
continue
|
||||
return ''.join(buf)
|
||||
|
||||
buf: list[str] = []
|
||||
curr_len = 0
|
||||
curr_ws_run = 0
|
||||
# Build by segments until target_len
|
||||
while curr_len < target_len:
|
||||
remain = target_len - curr_len
|
||||
r = rng.random()
|
||||
if r < 0.40:
|
||||
seg = gen_ws_segment(remain)
|
||||
elif r < 0.45:
|
||||
# Explicit newline-focused segment
|
||||
seg = ("\r\n" if rng.random() < 0.5 else ("\n" if rng.random() < 0.5 else "\r")) * rng.randint(1, max(1, remain))
|
||||
elif r < 0.65:
|
||||
seg = gen_letter_run(remain)
|
||||
elif r < 0.75:
|
||||
seg = gen_number_run(remain)
|
||||
elif r < 0.90:
|
||||
seg = gen_punc_run(remain)
|
||||
elif r < 0.95:
|
||||
seg = gen_contraction()
|
||||
else:
|
||||
seg = gen_random_unicode(remain)
|
||||
|
||||
if not seg:
|
||||
continue
|
||||
# Trim if needed
|
||||
# Append with whitespace-run capping
|
||||
for ch in seg:
|
||||
if curr_len >= target_len:
|
||||
break
|
||||
if is_ws_char(ch):
|
||||
if curr_ws_run >= MAX_WS_RUN:
|
||||
# insert a non-whitespace breaker
|
||||
breaker = '.'
|
||||
buf.append(breaker)
|
||||
curr_len += 1
|
||||
curr_ws_run = 0
|
||||
if curr_len >= target_len:
|
||||
break
|
||||
buf.append(ch)
|
||||
curr_len += 1
|
||||
curr_ws_run += 1
|
||||
else:
|
||||
buf.append(ch)
|
||||
curr_len += 1
|
||||
curr_ws_run = 0
|
||||
|
||||
# Occasionally end with trailing spaces to stress \s+(?!\S)
|
||||
if curr_len < max_len and rng.random() < 0.3:
|
||||
trail = gen_ws_segment(max_len - curr_len)
|
||||
if rng.random() < 0.7:
|
||||
trail = (' ' if rng.random() < 0.6 else '\t') * rng.randint(1, min(8, max_len - curr_len))
|
||||
# Append trailing with cap as well
|
||||
for ch in trail:
|
||||
if curr_len >= max_len:
|
||||
break
|
||||
if is_ws_char(ch):
|
||||
if curr_ws_run >= MAX_WS_RUN:
|
||||
buf.append('.')
|
||||
curr_len += 1
|
||||
curr_ws_run = 0
|
||||
if curr_len >= max_len:
|
||||
break
|
||||
buf.append(ch)
|
||||
curr_len += 1
|
||||
curr_ws_run += 1
|
||||
else:
|
||||
buf.append(ch)
|
||||
curr_len += 1
|
||||
curr_ws_run = 0
|
||||
|
||||
return ''.join(buf)
|
||||
|
||||
def write_temp_case(text: str, tag: str = "RUN") -> Path:
|
||||
TESTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
ts = int(time.time() * 1000)
|
||||
fname = f"in_fuzz_{tag}_{ts}.txt"
|
||||
path = TESTS_DIR / fname
|
||||
with open(path, 'wb') as f:
|
||||
f.write(text.encode('utf-8', errors='surrogatepass'))
|
||||
return path
|
||||
|
||||
def _format_tokens_dump(tokens: list[bytes]) -> str:
|
||||
lines = []
|
||||
for b in tokens:
|
||||
lines.append(f"{len(b)}\t{escape_bytes(b)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def tokenize_c_bytes(data: bytes) -> list[bytes]:
|
||||
tl = TokenList()
|
||||
c_lib.tokenlist_init(ctypes.byref(tl))
|
||||
try:
|
||||
c_lib.tokenize_fast(data, len(data), ctypes.byref(tl))
|
||||
out: list[bytes] = []
|
||||
count = int(tl.count)
|
||||
for i in range(count):
|
||||
ptr = tl.tokens[i]
|
||||
ln = int(tl.lengths[i])
|
||||
out.append(ctypes.string_at(ptr, ln))
|
||||
return out
|
||||
finally:
|
||||
c_lib.tokenlist_free(ctypes.byref(tl))
|
||||
|
||||
def tokenize_py_bytes(data: bytes) -> list[bytes]:
|
||||
if py_tokenize_str is None:
|
||||
raise RuntimeError("py_tokenizer not available")
|
||||
text = data.decode('utf-8', errors='surrogatepass')
|
||||
toks = py_tokenize_str(text)
|
||||
return [t.encode('utf-8', errors='surrogatepass') for t in toks]
|
||||
|
||||
def compare_pair_text(text: str):
|
||||
data = text.encode('utf-8', errors='surrogatepass')
|
||||
try:
|
||||
toks_c = tokenize_c_bytes(data)
|
||||
except Exception as e:
|
||||
return False, f"C failed: {e}", None, None
|
||||
try:
|
||||
toks_py = tokenize_py_bytes(data)
|
||||
except Exception as e:
|
||||
return False, f"Py failed: {e}", None, None
|
||||
ok = toks_c == toks_py
|
||||
return ok, None, _format_tokens_dump(toks_c), _format_tokens_dump(toks_py)
|
||||
|
||||
def run_fuzz(iters: int, max_len: int, seed: int, stop_on_first: bool):
|
||||
rng = random.Random(seed)
|
||||
total = 0
|
||||
mismatches = 0
|
||||
last_save = None
|
||||
|
||||
for i in range(iters if iters > 0 else 1_000_000_000):
|
||||
s = gen_valid_unicode_string(rng, max_len)
|
||||
ok, err, out_c, out_py = compare_pair_text(s)
|
||||
total += 1
|
||||
if not ok:
|
||||
mismatches += 1
|
||||
fail_path = write_temp_case(s, tag="FAIL")
|
||||
last_save = fail_path
|
||||
print(f"Mismatch at iter {i}, saved to {fail_path}")
|
||||
print(f"Seed: {seed}")
|
||||
|
||||
cps = [f"U+{ord(ch):04X}" for ch in s]
|
||||
cats = [u.category(ch) for ch in s]
|
||||
print(f"Text bytes len: {len(s.encode('utf-8','surrogatepass'))}, chars: {len(s)}")
|
||||
print(f"Codepoints: {' '.join(cps)}")
|
||||
print(f"Categories: {' '.join(cats)}")
|
||||
if err:
|
||||
print(err)
|
||||
if out_c is not None:
|
||||
print("--- C tokens ---")
|
||||
print(out_c)
|
||||
if out_py is not None:
|
||||
print("--- Py tokens ---")
|
||||
print(out_py)
|
||||
|
||||
if stop_on_first:
|
||||
break
|
||||
|
||||
if (i + 1) % 100 == 0:
|
||||
print(f"[fuzz] {i+1} cases, mismatches={mismatches}")
|
||||
# print(out_c, out_py, sep="\n")
|
||||
|
||||
return total, mismatches, last_save
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Fuzz C vs Python tokenizers on random valid UTF-8 inputs")
|
||||
ap.add_argument("--iters", type=int, default=0, help="Number of iterations (0 = very large run)")
|
||||
ap.add_argument("--max-len", type=int, default=256, help="Maximum number of Unicode scalars per case")
|
||||
ap.add_argument("--seed", type=int, default=12345, help="PRNG seed for reproducibility")
|
||||
ap.add_argument("--stop-on-first", action="store_true", help="Stop at first mismatch (default: run all)")
|
||||
args = ap.parse_args()
|
||||
|
||||
total, mismatches, last = run_fuzz(args.iters, args.max_len, args.seed, args.stop_on_first)
|
||||
print(f"Completed {total} cases, mismatches={mismatches}")
|
||||
if last:
|
||||
print(f"Last failing case saved at: {last}")
|
||||
if mismatches:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
43
fregex/py_tokenizer.py
Normal file
43
fregex/py_tokenizer.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import sys
|
||||
|
||||
from nanochat.tokenizer import SPLIT_PATTERN
|
||||
from tokenizers import pre_tokenizers, Regex
|
||||
|
||||
_SPLITTER = pre_tokenizers.Split(pattern=Regex(SPLIT_PATTERN), behavior="isolated", invert=False)
|
||||
|
||||
def escape_bytes(b: bytes) -> str:
|
||||
buf = []
|
||||
for code in b:
|
||||
if code == 0x5C: buf.append('\\\\')
|
||||
elif code == 0x0A: buf.append('\\n')
|
||||
elif code == 0x0D: buf.append('\\r')
|
||||
elif code == 0x09: buf.append('\\t')
|
||||
elif code == 0x0C: buf.append('\\f')
|
||||
elif code == 0x0B: buf.append('\\v')
|
||||
elif code == 0x22: buf.append('\\"')
|
||||
elif code < 32 or code >= 127:
|
||||
buf.append(f"\\x{code:02X}")
|
||||
else:
|
||||
buf.append(chr(code))
|
||||
return ''.join(buf)
|
||||
|
||||
def tokenize_py(text: str):
|
||||
parts = _SPLITTER.pre_tokenize_str(text)
|
||||
return [s for (s, _range) in parts]
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: {sys.argv[0]} <input-file>", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
with open(sys.argv[1], 'rb') as f:
|
||||
data = f.read()
|
||||
text = data.decode('utf-8', errors='surrogatepass')
|
||||
for tok in tokenize_py(text):
|
||||
b = tok.encode('utf-8', errors='surrogatepass')
|
||||
esc = escape_bytes(b)
|
||||
print(f"{len(b)}\t{esc}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
38
fregex/rust_split.py
Normal file
38
fregex/rust_split.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import sys
|
||||
|
||||
from nanochat.tokenizer import SPLIT_PATTERN
|
||||
from rustbpe import split_text as rust_split_text
|
||||
|
||||
def escape_bytes(b: bytes) -> str:
|
||||
buf = []
|
||||
for code in b:
|
||||
if code == 0x5C: buf.append('\\\\')
|
||||
elif code == 0x0A: buf.append('\\n')
|
||||
elif code == 0x0D: buf.append('\\r')
|
||||
elif code == 0x09: buf.append('\\t')
|
||||
elif code == 0x0C: buf.append('\\f')
|
||||
elif code == 0x0B: buf.append('\\v')
|
||||
elif code == 0x22: buf.append('\\"')
|
||||
elif code < 32 or code >= 127:
|
||||
buf.append(f"\\x{code:02X}")
|
||||
else:
|
||||
buf.append(chr(code))
|
||||
return ''.join(buf)
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: {sys.argv[0]} <input-file>", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
with open(sys.argv[1], 'rb') as f:
|
||||
data = f.read()
|
||||
text = data.decode('utf-8', errors='surrogatepass')
|
||||
parts = rust_split_text(SPLIT_PATTERN, text)
|
||||
for tok in parts:
|
||||
b = tok.encode('utf-8', errors='surrogatepass')
|
||||
esc = escape_bytes(b)
|
||||
print(f"{len(b)}\t{esc}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
|
@ -4,6 +4,7 @@ use std::collections::HashMap as StdHashMap;
|
|||
use dary_heap::OctonaryHeap;
|
||||
use fancy_regex::Regex;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
|
||||
use ahash::{AHashMap, AHashSet};
|
||||
use compact_str::CompactString;
|
||||
|
|
@ -467,9 +468,23 @@ impl Tokenizer {
|
|||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn split_text(pattern: String, text: String) -> PyResult<Vec<String>> {
|
||||
let re = Regex::new(&pattern)
|
||||
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?;
|
||||
let mut out: Vec<String> = Vec::new();
|
||||
for m in re.find_iter(&text) {
|
||||
let m = m
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("Regex match failed: {}", e)))?;
|
||||
out.push(m.as_str().to_string());
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
pyo3_log::init(); // forwards Rust `log` to Python's `logging`
|
||||
m.add_class::<Tokenizer>()?;
|
||||
m.add_function(wrap_pyfunction!(split_text, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user