remove string allocations

This commit is contained in:
MadMax129 2025-10-24 17:06:06 -04:00
parent 41c8b8dbde
commit 851810c7d5
6 changed files with 118 additions and 109 deletions

View File

@ -3,37 +3,45 @@ import ctypes
import random
import time
import statistics
import os
import gc
from pathlib import Path
from nanochat.tokenizer import SPLIT_PATTERN
os.environ.update({
'OMP_NUM_THREADS': '1',
'OPENBLAS_NUM_THREADS': '1',
'MKL_NUM_THREADS': '1',
'VECLIB_MAXIMUM_THREADS': '1',
'NUMEXPR_NUM_THREADS': '1',
'RAYON_NUM_THREADS': '1',
})
os.setpriority(os.PRIO_PROCESS, 0, -10)
from rustbpe import split_text as rust_split_text
from fregex.fuzz import gen_valid_unicode_string, compare_pair_text
from fregex.cload import *
def bench_c_regex(data: bytes, iterations: int) -> list:
times = []
for _ in range(iterations):
token_list = TokenList()
c_lib.tokenlist_init(ctypes.byref(token_list))
start = time.perf_counter()
c_lib.tokenize_fast(data, len(data), ctypes.byref(token_list))
elapsed = time.perf_counter() - start
c_lib.tokenlist_free(ctypes.byref(token_list))
times.append(elapsed * 1000)
return times
PyBytes_AsString = ctypes.pythonapi.PyBytes_AsString
PyBytes_AsString.restype = ctypes.c_void_p
PyBytes_AsString.argtypes = [ctypes.py_object]
def bench_rust_regex(text: str, iterations: int) -> list:
times = []
for _ in range(iterations):
start = time.perf_counter()
rust_split_text(SPLIT_PATTERN, text)
elapsed = time.perf_counter() - start
times.append(elapsed * 1000)
return times
def _run_once_c(data: bytes) -> float:
token_list = TokenList()
c_lib.tokenlist_init(ctypes.byref(token_list))
base_ptr = PyBytes_AsString(data)
t0 = time.perf_counter_ns()
c_lib.tokenize_fast(base_ptr, len(data), ctypes.byref(token_list))
dt_ms = (time.perf_counter_ns() - t0) / 1e6
c_lib.tokenlist_free(ctypes.byref(token_list))
return dt_ms
def _run_once_rust(text: str) -> float:
t0 = time.perf_counter_ns()
rust_split_text(SPLIT_PATTERN, text)
return (time.perf_counter_ns() - t0) / 1e6
def stats_summary(times: list) -> dict:
"""Compute statistics from timing list."""
@ -65,11 +73,33 @@ def benchmark_dataset(name: str, data_bytes: bytes, iterations: int) -> None:
print(f"\n--- Dataset: {name} ({len(data_bytes)} bytes, {iterations} iterations) ---")
print()
c_times = bench_c_regex(data_bytes, iterations)
# Pre-touch data to avoid first-touch/page-fault skew
if data_bytes:
_ = data_bytes[0]
for i in range(0, len(data_bytes), 4096):
_ = data_bytes[i]
# Warm-up
for _ in range(20):
_run_once_c(data_bytes)
_run_once_rust(test_text)
# Disable GC during timed section
gc_was_enabled = gc.isenabled()
if gc_was_enabled:
gc.disable()
c_times = []
rust_times = []
for _ in range(iterations):
c_times.append(_run_once_c(data_bytes))
rust_times.append(_run_once_rust(test_text))
if gc_was_enabled:
gc.enable()
print(format_stats("C tokenizer", len(data_bytes), c_times), end='')
rust_times = bench_rust_regex(test_text, iterations)
print(format_stats("Rust split", len(data_bytes), rust_times), end='')
if c_times and rust_times:
@ -115,7 +145,7 @@ def main():
try:
data = path.read_bytes()
benchmark_dataset(path.name, data, 10_000)
benchmark_dataset(path.name, data, 1_000)
except Exception as e:
print(f"❌ Error reading {file_path}: {e}")
else:
@ -124,8 +154,8 @@ def main():
("tiny", 100, 1000),
("small", 1024, 500),
("medium", 10 * 1024, 100),
("large", 100 * 1024, 30),
("xlarge", 1024 * 1024, 10),
("large", 100 * 1024, 100),
("xlarge", 1024 * 1024, 100),
]
for name, size_bytes, iterations in configs:
@ -140,6 +170,5 @@ def main():
print("=" * 140)
if __name__ == "__main__":
main()

