mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Compare commits
7 Commits
4ecaf67bdb
...
cc6830a9f9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc6830a9f9 | ||
|
|
bc1fca39f3 | ||
|
|
0a1059d571 | ||
|
|
851810c7d5 | ||
|
|
41c8b8dbde | ||
|
|
e02938c0aa | ||
|
|
12f418f0a1 |
174
fregex/bench.py
Executable file
174
fregex/bench.py
Executable file
|
|
@ -0,0 +1,174 @@
|
|||
import sys
|
||||
import ctypes
|
||||
import random
|
||||
import time
|
||||
import statistics
|
||||
import os
|
||||
import gc
|
||||
from pathlib import Path
|
||||
|
||||
from nanochat.tokenizer import SPLIT_PATTERN
|
||||
|
||||
os.environ.update({
|
||||
'OMP_NUM_THREADS': '1',
|
||||
'OPENBLAS_NUM_THREADS': '1',
|
||||
'MKL_NUM_THREADS': '1',
|
||||
'VECLIB_MAXIMUM_THREADS': '1',
|
||||
'NUMEXPR_NUM_THREADS': '1',
|
||||
'RAYON_NUM_THREADS': '1',
|
||||
})
|
||||
|
||||
os.setpriority(os.PRIO_PROCESS, 0, -10)
|
||||
|
||||
from rustbpe import split_text as rust_split_text
|
||||
from fregex.fuzz import gen_valid_unicode_string, compare_pair_text
|
||||
from fregex.cload import *
|
||||
|
||||
PyBytes_AsString = ctypes.pythonapi.PyBytes_AsString
|
||||
PyBytes_AsString.restype = ctypes.c_void_p
|
||||
PyBytes_AsString.argtypes = [ctypes.py_object]
|
||||
|
||||
def _run_once_c(data: bytes) -> float:
|
||||
token_list = TokenList()
|
||||
c_lib.tokenlist_init(ctypes.byref(token_list))
|
||||
base_ptr = PyBytes_AsString(data)
|
||||
t0 = time.perf_counter_ns()
|
||||
c_lib.tokenize_fast(base_ptr, len(data), ctypes.byref(token_list))
|
||||
dt_ms = (time.perf_counter_ns() - t0) / 1e6
|
||||
c_lib.tokenlist_free(ctypes.byref(token_list))
|
||||
return dt_ms
|
||||
|
||||
def _run_once_rust(text: str) -> float:
|
||||
t0 = time.perf_counter_ns()
|
||||
rust_split_text(SPLIT_PATTERN, text)
|
||||
return (time.perf_counter_ns() - t0) / 1e6
|
||||
|
||||
def stats_summary(times: list) -> dict:
|
||||
"""Compute statistics from timing list."""
|
||||
if not times or len(times) == 0:
|
||||
return {}
|
||||
|
||||
return {
|
||||
'min': min(times),
|
||||
'max': max(times),
|
||||
'mean': statistics.mean(times),
|
||||
'median': statistics.median(times),
|
||||
'stdev': statistics.stdev(times) if len(times) > 1 else 0,
|
||||
}
|
||||
|
||||
def format_stats(name: str, data_size: int, times: list) -> str:
|
||||
"""Format timing statistics for output."""
|
||||
if not times or len(times) == 0:
|
||||
return f"{name:20} {data_size:>10} B --\n"
|
||||
|
||||
stats = stats_summary(times)
|
||||
|
||||
return (f"{name:20} {data_size:>10} B "
|
||||
f"min={stats['min']:.3f}ms max={stats['max']:.3f}ms "
|
||||
f"mean={stats['mean']:.3f}ms median={stats['median']:.3f}ms "
|
||||
f"stdev={stats['stdev']:.3f}ms\n")
|
||||
|
||||
def benchmark_dataset(name: str, data_bytes: bytes, iterations: int) -> None:
|
||||
test_text = data_bytes.decode('utf-8', errors='replace')
|
||||
|
||||
print(f"\n--- Dataset: {name} ({len(data_bytes)} bytes, {iterations} iterations) ---")
|
||||
print()
|
||||
|
||||
# Pre-touch data to avoid first-touch/page-fault skew
|
||||
if data_bytes:
|
||||
_ = data_bytes[0]
|
||||
for i in range(0, len(data_bytes), 4096):
|
||||
_ = data_bytes[i]
|
||||
|
||||
# Warm-up
|
||||
for _ in range(20):
|
||||
_run_once_c(data_bytes)
|
||||
_run_once_rust(test_text)
|
||||
|
||||
# Disable GC during timed section
|
||||
gc_was_enabled = gc.isenabled()
|
||||
if gc_was_enabled:
|
||||
gc.disable()
|
||||
|
||||
c_times = []
|
||||
rust_times = []
|
||||
for _ in range(iterations):
|
||||
c_times.append(_run_once_c(data_bytes))
|
||||
rust_times.append(_run_once_rust(test_text))
|
||||
|
||||
if gc_was_enabled:
|
||||
gc.enable()
|
||||
|
||||
print(format_stats("C tokenizer", len(data_bytes), c_times), end='')
|
||||
print(format_stats("Rust split", len(data_bytes), rust_times), end='')
|
||||
|
||||
if c_times and rust_times:
|
||||
c_mean = statistics.mean(c_times)
|
||||
rust_mean = statistics.mean(rust_times)
|
||||
ratio = rust_mean / c_mean
|
||||
speedup = "C is faster" if ratio > 1 else "Rust is faster"
|
||||
print(f"Speedup: {ratio:.2f}x ({speedup})")
|
||||
|
||||
print()
|
||||
|
||||
# Verify token splits match between C and Python regex tokenizer
|
||||
cmp_text = data_bytes.decode('utf-8', errors='surrogatepass')
|
||||
ok, err, out_c, out_py = compare_pair_text(cmp_text)
|
||||
if ok:
|
||||
print("Compare: OK (C vs Py splits match)")
|
||||
else:
|
||||
print("Compare: MISMATCH (C vs Py)")
|
||||
if err:
|
||||
print(err)
|
||||
if out_c is not None and out_py is not None:
|
||||
c_lines = out_c.splitlines()
|
||||
p_lines = out_py.splitlines()
|
||||
print(f"C tokens: {len(c_lines)} | Py tokens: {len(p_lines)}")
|
||||
print("--- C (head) ---")
|
||||
print("\n".join(c_lines[:10]))
|
||||
print("--- Py (head) ---")
|
||||
print("\n".join(p_lines[:10]))
|
||||
# Stop the benchmark if mismatch detected
|
||||
raise SystemExit(1)
|
||||
|
||||
def main():
|
||||
# Check if files were provided as arguments
|
||||
file_args = sys.argv[1:] if len(sys.argv) > 1 else []
|
||||
|
||||
# If files provided, benchmark them
|
||||
if file_args:
|
||||
for file_path in file_args:
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
print(f"❌ File not found: {file_path}")
|
||||
continue
|
||||
|
||||
try:
|
||||
data = path.read_bytes()
|
||||
benchmark_dataset(path.name, data, 1_000)
|
||||
except Exception as e:
|
||||
print(f"❌ Error reading {file_path}: {e}")
|
||||
else:
|
||||
# Run random generated data
|
||||
configs = [
|
||||
("tiny", 100, 1000),
|
||||
("small", 1024, 500),
|
||||
("medium", 10 * 1024, 100),
|
||||
("large", 100 * 1024, 100),
|
||||
("xlarge", 1024 * 1024, 100),
|
||||
]
|
||||
|
||||
for name, size_bytes, iterations in configs:
|
||||
# Generate test data
|
||||
test_text = gen_valid_unicode_string(
|
||||
random.Random(hash(name)),
|
||||
size_bytes
|
||||
)
|
||||
test_bytes = test_text.encode('utf-8')
|
||||
|
||||
benchmark_dataset(name, test_bytes, iterations)
|
||||
|
||||
print("=" * 140)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
51
fregex/cload.py
Normal file
51
fregex/cload.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
import ctypes
|
||||
|
||||
c_lib = ctypes.CDLL("fregex/libfregex.dylib")
|
||||
|
||||
|
||||
class TokenPos(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("start", ctypes.c_size_t),
|
||||
("end", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
|
||||
class TokenList(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("splits", ctypes.POINTER(TokenPos)),
|
||||
("count", ctypes.c_size_t),
|
||||
("capacity", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
|
||||
c_lib.tokenlist_init.argtypes = [ctypes.POINTER(TokenList)]
|
||||
c_lib.tokenlist_init.restype = None
|
||||
c_lib.tokenlist_free.argtypes = [ctypes.POINTER(TokenList)]
|
||||
c_lib.tokenlist_free.restype = None
|
||||
# Accept a raw pointer to the input buffer rather than a Python bytes object
|
||||
c_lib.tokenize_fast.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(TokenList)]
|
||||
c_lib.tokenize_fast.restype = None
|
||||
|
||||
def tokenize_c_bytes(data: bytes) -> list[bytes]:
|
||||
# Use a C char* view of the original bytes; offsets computed from this base
|
||||
c_data = ctypes.c_char_p(data)
|
||||
tl = TokenList()
|
||||
c_lib.tokenlist_init(ctypes.byref(tl))
|
||||
try:
|
||||
base_addr = ctypes.cast(c_data, ctypes.c_void_p).value
|
||||
# Pass the same pointer to C
|
||||
c_lib.tokenize_fast(ctypes.cast(c_data, ctypes.c_void_p), len(data), ctypes.byref(tl))
|
||||
out: list[bytes] = []
|
||||
count = int(tl.count)
|
||||
for i in range(count):
|
||||
start_addr = int(tl.splits[i].start)
|
||||
end_addr = int(tl.splits[i].end)
|
||||
# Compute offsets into our local buffer
|
||||
off_start = start_addr - base_addr
|
||||
off_end = end_addr - base_addr
|
||||
if off_start < 0 or off_end < off_start or off_end > len(data):
|
||||
raise RuntimeError(f"Invalid span [{start_addr}:{end_addr}] for buffer base {base_addr}")
|
||||
out.append(data[off_start:off_end])
|
||||
return out
|
||||
finally:
|
||||
c_lib.tokenlist_free(ctypes.byref(tl))
|
||||
162
fregex/compare.py
Normal file
162
fregex/compare.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
import sys
|
||||
import ctypes
|
||||
from pathlib import Path
|
||||
|
||||
from nanochat.tokenizer import SPLIT_PATTERN
|
||||
from rustbpe import split_text as rust_split_text
|
||||
from fregex.cload import *
|
||||
from fregex.py_tokenizer import tokenize_py as py_tokenize_str
|
||||
|
||||
def escape_bytes(b: bytes) -> str:
|
||||
buf = []
|
||||
for code in b:
|
||||
if code == 0x5C:
|
||||
buf.append('\\')
|
||||
elif code == 0x0A:
|
||||
buf.append('\\n')
|
||||
elif code == 0x0D:
|
||||
buf.append('\\r')
|
||||
elif code == 0x09:
|
||||
buf.append('\\t')
|
||||
elif code == 0x0C:
|
||||
buf.append('\\f')
|
||||
elif code == 0x0B:
|
||||
buf.append('\\v')
|
||||
elif code == 0x22:
|
||||
buf.append('\\"')
|
||||
elif code < 32 or code >= 127:
|
||||
buf.append(f"\\x{code:02X}")
|
||||
else:
|
||||
buf.append(chr(code))
|
||||
return ''.join(buf)
|
||||
|
||||
def dump_tokens(tokens: list[bytes]) -> str:
|
||||
return "\n".join(f"{len(b)}\t{escape_bytes(b)}" for b in tokens)
|
||||
|
||||
def tokenize_py_bytes(data: bytes) -> list[bytes]:
|
||||
text = data.decode('utf-8', errors='surrogatepass')
|
||||
toks = py_tokenize_str(text)
|
||||
return [t.encode('utf-8', errors='surrogatepass') for t in toks]
|
||||
|
||||
def tokenize_rs_bytes(data: bytes) -> list[bytes]:
|
||||
text = data.decode('utf-8', errors='surrogatepass')
|
||||
parts = rust_split_text(SPLIT_PATTERN, text)
|
||||
return [t.encode('utf-8', errors='surrogatepass') for t in parts]
|
||||
|
||||
def compare_one(path: Path) -> int:
|
||||
data_bytes = Path(path).read_bytes()
|
||||
try:
|
||||
c_toks = tokenize_c_bytes(data_bytes)
|
||||
except Exception as e:
|
||||
print(f"C tokenizer failed on {path}:\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
try:
|
||||
py_toks = tokenize_py_bytes(data_bytes)
|
||||
except Exception as e:
|
||||
print(f"Python tokenizer failed on {path}:\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
try:
|
||||
rs_toks = tokenize_rs_bytes(data_bytes)
|
||||
except Exception as e:
|
||||
print(f"Rust split failed on {path}:\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
out_c = dump_tokens(c_toks)
|
||||
out_py = dump_tokens(py_toks)
|
||||
out_rs = dump_tokens(rs_toks)
|
||||
|
||||
if out_c == out_py == out_rs:
|
||||
print(f"OK {path.name}")
|
||||
return 0
|
||||
else:
|
||||
print(f"DIFF {path.name}")
|
||||
# Show a small 3-way diff at first differing line, with byte offsets
|
||||
c_lines = out_c.splitlines()
|
||||
p_lines = out_py.splitlines()
|
||||
r_lines = out_rs.splitlines()
|
||||
|
||||
def parse_lines(lines):
|
||||
parsed = []
|
||||
for ln in lines:
|
||||
# Format is: "<len>\t<escaped>"
|
||||
try:
|
||||
left, right = ln.split('\t', 1)
|
||||
blen = int(left)
|
||||
except Exception:
|
||||
blen = 0
|
||||
right = ln
|
||||
parsed.append((blen, right))
|
||||
return parsed
|
||||
|
||||
c_parsed = parse_lines(c_lines)
|
||||
p_parsed = parse_lines(p_lines)
|
||||
r_parsed = parse_lines(r_lines)
|
||||
|
||||
def byte_offsets(parsed):
|
||||
offs = []
|
||||
pos = 0
|
||||
for blen, _ in parsed:
|
||||
offs.append((pos, pos + blen))
|
||||
pos += blen
|
||||
return offs
|
||||
|
||||
c_offs = byte_offsets(c_parsed)
|
||||
p_offs = byte_offsets(p_parsed)
|
||||
r_offs = byte_offsets(r_parsed)
|
||||
|
||||
data_bytes = Path(path).read_bytes()
|
||||
|
||||
def print_unicode_debug(label, offs_list, idx):
|
||||
if idx >= len(offs_list):
|
||||
print(f" {label} piece: [n/a]")
|
||||
return
|
||||
start, end = offs_list[idx]
|
||||
piece_bytes = data_bytes[start:end]
|
||||
piece_text = piece_bytes.decode('utf-8', errors='replace')
|
||||
if not piece_bytes:
|
||||
print(f" {label} piece: [EMPTY]")
|
||||
return
|
||||
cp_parts = []
|
||||
for ch in piece_text:
|
||||
cp_parts.append(f"U+{ord(ch):04X}")
|
||||
bytes_hex = ' '.join(f"{b:02X}" for b in piece_bytes)
|
||||
print(f" {label} chars: {' | '.join(cp_parts)}")
|
||||
print(f" {label} bytes: {bytes_hex} ({len(piece_bytes)}B, {len(piece_text)} chars)")
|
||||
|
||||
max_len = max(len(c_lines), len(p_lines), len(r_lines))
|
||||
for i in range(max_len):
|
||||
cl = c_lines[i] if i < len(c_lines) else "<eof>"
|
||||
pl = p_lines[i] if i < len(p_lines) else "<eof>"
|
||||
rl = r_lines[i] if i < len(r_lines) else "<eof>"
|
||||
if not (cl == pl == rl):
|
||||
# Collect byte positions if available
|
||||
c_pos = f"[{c_offs[i][0]}:{c_offs[i][1]}]" if i < len(c_offs) else "[n/a]"
|
||||
p_pos = f"[{p_offs[i][0]}:{p_offs[i][1]}]" if i < len(p_offs) else "[n/a]"
|
||||
r_pos = f"[{r_offs[i][0]}:{r_offs[i][1]}]" if i < len(r_offs) else "[n/a]"
|
||||
print(
|
||||
f" line {i+1}:\n"
|
||||
f" C: {cl} @ bytes {c_pos}\n"
|
||||
f" Py: {pl} @ bytes {p_pos}\n"
|
||||
f" Rs: {rl} @ bytes {r_pos}"
|
||||
)
|
||||
print(" === Unicode split detail ===")
|
||||
print_unicode_debug("C", c_offs, i)
|
||||
print_unicode_debug("Py", p_offs, i)
|
||||
print_unicode_debug("Rs", r_offs, i)
|
||||
break
|
||||
return 2
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: {sys.argv[0]} <tests-dir>")
|
||||
sys.exit(2)
|
||||
paths = sorted(Path(sys.argv[1]).glob('*.txt'))
|
||||
bad = 0
|
||||
for p in paths:
|
||||
bad += compare_one(p)
|
||||
print(f"Completed. Failures: {bad}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
498
fregex/fregex.c
Normal file
498
fregex/fregex.c
Normal file
|
|
@ -0,0 +1,498 @@
|
|||
#include "fregex-2.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <ctype.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "utf8proc/utf8proc.h"
|
||||
|
||||
/*
|
||||
Regex pattern we care about from nanochat/tokenizer.py SPLIT_PATTERN
|
||||
|
||||
Break it down:
|
||||
A) '(?i:[sdmt]|ll|ve|re)
|
||||
B) [^\r\n\p{L}\p{N}]?+\p{L}+
|
||||
C) \p{N}{1,2}
|
||||
D) ?[^\s\p{L}\p{N}]++[\r\n]*
|
||||
E) \s*[\r\n]
|
||||
F) \s+(?!\S)
|
||||
G) \s+
|
||||
*/
|
||||
|
||||
#define UNICODE_LF 0x000A // Line Feed
|
||||
#define UNICODE_CR 0x000D // Carriage Return
|
||||
|
||||
static inline size_t utf8_decode_cp(
|
||||
const char *s,
|
||||
const char *end,
|
||||
unsigned int *out_cp
|
||||
) {
|
||||
if (s >= end) {
|
||||
*out_cp = 0;
|
||||
return 0;
|
||||
}
|
||||
utf8proc_int32_t ch = 0;
|
||||
ssize_t ret = utf8proc_iterate(
|
||||
(const utf8proc_uint8_t*)s,
|
||||
(ssize_t)(end - s),
|
||||
&ch
|
||||
);
|
||||
if (ret < 0) {
|
||||
// invalid sequence: treat as single byte
|
||||
*out_cp = (unsigned char)*s;
|
||||
return 1;
|
||||
}
|
||||
*out_cp = (unsigned int)ch;
|
||||
return (size_t)ret;
|
||||
}
|
||||
|
||||
static inline bool is_utf8_cont_byte(unsigned char b) {
|
||||
return (b & 0xC0) == 0x80;
|
||||
}
|
||||
|
||||
static inline bool is_cr_or_lf(unsigned int cp) {
|
||||
return cp == UNICODE_LF || cp == UNICODE_CR;
|
||||
}
|
||||
|
||||
static inline bool is_letter(unsigned int cp) {
|
||||
utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp);
|
||||
|
||||
switch (cat) {
|
||||
case UTF8PROC_CATEGORY_LU: // Letter, Uppercase
|
||||
case UTF8PROC_CATEGORY_LL: // Letter, Lowercase
|
||||
case UTF8PROC_CATEGORY_LT: // Letter, Titlecase
|
||||
case UTF8PROC_CATEGORY_LM: // Letter, Modifier
|
||||
case UTF8PROC_CATEGORY_LO: // Letter, Other
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool is_number(unsigned int cp) {
|
||||
utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp);
|
||||
switch (cat) {
|
||||
case UTF8PROC_CATEGORY_ND: // Number, Decimal Digit
|
||||
case UTF8PROC_CATEGORY_NL: // Number, Letter
|
||||
case UTF8PROC_CATEGORY_NO: // Number, Other
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool is_space(unsigned int cp) {
|
||||
utf8proc_category_t cat = utf8proc_category((utf8proc_int32_t)cp);
|
||||
|
||||
if (
|
||||
cat == UTF8PROC_CATEGORY_ZS ||
|
||||
cat == UTF8PROC_CATEGORY_ZL ||
|
||||
cat == UTF8PROC_CATEGORY_ZP
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
switch (cp) {
|
||||
case 0x0009: // TAB
|
||||
case 0x000A: // LF
|
||||
case 0x000B: // VT
|
||||
case 0x000C: // FF
|
||||
case 0x000D: // CR
|
||||
case 0x0085: // NEL
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool is_alnum(unsigned int cp) {
|
||||
return is_letter(cp) || is_number(cp);
|
||||
}
|
||||
|
||||
static inline bool is_non_space_letter_number(unsigned int cp) {
|
||||
return !is_space(cp) && !is_alnum(cp);
|
||||
}
|
||||
|
||||
static void *xrealloc(void *ptr, size_t new_size) {
|
||||
void *p = realloc(ptr, new_size);
|
||||
if (!p) {
|
||||
fprintf(stderr, "Out of memory while reallocating %zu bytes\n", new_size);
|
||||
exit(1);
|
||||
}
|
||||
return p;
|
||||
}
|
||||
|
||||
static char *xmemdup(const char *src, size_t len) {
|
||||
char *s = (char*)malloc(len + 1);
|
||||
if (!s) {
|
||||
fprintf(stderr, "Out of memory while allocating %zu bytes\n", len + 1);
|
||||
exit(1);
|
||||
}
|
||||
memcpy(s, src, len);
|
||||
s[len] = '\0';
|
||||
return s;
|
||||
}
|
||||
|
||||
void tokenlist_init(TokenList *list) {
|
||||
list->splits = NULL;
|
||||
list->count = 0;
|
||||
list->capacity = 0;
|
||||
}
|
||||
|
||||
void tokenlist_free(TokenList *list) {
|
||||
if (!list)
|
||||
return;
|
||||
free(list->splits);
|
||||
list->splits = NULL;
|
||||
list->count = 0;
|
||||
list->capacity = 0;
|
||||
}
|
||||
|
||||
static void tokenlist_push(TokenList *list, const char *start, size_t len) {
|
||||
if (list->count == list->capacity) {
|
||||
const size_t new_cap = list->capacity ? (list->capacity * 2) : 128;
|
||||
list->splits = (TokenPos*)xrealloc(list->splits, new_cap * sizeof(TokenPos));
|
||||
list->capacity = new_cap;
|
||||
}
|
||||
/* Write the start / end position of string */
|
||||
list->splits[list->count].start = (size_t)start;
|
||||
list->splits[list->count].end = (size_t)(start + len); // len - 1 ?
|
||||
list->count++;
|
||||
}
|
||||
|
||||
static void fput_escaped_char(unsigned char c, FILE *out) {
|
||||
switch (c) {
|
||||
case '\\': fputs("\\\\", out); break;
|
||||
case '\n': fputs("\\n", out); break;
|
||||
case '\r': fputs("\\r", out); break;
|
||||
case '\t': fputs("\\t", out); break;
|
||||
case '\f': fputs("\\f", out); break;
|
||||
case '\v': fputs("\\v", out); break;
|
||||
case '\"': fputs("\\\"", out); break;
|
||||
default:
|
||||
if (c < 32 || c >= 127) {
|
||||
fprintf(out, "\\x%02X", c);
|
||||
} else {
|
||||
fputc(c, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* A) '(?i:[sdmt]|ll|ve|re) */
|
||||
static size_t match_contraction(const char *p, const char *end) {
|
||||
if (p >= end || *p != '\'' || (p + 1) >= end)
|
||||
return 0;
|
||||
|
||||
unsigned char a = (unsigned char)p[1];
|
||||
|
||||
// locale-independent lowercase for ASCII letters:
|
||||
// map A–Z to a–z; leaves others unchanged
|
||||
if (a >= 'A' && a <= 'Z')
|
||||
a = (unsigned char)(a + ('a' - 'A'));
|
||||
|
||||
// 's 'd 'm 't
|
||||
if (a == 's' || a == 'd' || a == 'm' || a == 't')
|
||||
return 2;
|
||||
|
||||
// Need a second following byte for 'll 've 're
|
||||
if (p + 2 >= end)
|
||||
return 0;
|
||||
|
||||
unsigned char b = (unsigned char)p[2];
|
||||
if (b >= 'A' && b <= 'Z')
|
||||
b = (unsigned char)(b + ('a' - 'A'));
|
||||
|
||||
if ((a == 'l' && b == 'l') ||
|
||||
(a == 'v' && b == 'e') ||
|
||||
(a == 'r' && b == 'e')) {
|
||||
return 3;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* B) [^\r\n\p{L}\p{N}]?+\p{L}+ */
|
||||
static size_t match_word_with_optional_prefix(const char *p, const char *end) {
|
||||
if (p >= end)
|
||||
return 0;
|
||||
|
||||
const char *q = p;
|
||||
unsigned int cp0;
|
||||
size_t n0 = utf8_decode_cp(q, end, &cp0);
|
||||
if (n0 == 0)
|
||||
return 0;
|
||||
|
||||
const char *letters_start = q; // will point to first letter of \p{L}+
|
||||
size_t prefix_bytes = 0;
|
||||
|
||||
// Consider optional one-codepoint prefix if first cp is NOT (CR/LF or alnum)
|
||||
if (!is_cr_or_lf(cp0) && !is_alnum(cp0)) {
|
||||
// Look ahead: must have at least one letter right after to commit the prefix
|
||||
const char *after_prefix = q + n0;
|
||||
if (after_prefix >= end) {
|
||||
return 0; // no room for required \p{L}+
|
||||
}
|
||||
unsigned int cp1;
|
||||
size_t n1 = utf8_decode_cp(after_prefix, end, &cp1);
|
||||
if (n1 > 0 && is_letter(cp1)) {
|
||||
// Commit the prefix (possessive) and start letters after it
|
||||
prefix_bytes = n0;
|
||||
letters_start = after_prefix;
|
||||
q = after_prefix;
|
||||
} else {
|
||||
// Can't commit the prefix (no letter follows), so this branch fails
|
||||
return 0;
|
||||
}
|
||||
} else if (is_letter(cp0)) {
|
||||
// No prefix; we already sit on the first letter
|
||||
letters_start = q;
|
||||
} else {
|
||||
// First cp is CR/LF or a number; this branch doesn't match
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Consume \p{L}+ (require at least one letter)
|
||||
size_t letter_count = 0;
|
||||
const char *scan = letters_start;
|
||||
while (scan < end) {
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(scan, end, &cp);
|
||||
if (n == 0 || !is_letter(cp))
|
||||
break;
|
||||
scan += n;
|
||||
letter_count++;
|
||||
}
|
||||
|
||||
if (letter_count == 0) {
|
||||
// Shouldn't happen given the look-ahead, but keep the guard
|
||||
return 0;
|
||||
}
|
||||
|
||||
return (size_t)(scan - p); // includes prefix (if any) + letters
|
||||
}
|
||||
|
||||
/* C) \p{N}{1,2} */
|
||||
static size_t match_short_number(const char *p, const char *end) {
|
||||
if (p >= end)
|
||||
return 0;
|
||||
|
||||
/* First number required */
|
||||
unsigned int cp1;
|
||||
size_t n1 = utf8_decode_cp(p, end, &cp1);
|
||||
if (n1 == 0 || !is_number(cp1))
|
||||
return 0;
|
||||
|
||||
const char *q = p + n1;
|
||||
|
||||
// Optional second number cp
|
||||
if (q < end) {
|
||||
unsigned int cp2;
|
||||
size_t n2 = utf8_decode_cp(q, end, &cp2);
|
||||
if (n2 > 0 && is_number(cp2))
|
||||
return (size_t)((q + n2) - p); // 2 numbers
|
||||
}
|
||||
return (size_t)(q - p); // 1 number
|
||||
}
|
||||
|
||||
/* D) ?[^\s\p{L}\p{N}]++[\r\n]* */
|
||||
static size_t match_punct_run(const char *p, const char *end) {
|
||||
const char *q = p;
|
||||
|
||||
/* Optional single ASCII space */
|
||||
if (q < end && *q == ' ') {
|
||||
const char *r = q + 1;
|
||||
if (r >= end)
|
||||
return 0;
|
||||
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(r, end, &cp);
|
||||
if (n == 0)
|
||||
return 0;
|
||||
|
||||
// Eligible b/c not whitespace and not alnum
|
||||
if (!is_space(cp) && !is_alnum(cp)) {
|
||||
q = r; // commit the space
|
||||
} else {
|
||||
// Not followed by eligible punct
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Now require at least one eligible (not whitespace, not alnum)
|
||||
size_t took = 0;
|
||||
while (q < end) {
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(q, end, &cp);
|
||||
if (n == 0) break;
|
||||
if (is_space(cp) || is_alnum(cp))
|
||||
break; // stop on any whitespace or letter/number
|
||||
q += n;
|
||||
took++;
|
||||
}
|
||||
if (took == 0)
|
||||
return 0; // must have at least one punctuation/symbol
|
||||
|
||||
// Finally, optionally absorb CR/LF sequence(s)
|
||||
while (q < end) {
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(q, end, &cp);
|
||||
if (n == 0 || !is_cr_or_lf(cp))
|
||||
break;
|
||||
q += n;
|
||||
}
|
||||
|
||||
return (size_t)(q - p);
|
||||
}
|
||||
|
||||
/* E) \s*[\r\n] */
|
||||
static size_t match_ws_then_linebreak(const char *p, const char *end) {
|
||||
const char *q = p;
|
||||
const char *best = NULL;
|
||||
|
||||
// Check boundary before consuming any whitespace, too (zero-length \s*)
|
||||
if (q < end) {
|
||||
unsigned int nx;
|
||||
size_t nn = utf8_decode_cp(q, end, &nx);
|
||||
if (nn > 0 && is_cr_or_lf(nx)) {
|
||||
best = q; // \s* = 0, [\r\n] = this char
|
||||
}
|
||||
}
|
||||
|
||||
// Scan whitespace; at each boundary, test the next cp
|
||||
while (q < end) {
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(q, end, &cp);
|
||||
if (n == 0 || !is_space(cp))
|
||||
break;
|
||||
q += n; // we consumed one whitespace cp; boundary is at q now
|
||||
|
||||
if (q < end) {
|
||||
unsigned int nx;
|
||||
size_t nn = utf8_decode_cp(q, end, &nx);
|
||||
if (nn > 0 && is_cr_or_lf(nx)) {
|
||||
best = q; // prefer the rightmost usable boundary
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!best) return 0;
|
||||
|
||||
// At 'best' the next cp is the CR/LF to include
|
||||
unsigned int br;
|
||||
size_t nb = utf8_decode_cp(best, end, &br);
|
||||
return (size_t)((best + nb) - p);
|
||||
}
|
||||
|
||||
/* F) \s+(?!\S) */
|
||||
static size_t match_trailing_ws(const char *p, const char *end) {
|
||||
if (p >= end)
|
||||
return 0;
|
||||
|
||||
// First cp must be whitespace
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(p, end, &cp);
|
||||
if (n == 0 || !is_space(cp))
|
||||
return 0;
|
||||
|
||||
// Consume full whitespace run [p, r)
|
||||
const char *r = p + n;
|
||||
while (r < end) {
|
||||
size_t m = utf8_decode_cp(r, end, &cp);
|
||||
if (m == 0 || !is_space(cp))
|
||||
break;
|
||||
r += m;
|
||||
}
|
||||
|
||||
if (r == end) {
|
||||
// Only whitespace to EOF -> take all of it
|
||||
return (size_t)(r - p);
|
||||
}
|
||||
|
||||
// Backtrack by exactly one whitespace cp
|
||||
// If the run length is only 1 cp, F must fail.
|
||||
// Find the start of the last whitespace cp in [p, r)
|
||||
const char *t = r;
|
||||
// step back to beginning of previous UTF-8 cp
|
||||
do {
|
||||
--t;
|
||||
} while (t > p && is_utf8_cont_byte(*t));
|
||||
|
||||
if (t == p) {
|
||||
// run had length 1 cp -> cannot backtrack to keep \s+ >= 1
|
||||
return 0;
|
||||
}
|
||||
// Now [p, t) is k-1 whitespace cps
|
||||
return (size_t)(t - p);
|
||||
}
|
||||
|
||||
/* G) \s+ */
|
||||
static size_t match_ws_run(const char *p, const char *end) {
|
||||
if (p >= end)
|
||||
return 0;
|
||||
|
||||
const char *q = p;
|
||||
unsigned int cp;
|
||||
size_t n = utf8_decode_cp(q, end, &cp);
|
||||
if (n == 0 || !is_space(cp))
|
||||
return 0;
|
||||
|
||||
/* We have at least one whitespace, consume the run */
|
||||
q += n;
|
||||
while (q < end) {
|
||||
size_t m = utf8_decode_cp(q, end, &cp);
|
||||
if (m == 0 || !is_space(cp))
|
||||
break;
|
||||
q += m;
|
||||
}
|
||||
return (size_t)(q - p);
|
||||
}
|
||||
|
||||
void tokenize_fast(const char *input, size_t input_len, TokenList *out) {
|
||||
if (!input) {
|
||||
out->splits = NULL;
|
||||
out->count = 0;
|
||||
out->capacity = 0;
|
||||
return;
|
||||
}
|
||||
const char *p = input;
|
||||
const char *end = input + input_len;
|
||||
|
||||
while (p < end) {
|
||||
/* Special tokens take precedence */
|
||||
// TODO LATER
|
||||
int captured = 0;
|
||||
|
||||
/* Evaluate case A */
|
||||
captured = match_contraction(p, end);
|
||||
/* Evaluate case B */
|
||||
if (!captured) captured = match_word_with_optional_prefix(p, end);
|
||||
/* Evaluate case C */
|
||||
if (!captured) captured = match_short_number(p, end);
|
||||
/* Evaluate case D */
|
||||
if (!captured) captured = match_punct_run(p, end);
|
||||
/* Evaluate case E */
|
||||
if (!captured) captured = match_ws_then_linebreak(p, end);
|
||||
/* Evaluate case F */
|
||||
if (!captured) captured = match_trailing_ws(p, end);
|
||||
/* Evaluate case G */
|
||||
if (!captured) captured = match_ws_run(p, end);
|
||||
|
||||
if (captured) {
|
||||
tokenlist_push(out, p, captured);
|
||||
p += captured;
|
||||
}
|
||||
else {
|
||||
/* Fallback take a single CP */
|
||||
unsigned int cp;
|
||||
size_t adv = utf8_decode_cp(p, end, &cp);
|
||||
if (adv == 0)
|
||||
adv = 1;
|
||||
tokenlist_push(out, p, adv);
|
||||
p += adv;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
21
fregex/fregex.h
Normal file
21
fregex/fregex.h
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
#ifndef FAST_REGEX_H
|
||||
#define FAST_REGEX_H
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
typedef struct {
|
||||
size_t start, end;
|
||||
} TokenPos;
|
||||
|
||||
typedef struct {
|
||||
TokenPos *splits;
|
||||
size_t count;
|
||||
size_t capacity;
|
||||
} TokenList;
|
||||
|
||||
void tokenlist_init(TokenList *list);
|
||||
void tokenlist_free(TokenList *list);
|
||||
void tokenize_fast(const char *input, size_t input_len, TokenList *out);
|
||||
|
||||
#endif // FAST_REGEX_H
|
||||
325
fregex/fuzz.py
Normal file
325
fregex/fuzz.py
Normal file
|
|
@ -0,0 +1,325 @@
|
|||
import sys
|
||||
import time
|
||||
import random
|
||||
import argparse
|
||||
import unicodedata as u
|
||||
import ctypes
|
||||
from pathlib import Path
|
||||
|
||||
from fregex.cload import *
|
||||
|
||||
HERE = Path(__file__).resolve().parent
|
||||
TESTS_DIR = HERE / "tests"
|
||||
|
||||
from fregex.py_tokenizer import tokenize_py as py_tokenize_str
|
||||
|
||||
def escape_bytes(b: bytes) -> str:
|
||||
buf = []
|
||||
for code in b:
|
||||
if code == 0x5C:
|
||||
buf.append('\\\\')
|
||||
elif code == 0x0A:
|
||||
buf.append('\\n')
|
||||
elif code == 0x0D:
|
||||
buf.append('\\r')
|
||||
elif code == 0x09:
|
||||
buf.append('\\t')
|
||||
elif code == 0x0C:
|
||||
buf.append('\\f')
|
||||
elif code == 0x0B:
|
||||
buf.append('\\v')
|
||||
elif code == 0x22:
|
||||
buf.append('\\"')
|
||||
elif code < 32 or code >= 127:
|
||||
buf.append(f"\\x{code:02X}")
|
||||
else:
|
||||
buf.append(chr(code))
|
||||
return ''.join(buf)
|
||||
|
||||
def gen_valid_unicode_string(rng: random.Random, max_len: int) -> str:
|
||||
target_len = rng.randint(0, max_len)
|
||||
|
||||
ws_cps = [
|
||||
0x20, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, # space, \t, \n, \v, \f, \r
|
||||
0x00A0, # NO-BREAK SPACE
|
||||
0x1680, # OGHAM SPACE MARK
|
||||
0x2000, 0x2001, 0x2002, 0x2003, 0x2004, 0x2005, 0x2006,
|
||||
0x2007, 0x2008, 0x2009, 0x200A, # EN/EM/THIN/HAIR SPACES etc.
|
||||
0x2028, 0x2029, # LINE SEPARATOR, PARAGRAPH SEPARATOR
|
||||
0x202F, # NARROW NO-BREAK SPACE
|
||||
0x205F, # MEDIUM MATHEMATICAL SPACE
|
||||
0x3000, # IDEOGRAPHIC SPACE
|
||||
0x200B, # ZERO WIDTH SPACE (not WS in Python, but hits tokenizer class)
|
||||
0xFEFF, # ZERO WIDTH NO-BREAK SPACE
|
||||
]
|
||||
|
||||
ascii_punct = [
|
||||
ord(c)
|
||||
for c in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
||||
]
|
||||
|
||||
def rand_scalar_excluding_surrogates(lo: int, hi: int) -> int:
|
||||
while True:
|
||||
cp = rng.randint(lo, hi)
|
||||
if 0xD800 <= cp <= 0xDFFF:
|
||||
continue
|
||||
return cp
|
||||
|
||||
def is_ws_char(ch: str) -> bool:
|
||||
cp = ord(ch)
|
||||
return ch.isspace() or (cp in ws_cps)
|
||||
|
||||
def gen_ws_segment(max_run: int) -> str:
|
||||
# Mix of various spaces, often multi-length; sometimes explicit CRLFs
|
||||
if rng.random() < 0.35:
|
||||
# Build CR, LF, CRLF, or repeated newlines
|
||||
seqs = ["\n", "\r", "\r\n"]
|
||||
unit = rng.choice(seqs)
|
||||
unit_len = len(unit)
|
||||
max_reps = max(1, max_run // unit_len)
|
||||
seg = unit * rng.randint(1, max_reps)
|
||||
return seg
|
||||
run = rng.randint(1, max(1, max_run))
|
||||
buf = []
|
||||
for _ in range(run):
|
||||
cp = rng.choice(ws_cps)
|
||||
buf.append(chr(cp))
|
||||
return ''.join(buf)
|
||||
|
||||
def gen_letter_run(max_run: int) -> str:
|
||||
run = rng.randint(1, max(1, max_run))
|
||||
buf = []
|
||||
for _ in range(run):
|
||||
if rng.random() < 0.6:
|
||||
# ASCII letters
|
||||
base = ord('A') if rng.random() < 0.5 else ord('a')
|
||||
buf.append(chr(base + rng.randint(0, 25)))
|
||||
else:
|
||||
# Any Unicode letter
|
||||
while True:
|
||||
cp = rand_scalar_excluding_surrogates(0x00A0, 0x10FFFF)
|
||||
if u.category(chr(cp)).startswith('L'):
|
||||
buf.append(chr(cp))
|
||||
break
|
||||
# optional prefix of single non-WS, non-letter, non-number to stress
|
||||
# the leading [^\r\n\p{L}\p{N}]?+ in the regex
|
||||
if rng.random() < 0.3:
|
||||
buf.insert(0, gen_punc_run(1, allow_space=False))
|
||||
return ''.join(buf)
|
||||
|
||||
def gen_number_run(max_run: int) -> str:
|
||||
# Bias to lengths 1..2 per \p{N}{1,2}, but sometimes longer
|
||||
if rng.random() < 0.7:
|
||||
run = rng.randint(1, min(2, max_run))
|
||||
else:
|
||||
run = rng.randint(3, max(3, max_run))
|
||||
buf = []
|
||||
for _ in range(run):
|
||||
if rng.random() < 0.75:
|
||||
buf.append(chr(ord('0') + rng.randint(0, 9)))
|
||||
else:
|
||||
# Other numeric categories (Nd/Nl/No)
|
||||
while True:
|
||||
cp = rand_scalar_excluding_surrogates(0x00A0, 0x10FFFF)
|
||||
if u.category(chr(cp)).startswith('N'):
|
||||
buf.append(chr(cp))
|
||||
break
|
||||
return ''.join(buf)
|
||||
|
||||
def gen_punc_run(max_run: int, allow_space: bool = True) -> str:
|
||||
run = rng.randint(1, max(1, max_run))
|
||||
buf = []
|
||||
# optional leading single space before punc block
|
||||
if allow_space and rng.random() < 0.5:
|
||||
buf.append(' ')
|
||||
for _ in range(run):
|
||||
if rng.random() < 0.6:
|
||||
cp = rng.choice(ascii_punct)
|
||||
else:
|
||||
while True:
|
||||
cp = rand_scalar_excluding_surrogates(0, 0x10FFFF)
|
||||
ch = chr(cp)
|
||||
if (
|
||||
not u.category(ch).startswith('L') and
|
||||
not u.category(ch).startswith('N') and
|
||||
cp not in ws_cps and
|
||||
not ch.isspace()
|
||||
):
|
||||
break
|
||||
# ensure we don't accidentally add null
|
||||
buf.append(chr(cp))
|
||||
# optional trailing newlines to stress [\r\n]*
|
||||
if rng.random() < 0.35:
|
||||
tail = gen_ws_segment(3)
|
||||
# Keep only CR/LF components in the tail for this case
|
||||
tail = tail.replace('\t', '').replace('\v', '').replace('\f', '').replace(' ', '')
|
||||
buf.append(tail)
|
||||
return ''.join(buf)
|
||||
|
||||
def gen_contraction() -> str:
|
||||
# e.g., we're, he'll, I'd, I'm, can't, they've
|
||||
prefixes = [gen_letter_run( rng.randint(1, 6) )]
|
||||
suffix = rng.choice(["s", "d", "m", "t", "ll", "ve", "re"])
|
||||
return prefixes[0] + "'" + suffix
|
||||
|
||||
def gen_random_unicode(max_run: int) -> str:
|
||||
run = rng.randint(1, max(1, max_run))
|
||||
buf = []
|
||||
for _ in range(run):
|
||||
cp = rand_scalar_excluding_surrogates(0, 0x10FFFF)
|
||||
try:
|
||||
buf.append(chr(cp))
|
||||
except ValueError:
|
||||
continue
|
||||
return ''.join(buf)
|
||||
|
||||
buf: list[str] = []
|
||||
curr_len = 0
|
||||
# Build by segments until target_len
|
||||
while curr_len < target_len:
|
||||
remain = target_len - curr_len
|
||||
r = rng.random()
|
||||
if r < 0.40:
|
||||
seg = gen_ws_segment(remain)
|
||||
elif r < 0.45:
|
||||
# Explicit newline-focused segment
|
||||
seg = ("\r\n" if rng.random() < 0.5 else ("\n" if rng.random() < 0.5 else "\r")) * rng.randint(1, max(1, remain))
|
||||
elif r < 0.65:
|
||||
seg = gen_letter_run(remain)
|
||||
elif r < 0.75:
|
||||
seg = gen_number_run(remain)
|
||||
elif r < 0.90:
|
||||
seg = gen_punc_run(remain)
|
||||
elif r < 0.95:
|
||||
seg = gen_contraction()
|
||||
else:
|
||||
seg = gen_random_unicode(remain)
|
||||
|
||||
if not seg:
|
||||
continue
|
||||
# Trim if needed
|
||||
# Append
|
||||
for ch in seg:
|
||||
if curr_len >= target_len:
|
||||
break
|
||||
if is_ws_char(ch):
|
||||
buf.append(ch)
|
||||
curr_len += 1
|
||||
else:
|
||||
buf.append(ch)
|
||||
curr_len += 1
|
||||
|
||||
# Occasionally end with trailing spaces to stress \s+(?!\S)
|
||||
if curr_len < max_len and rng.random() < 0.3:
|
||||
trail = gen_ws_segment(max_len - curr_len)
|
||||
if rng.random() < 0.7:
|
||||
trail = (' ' if rng.random() < 0.6 else '\t') * rng.randint(1, min(8, max_len - curr_len))
|
||||
# Append trailing
|
||||
for ch in trail:
|
||||
if curr_len >= max_len:
|
||||
break
|
||||
if is_ws_char(ch):
|
||||
buf.append(ch)
|
||||
curr_len += 1
|
||||
else:
|
||||
buf.append(ch)
|
||||
curr_len += 1
|
||||
|
||||
return ''.join(buf)
|
||||
|
||||
def write_temp_case(text: str, tag: str = "RUN") -> Path:
|
||||
TESTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
ts = int(time.time() * 1000)
|
||||
fname = f"in_fuzz_{tag}_{ts}.txt"
|
||||
path = TESTS_DIR / fname
|
||||
with open(path, 'wb') as f:
|
||||
f.write(text.encode('utf-8', errors='surrogatepass'))
|
||||
return path
|
||||
|
||||
def _format_tokens_dump(tokens: list[bytes]) -> str:
|
||||
lines = []
|
||||
for b in tokens:
|
||||
lines.append(f"{len(b)}\t{escape_bytes(b)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def tokenize_py_bytes(data: bytes) -> list[bytes]:
|
||||
if py_tokenize_str is None:
|
||||
raise RuntimeError("py_tokenizer not available")
|
||||
text = data.decode('utf-8', errors='surrogatepass')
|
||||
toks = py_tokenize_str(text)
|
||||
return [t.encode('utf-8', errors='surrogatepass') for t in toks]
|
||||
|
||||
def compare_pair_text(text: str):
|
||||
data = text.encode('utf-8', errors='surrogatepass')
|
||||
try:
|
||||
toks_c = tokenize_c_bytes(data)
|
||||
except Exception as e:
|
||||
return False, f"C failed: {e}", None, None
|
||||
try:
|
||||
toks_py = tokenize_py_bytes(data)
|
||||
except Exception as e:
|
||||
return False, f"Py failed: {e}", None, None
|
||||
ok = toks_c == toks_py
|
||||
return ok, None, _format_tokens_dump(toks_c), _format_tokens_dump(toks_py)
|
||||
|
||||
def run_fuzz(iters: int, max_len: int, seed: int, stop_on_first: bool):
|
||||
rng = random.Random(seed)
|
||||
total = 0
|
||||
mismatches = 0
|
||||
last_save = None
|
||||
|
||||
for i in range(iters if iters > 0 else 1_000_000_000):
|
||||
s = gen_valid_unicode_string(rng, max_len)
|
||||
ok, err, out_c, out_py = compare_pair_text(s)
|
||||
total += 1
|
||||
if not ok:
|
||||
mismatches += 1
|
||||
fail_path = write_temp_case(s, tag="FAIL")
|
||||
last_save = fail_path
|
||||
print(f"Mismatch at iter {i}, saved to {fail_path}")
|
||||
print(f"Seed: {seed}")
|
||||
|
||||
cps = [f"U+{ord(ch):04X}" for ch in s]
|
||||
cats = [u.category(ch) for ch in s]
|
||||
print(f"Text bytes len: {len(s.encode('utf-8','surrogatepass'))}, chars: {len(s)}")
|
||||
print(f"Codepoints: {' '.join(cps)}")
|
||||
print(f"Categories: {' '.join(cats)}")
|
||||
if err:
|
||||
print(err)
|
||||
if out_c is not None:
|
||||
print("--- C tokens ---")
|
||||
print(out_c)
|
||||
if out_py is not None:
|
||||
print("--- Py tokens ---")
|
||||
print(out_py)
|
||||
|
||||
if stop_on_first:
|
||||
break
|
||||
|
||||
if (i + 1) % 100 == 0:
|
||||
print(f"[fuzz] {i+1} cases, mismatches={mismatches}")
|
||||
# print(out_c, out_py, sep="\n")
|
||||
|
||||
return total, mismatches, last_save
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Fuzz C vs Python tokenizers on random valid UTF-8 inputs")
|
||||
ap.add_argument("--iters", type=int, default=0, help="Number of iterations (0 = very large run)")
|
||||
ap.add_argument("--max-len", type=int, default=256, help="Maximum number of Unicode scalars per case")
|
||||
ap.add_argument("--seed", type=int, default=12345, help="PRNG seed for reproducibility")
|
||||
ap.add_argument("--stop-on-first", action="store_true", help="Stop at first mismatch (default: run all)")
|
||||
args = ap.parse_args()
|
||||
|
||||
total, mismatches, last = run_fuzz(args.iters, args.max_len, args.seed, args.stop_on_first)
|
||||
print(f"Completed {total} cases, mismatches={mismatches}")
|
||||
if last:
|
||||
print(f"Last failing case saved at: {last}")
|
||||
if mismatches:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
43
fregex/py_tokenizer.py
Normal file
43
fregex/py_tokenizer.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import sys
|
||||
|
||||
from nanochat.tokenizer import SPLIT_PATTERN
|
||||
from tokenizers import pre_tokenizers, Regex
|
||||
|
||||
_SPLITTER = pre_tokenizers.Split(pattern=Regex(SPLIT_PATTERN), behavior="isolated", invert=False)
|
||||
|
||||
def escape_bytes(b: bytes) -> str:
|
||||
buf = []
|
||||
for code in b:
|
||||
if code == 0x5C: buf.append('\\\\')
|
||||
elif code == 0x0A: buf.append('\\n')
|
||||
elif code == 0x0D: buf.append('\\r')
|
||||
elif code == 0x09: buf.append('\\t')
|
||||
elif code == 0x0C: buf.append('\\f')
|
||||
elif code == 0x0B: buf.append('\\v')
|
||||
elif code == 0x22: buf.append('\\"')
|
||||
elif code < 32 or code >= 127:
|
||||
buf.append(f"\\x{code:02X}")
|
||||
else:
|
||||
buf.append(chr(code))
|
||||
return ''.join(buf)
|
||||
|
||||
def tokenize_py(text: str):
|
||||
parts = _SPLITTER.pre_tokenize_str(text)
|
||||
return [s for (s, _range) in parts]
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: {sys.argv[0]} <input-file>", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
with open(sys.argv[1], 'rb') as f:
|
||||
data = f.read()
|
||||
text = data.decode('utf-8', errors='surrogatepass')
|
||||
for tok in tokenize_py(text):
|
||||
b = tok.encode('utf-8', errors='surrogatepass')
|
||||
esc = escape_bytes(b)
|
||||
print(f"{len(b)}\t{esc}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
38
fregex/rust_split.py
Normal file
38
fregex/rust_split.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import sys
|
||||
|
||||
from nanochat.tokenizer import SPLIT_PATTERN
|
||||
from rustbpe import split_text as rust_split_text
|
||||
|
||||
def escape_bytes(b: bytes) -> str:
|
||||
buf = []
|
||||
for code in b:
|
||||
if code == 0x5C: buf.append('\\\\')
|
||||
elif code == 0x0A: buf.append('\\n')
|
||||
elif code == 0x0D: buf.append('\\r')
|
||||
elif code == 0x09: buf.append('\\t')
|
||||
elif code == 0x0C: buf.append('\\f')
|
||||
elif code == 0x0B: buf.append('\\v')
|
||||
elif code == 0x22: buf.append('\\"')
|
||||
elif code < 32 or code >= 127:
|
||||
buf.append(f"\\x{code:02X}")
|
||||
else:
|
||||
buf.append(chr(code))
|
||||
return ''.join(buf)
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: {sys.argv[0]} <input-file>", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
with open(sys.argv[1], 'rb') as f:
|
||||
data = f.read()
|
||||
text = data.decode('utf-8', errors='surrogatepass')
|
||||
parts = rust_split_text(SPLIT_PATTERN, text)
|
||||
for tok in parts:
|
||||
b = tok.encode('utf-8', errors='surrogatepass')
|
||||
esc = escape_bytes(b)
|
||||
print(f"{len(b)}\t{esc}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
|
@ -8,7 +8,7 @@ Notable features:
|
|||
- norm after token embedding
|
||||
- no learnable params in rmsnorm
|
||||
- no bias in linear layers
|
||||
- Multi-Query Attention (MQA) support for more efficient inference
|
||||
- Group-Query Attention (GQA) support for more efficient inference
|
||||
"""
|
||||
|
||||
import math
|
||||
|
|
@ -29,7 +29,7 @@ class GPTConfig:
|
|||
vocab_size: int = 50304
|
||||
n_layer: int = 12
|
||||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (MQA)
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_embd: int = 768
|
||||
|
||||
|
||||
|
|
|
|||
18
rustbpe/Cargo.lock
generated
18
rustbpe/Cargo.lock
generated
|
|
@ -186,6 +186,16 @@ version = "0.2.175"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
version = "0.8.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.28"
|
||||
|
|
@ -363,7 +373,9 @@ dependencies = [
|
|||
"dary_heap",
|
||||
"fancy-regex",
|
||||
"indexmap",
|
||||
"libloading",
|
||||
"log",
|
||||
"once_cell",
|
||||
"pyo3",
|
||||
"pyo3-log",
|
||||
"rayon",
|
||||
|
|
@ -431,6 +443,12 @@ dependencies = [
|
|||
"wit-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-link"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen"
|
||||
version = "0.45.1"
|
||||
|
|
|
|||
|
|
@ -13,3 +13,5 @@ pyo3-log = "0.12.4"
|
|||
ahash = "0.8.12"
|
||||
rayon = "1.11.0"
|
||||
compact_str = "0.9.0"
|
||||
libloading = "0.8.5"
|
||||
once_cell = "1.19.0"
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@ use std::collections::HashMap as StdHashMap;
|
|||
|
||||
use dary_heap::OctonaryHeap;
|
||||
use fancy_regex::Regex;
|
||||
use libloading::Library;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
|
||||
use ahash::{AHashMap, AHashSet};
|
||||
use compact_str::CompactString;
|
||||
use rayon::prelude::*;
|
||||
|
||||
// Default GPT-4 style regex pattern for splitting text
|
||||
|
|
@ -14,6 +16,79 @@ const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{
|
|||
|
||||
type Pair = (u32, u32);
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
type c_char = std::os::raw::c_char;
|
||||
|
||||
#[repr(C)]
|
||||
struct CTokenPos {
|
||||
start: usize,
|
||||
end: usize,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
struct CTokenList {
|
||||
splits: *mut CTokenPos,
|
||||
count: usize,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
type FnTokenListInit = unsafe extern "C" fn(list: *mut CTokenList);
|
||||
type FnTokenListFree = unsafe extern "C" fn(list: *mut CTokenList);
|
||||
type FnTokenizeFast = unsafe extern "C" fn(input: *const c_char, input_len: usize, out: *mut CTokenList);
|
||||
|
||||
struct FregexSymbols {
|
||||
_lib: Library,
|
||||
tokenlist_init: FnTokenListInit,
|
||||
tokenlist_free: FnTokenListFree,
|
||||
tokenize_fast: FnTokenizeFast,
|
||||
}
|
||||
|
||||
static FREGEX: OnceCell<FregexSymbols> = OnceCell::new();
|
||||
|
||||
fn load_fregex() -> &'static FregexSymbols {
|
||||
FREGEX.get_or_init(|| {
|
||||
// NOTE: adjust this path per user if needed.
|
||||
let path = "fregex/libfregex.dylib";
|
||||
let lib = unsafe { Library::new(path) }.unwrap_or_else(|e| {
|
||||
panic!("Failed to load libfregex from {}: {}", path, e);
|
||||
});
|
||||
unsafe {
|
||||
let tokenlist_init: FnTokenListInit = *lib.get(b"tokenlist_init\0").expect("symbol tokenlist_init");
|
||||
let tokenlist_free: FnTokenListFree = *lib.get(b"tokenlist_free\0").expect("symbol tokenlist_free");
|
||||
let tokenize_fast: FnTokenizeFast = *lib.get(b"tokenize_fast\0").expect("symbol tokenize_fast");
|
||||
println!("rustbpe: loaded libfregex.dylib from {}", path);
|
||||
FregexSymbols { _lib: lib, tokenlist_init, tokenlist_free, tokenize_fast }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn tokenize_with_c_each<'a, F>(input: &'a [u8], mut on_piece: F)
|
||||
where
|
||||
F: FnMut(&'a [u8]),
|
||||
{
|
||||
if input.is_empty() {
|
||||
return;
|
||||
}
|
||||
let syms = load_fregex();
|
||||
let mut out = CTokenList { splits: std::ptr::null_mut(), count: 0, capacity: 0 };
|
||||
let base_ptr = input.as_ptr() as usize;
|
||||
unsafe {
|
||||
(syms.tokenlist_init)(&mut out as *mut CTokenList);
|
||||
(syms.tokenize_fast)(input.as_ptr() as *const c_char, input.len(), &mut out as *mut CTokenList);
|
||||
if !out.splits.is_null() {
|
||||
let slice = std::slice::from_raw_parts(out.splits, out.count);
|
||||
for pos in slice.iter() {
|
||||
let start = pos.start.saturating_sub(base_ptr);
|
||||
let end = pos.end.saturating_sub(base_ptr);
|
||||
if end <= input.len() && start <= end {
|
||||
on_piece(&input[start..end]);
|
||||
}
|
||||
}
|
||||
}
|
||||
(syms.tokenlist_free)(&mut out as *mut CTokenList);
|
||||
}
|
||||
}
|
||||
|
||||
/// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
|
||||
#[pyclass]
|
||||
pub struct Tokenizer {
|
||||
|
|
@ -295,8 +370,8 @@ impl Tokenizer {
|
|||
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
|
||||
};
|
||||
|
||||
// Global chunk counts
|
||||
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
|
||||
// Global chunk counts: own bytes once per unique chunk (no string copies)
|
||||
let mut counts: AHashMap<Vec<u8>, i32> = AHashMap::new();
|
||||
|
||||
// Temporary buffer we refill under the GIL
|
||||
let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
|
||||
|
|
@ -344,31 +419,28 @@ impl Tokenizer {
|
|||
|
||||
total_sequences += buf.len() as u64;
|
||||
|
||||
let pattern = self.compiled_pattern.clone();
|
||||
let local: AHashMap<CompactString, i32> = py.allow_threads(|| {
|
||||
// Build per-string local counts that reference the buffer slices (no allocations)
|
||||
let locals: Vec<Vec<(&[u8], i32)>> = py.allow_threads(|| {
|
||||
buf.par_iter()
|
||||
.map(|s| {
|
||||
let mut m: AHashMap<CompactString, i32> = AHashMap::new();
|
||||
for mat in pattern.find_iter(s) {
|
||||
let piece = mat.expect("regex match failed").as_str();
|
||||
*m.entry(CompactString::from(piece)).or_default() += 1;
|
||||
}
|
||||
m
|
||||
let mut m: AHashMap<&[u8], i32> = AHashMap::new();
|
||||
let bytes = s.as_bytes();
|
||||
tokenize_with_c_each(bytes, |piece| { *m.entry(piece).or_default() += 1; });
|
||||
// Materialize as Vec to allow merging after parallel section
|
||||
m.into_iter().collect::<Vec<(&[u8], i32)>>()
|
||||
})
|
||||
.reduce(
|
||||
|| AHashMap::new(),
|
||||
|mut a, b| {
|
||||
for (k, v) in b {
|
||||
*a.entry(k).or_default() += v;
|
||||
}
|
||||
a
|
||||
},
|
||||
)
|
||||
.collect()
|
||||
});
|
||||
|
||||
// Merge local into global (single-threaded)
|
||||
for (k, v) in local {
|
||||
*counts.entry(k).or_default() += v;
|
||||
// Merge locals into global (single-threaded) without copying unless inserting new keys
|
||||
for local in locals {
|
||||
for (piece, v) in local {
|
||||
if let Some(cnt) = counts.get_mut(piece) {
|
||||
*cnt += v;
|
||||
} else {
|
||||
counts.insert(piece.to_vec(), v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if exhausted {
|
||||
|
|
@ -380,8 +452,8 @@ impl Tokenizer {
|
|||
// Materialize words & counts
|
||||
let mut words = Vec::with_capacity(counts.len());
|
||||
let mut cvec = Vec::with_capacity(counts.len());
|
||||
for (chunk, c) in counts.into_iter() {
|
||||
words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect()));
|
||||
for (chunk_bytes, c) in counts.into_iter() {
|
||||
words.push(Word::new(chunk_bytes.iter().map(|&b| b as u32).collect()));
|
||||
cvec.push(c);
|
||||
}
|
||||
|
||||
|
|
@ -467,9 +539,23 @@ impl Tokenizer {
|
|||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn split_text(pattern: String, text: String) -> PyResult<Vec<String>> {
|
||||
let re = Regex::new(&pattern)
|
||||
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?;
|
||||
let mut out: Vec<String> = Vec::new();
|
||||
for m in re.find_iter(&text) {
|
||||
let m = m
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("Regex match failed: {}", e)))?;
|
||||
out.push(m.as_str().to_string());
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
pyo3_log::init(); // forwards Rust `log` to Python's `logging`
|
||||
m.add_class::<Tokenizer>()?;
|
||||
m.add_function(wrap_pyfunction!(split_text, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user