View File

@ -2,19 +2,50 @@ import ctypes
c_lib = ctypes.CDLL("fregex/libfregex.dylib")
class TokenList(ctypes.Structure):
pass
TokenList._fields_ = [
("tokens", ctypes.POINTER(ctypes.POINTER(ctypes.c_char))),
("lengths", ctypes.POINTER(ctypes.c_size_t)),
("count", ctypes.c_size_t),
("capacity", ctypes.c_size_t),
]
class TokenPos(ctypes.Structure):
_fields_ = [
("start", ctypes.c_size_t),
("end", ctypes.c_size_t),
]
class TokenList(ctypes.Structure):
_fields_ = [
("splits", ctypes.POINTER(TokenPos)),
("count", ctypes.c_size_t),
("capacity", ctypes.c_size_t),
]
c_lib.tokenlist_init.argtypes = [ctypes.POINTER(TokenList)]
c_lib.tokenlist_init.restype = None
c_lib.tokenlist_free.argtypes = [ctypes.POINTER(TokenList)]
c_lib.tokenlist_free.restype = None
c_lib.tokenize_fast.argtypes = [ctypes.c_char_p, ctypes.c_size_t, ctypes.POINTER(TokenList)]
c_lib.tokenize_fast.restype = None
# Accept a raw pointer to the input buffer rather than a Python bytes object
c_lib.tokenize_fast.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(TokenList)]
c_lib.tokenize_fast.restype = None
def tokenize_c_bytes(data: bytes) -> list[bytes]:
# Use a C char* view of the original bytes; offsets computed from this base
c_data = ctypes.c_char_p(data)
tl = TokenList()
c_lib.tokenlist_init(ctypes.byref(tl))
try:
base_addr = ctypes.cast(c_data, ctypes.c_void_p).value
# Pass the same pointer to C
c_lib.tokenize_fast(ctypes.cast(c_data, ctypes.c_void_p), len(data), ctypes.byref(tl))
out: list[bytes] = []
count = int(tl.count)
for i in range(count):
start_addr = int(tl.splits[i].start)
end_addr = int(tl.splits[i].end)
# Compute offsets into our local buffer
off_start = start_addr - base_addr
off_end = end_addr - base_addr
if off_start < 0 or off_end < off_start or off_end > len(data):
raise RuntimeError(f"Invalid span [{start_addr}:{end_addr}] for buffer base {base_addr}")
out.append(data[off_start:off_end])
return out
finally:
c_lib.tokenlist_free(ctypes.byref(tl))

View File

@ -33,21 +33,6 @@ def escape_bytes(b: bytes) -> str:
def dump_tokens(tokens: list[bytes]) -> str:
return "\n".join(f"{len(b)}\t{escape_bytes(b)}" for b in tokens)
def tokenize_c_bytes(data: bytes) -> list[bytes]:
tl = TokenList()
c_lib.tokenlist_init(ctypes.byref(tl))
try:
c_lib.tokenize_fast(data, len(data), ctypes.byref(tl))
out: list[bytes] = []
count = int(tl.count)
for i in range(count):
ptr = tl.tokens[i]
ln = int(tl.lengths[i])
out.append(ctypes.string_at(ptr, ln))
return out
finally:
c_lib.tokenlist_free(ctypes.byref(tl))
def tokenize_py_bytes(data: bytes) -> list[bytes]:
text = data.decode('utf-8', errors='surrogatepass')
toks = py_tokenize_str(text)

View File

@ -1,4 +1,4 @@
#include "fregex.h"
#include "fregex-2.h"
#include <stdlib.h>
#include <string.h>
@ -136,34 +136,29 @@ static char *xmemdup(const char *src, size_t len) {
}
void tokenlist_init(TokenList *list) {
list->tokens = NULL;
list->lengths = NULL;
list->splits = NULL;
list->count = 0;
list->capacity = 0;
}
void tokenlist_free(TokenList *list) {
if (!list)
if (!list)
return;
for (size_t i = 0; i < list->count; ++i)
free(list->tokens[i]);
free(list->tokens);
free(list->lengths);
list->tokens = NULL;
list->lengths = NULL;
list->count = 0;
free(list->splits);
list->splits = NULL;
list->count = 0;
list->capacity = 0;
}
static void tokenlist_push(TokenList *list, const char *start, size_t len) {
if (list->count == list->capacity) {
const size_t new_cap = list->capacity ? (list->capacity * 2) : 64;
list->tokens = (char**)xrealloc(list->tokens, new_cap * sizeof(char*));
list->lengths = (size_t*)xrealloc(list->lengths, new_cap * sizeof(size_t));
const size_t new_cap = list->capacity ? (list->capacity * 2) : 128;
list->splits = (TokenPos*)xrealloc(list->splits, new_cap * sizeof(TokenPos));
list->capacity = new_cap;
}
list->tokens[list->count] = xmemdup(start, len);
list->lengths[list->count] = len;
/* Write the start / end position of string */
list->splits[list->count].start = (size_t)start;
list->splits[list->count].end = (size_t)(start + len); // len - 1 ?
list->count++;
}
@ -185,20 +180,6 @@ static void fput_escaped_char(unsigned char c, FILE *out) {
}
}
void print_token_escaped(const char *s, size_t len, FILE *out) {
fprintf(out, "%zu\t", len);
for (size_t i = 0; i < len; ++i) fput_escaped_char((unsigned char)s[i], out);
fputc('\n', out);
}
void print_tokens_escaped(const TokenList *list, FILE *out) {
for (size_t i = 0; i < list->count; ++i) {
const char *tok = list->tokens[i];
size_t len = list->lengths[i];
print_token_escaped(tok, len, out);
}
}
/* A) '(?i:[sdmt]|ll|ve|re) */
static size_t match_contraction(const char *p, const char *end) {
if (p >= end || *p != '\'' || (p + 1) >= end)
@ -470,7 +451,7 @@ static size_t match_ws_run(const char *p, const char *end) {
void tokenize_fast(const char *input, size_t input_len, TokenList *out) {
if (!input) {
out->tokens = NULL;
out->splits = NULL;
out->count = 0;
out->capacity = 0;
return;

View File

@ -5,19 +5,17 @@
#include <stdio.h>
typedef struct {
char **tokens;
size_t *lengths;
size_t start, end;
} TokenPos;
typedef struct {
TokenPos *splits;
size_t count;
size_t capacity;
} TokenList;
void tokenlist_init(TokenList *list);
void tokenlist_free(TokenList *list);
void tokenize_fast(const char *input, size_t input_len, TokenList *out);
void print_token_escaped(const char *s, size_t len, FILE *out);
void print_tokens_escaped(const TokenList *list, FILE *out);
#endif // FAST_REGEX_H

View File

@ -242,21 +242,6 @@ def _format_tokens_dump(tokens: list[bytes]) -> str:
lines.append(f"{len(b)}\t{escape_bytes(b)}")
return "\n".join(lines)
def tokenize_c_bytes(data: bytes) -> list[bytes]:
tl = TokenList()
c_lib.tokenlist_init(ctypes.byref(tl))
try:
c_lib.tokenize_fast(data, len(data), ctypes.byref(tl))
out: list[bytes] = []
count = int(tl.count)
for i in range(count):
ptr = tl.tokens[i]
ln = int(tl.lengths[i])
out.append(ctypes.string_at(ptr, ln))
return out
finally:
c_lib.tokenlist_free(ctypes.byref(tl))
def tokenize_py_bytes(data: bytes) -> list[bytes]:
if py_tokenize_str is None:
raise RuntimeError("py_tokenizer not available")