mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-04 14:45:25 +00:00
remove string allocations
This commit is contained in:
parent
41c8b8dbde
commit
851810c7d5
|
|
@ -3,37 +3,45 @@ import ctypes
|
|||
import random
|
||||
import time
|
||||
import statistics
|
||||
import os
|
||||
import gc
|
||||
from pathlib import Path
|
||||
|
||||
from nanochat.tokenizer import SPLIT_PATTERN
|
||||
|
||||
os.environ.update({
|
||||
'OMP_NUM_THREADS': '1',
|
||||
'OPENBLAS_NUM_THREADS': '1',
|
||||
'MKL_NUM_THREADS': '1',
|
||||
'VECLIB_MAXIMUM_THREADS': '1',
|
||||
'NUMEXPR_NUM_THREADS': '1',
|
||||
'RAYON_NUM_THREADS': '1',
|
||||
})
|
||||
|
||||
os.setpriority(os.PRIO_PROCESS, 0, -10)
|
||||
|
||||
from rustbpe import split_text as rust_split_text
|
||||
from fregex.fuzz import gen_valid_unicode_string, compare_pair_text
|
||||
from fregex.cload import *
|
||||
|
||||
def bench_c_regex(data: bytes, iterations: int) -> list:
|
||||
times = []
|
||||
for _ in range(iterations):
|
||||
token_list = TokenList()
|
||||
c_lib.tokenlist_init(ctypes.byref(token_list))
|
||||
|
||||
start = time.perf_counter()
|
||||
c_lib.tokenize_fast(data, len(data), ctypes.byref(token_list))
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
c_lib.tokenlist_free(ctypes.byref(token_list))
|
||||
times.append(elapsed * 1000)
|
||||
|
||||
return times
|
||||
PyBytes_AsString = ctypes.pythonapi.PyBytes_AsString
|
||||
PyBytes_AsString.restype = ctypes.c_void_p
|
||||
PyBytes_AsString.argtypes = [ctypes.py_object]
|
||||
|
||||
def bench_rust_regex(text: str, iterations: int) -> list:
|
||||
times = []
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
rust_split_text(SPLIT_PATTERN, text)
|
||||
elapsed = time.perf_counter() - start
|
||||
times.append(elapsed * 1000)
|
||||
|
||||
return times
|
||||
def _run_once_c(data: bytes) -> float:
|
||||
token_list = TokenList()
|
||||
c_lib.tokenlist_init(ctypes.byref(token_list))
|
||||
base_ptr = PyBytes_AsString(data)
|
||||
t0 = time.perf_counter_ns()
|
||||
c_lib.tokenize_fast(base_ptr, len(data), ctypes.byref(token_list))
|
||||
dt_ms = (time.perf_counter_ns() - t0) / 1e6
|
||||
c_lib.tokenlist_free(ctypes.byref(token_list))
|
||||
return dt_ms
|
||||
|
||||
def _run_once_rust(text: str) -> float:
|
||||
t0 = time.perf_counter_ns()
|
||||
rust_split_text(SPLIT_PATTERN, text)
|
||||
return (time.perf_counter_ns() - t0) / 1e6
|
||||
|
||||
def stats_summary(times: list) -> dict:
|
||||
"""Compute statistics from timing list."""
|
||||
|
|
@ -65,11 +73,33 @@ def benchmark_dataset(name: str, data_bytes: bytes, iterations: int) -> None:
|
|||
|
||||
print(f"\n--- Dataset: {name} ({len(data_bytes)} bytes, {iterations} iterations) ---")
|
||||
print()
|
||||
|
||||
c_times = bench_c_regex(data_bytes, iterations)
|
||||
|
||||
# Pre-touch data to avoid first-touch/page-fault skew
|
||||
if data_bytes:
|
||||
_ = data_bytes[0]
|
||||
for i in range(0, len(data_bytes), 4096):
|
||||
_ = data_bytes[i]
|
||||
|
||||
# Warm-up
|
||||
for _ in range(20):
|
||||
_run_once_c(data_bytes)
|
||||
_run_once_rust(test_text)
|
||||
|
||||
# Disable GC during timed section
|
||||
gc_was_enabled = gc.isenabled()
|
||||
if gc_was_enabled:
|
||||
gc.disable()
|
||||
|
||||
c_times = []
|
||||
rust_times = []
|
||||
for _ in range(iterations):
|
||||
c_times.append(_run_once_c(data_bytes))
|
||||
rust_times.append(_run_once_rust(test_text))
|
||||
|
||||
if gc_was_enabled:
|
||||
gc.enable()
|
||||
|
||||
print(format_stats("C tokenizer", len(data_bytes), c_times), end='')
|
||||
|
||||
rust_times = bench_rust_regex(test_text, iterations)
|
||||
print(format_stats("Rust split", len(data_bytes), rust_times), end='')
|
||||
|
||||
if c_times and rust_times:
|
||||
|
|
@ -115,7 +145,7 @@ def main():
|
|||
|
||||
try:
|
||||
data = path.read_bytes()
|
||||
benchmark_dataset(path.name, data, 10_000)
|
||||
benchmark_dataset(path.name, data, 1_000)
|
||||
except Exception as e:
|
||||
print(f"❌ Error reading {file_path}: {e}")
|
||||
else:
|
||||
|
|
@ -124,8 +154,8 @@ def main():
|
|||
("tiny", 100, 1000),
|
||||
("small", 1024, 500),
|
||||
("medium", 10 * 1024, 100),
|
||||
("large", 100 * 1024, 30),
|
||||
("xlarge", 1024 * 1024, 10),
|
||||
("large", 100 * 1024, 100),
|
||||
("xlarge", 1024 * 1024, 100),
|
||||
]
|
||||
|
||||
for name, size_bytes, iterations in configs:
|
||||
|
|
@ -140,6 +170,5 @@ def main():
|
|||
|
||||
print("=" * 140)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -2,19 +2,50 @@ import ctypes
|
|||
|
||||
c_lib = ctypes.CDLL("fregex/libfregex.dylib")
|
||||
|
||||
class TokenList(ctypes.Structure):
|
||||
pass
|
||||
|
||||
TokenList._fields_ = [
|
||||
("tokens", ctypes.POINTER(ctypes.POINTER(ctypes.c_char))),
|
||||
("lengths", ctypes.POINTER(ctypes.c_size_t)),
|
||||
("count", ctypes.c_size_t),
|
||||
("capacity", ctypes.c_size_t),
|
||||
]
|
||||
class TokenPos(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("start", ctypes.c_size_t),
|
||||
("end", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
|
||||
class TokenList(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("splits", ctypes.POINTER(TokenPos)),
|
||||
("count", ctypes.c_size_t),
|
||||
("capacity", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
|
||||
c_lib.tokenlist_init.argtypes = [ctypes.POINTER(TokenList)]
|
||||
c_lib.tokenlist_init.restype = None
|
||||
c_lib.tokenlist_free.argtypes = [ctypes.POINTER(TokenList)]
|
||||
c_lib.tokenlist_free.restype = None
|
||||
c_lib.tokenize_fast.argtypes = [ctypes.c_char_p, ctypes.c_size_t, ctypes.POINTER(TokenList)]
|
||||
c_lib.tokenize_fast.restype = None
|
||||
# Accept a raw pointer to the input buffer rather than a Python bytes object
|
||||
c_lib.tokenize_fast.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(TokenList)]
|
||||
c_lib.tokenize_fast.restype = None
|
||||
|
||||
def tokenize_c_bytes(data: bytes) -> list[bytes]:
|
||||
# Use a C char* view of the original bytes; offsets computed from this base
|
||||
c_data = ctypes.c_char_p(data)
|
||||
tl = TokenList()
|
||||
c_lib.tokenlist_init(ctypes.byref(tl))
|
||||
try:
|
||||
base_addr = ctypes.cast(c_data, ctypes.c_void_p).value
|
||||
# Pass the same pointer to C
|
||||
c_lib.tokenize_fast(ctypes.cast(c_data, ctypes.c_void_p), len(data), ctypes.byref(tl))
|
||||
out: list[bytes] = []
|
||||
count = int(tl.count)
|
||||
for i in range(count):
|
||||
start_addr = int(tl.splits[i].start)
|
||||
end_addr = int(tl.splits[i].end)
|
||||
# Compute offsets into our local buffer
|
||||
off_start = start_addr - base_addr
|
||||
off_end = end_addr - base_addr
|
||||
if off_start < 0 or off_end < off_start or off_end > len(data):
|
||||
raise RuntimeError(f"Invalid span [{start_addr}:{end_addr}] for buffer base {base_addr}")
|
||||
out.append(data[off_start:off_end])
|
||||
return out
|
||||
finally:
|
||||
c_lib.tokenlist_free(ctypes.byref(tl))
|
||||
|
|
@ -33,21 +33,6 @@ def escape_bytes(b: bytes) -> str:
|
|||
def dump_tokens(tokens: list[bytes]) -> str:
|
||||
return "\n".join(f"{len(b)}\t{escape_bytes(b)}" for b in tokens)
|
||||
|
||||
def tokenize_c_bytes(data: bytes) -> list[bytes]:
|
||||
tl = TokenList()
|
||||
c_lib.tokenlist_init(ctypes.byref(tl))
|
||||
try:
|
||||
c_lib.tokenize_fast(data, len(data), ctypes.byref(tl))
|
||||
out: list[bytes] = []
|
||||
count = int(tl.count)
|
||||
for i in range(count):
|
||||
ptr = tl.tokens[i]
|
||||
ln = int(tl.lengths[i])
|
||||
out.append(ctypes.string_at(ptr, ln))
|
||||
return out
|
||||
finally:
|
||||
c_lib.tokenlist_free(ctypes.byref(tl))
|
||||
|
||||
def tokenize_py_bytes(data: bytes) -> list[bytes]:
|
||||
text = data.decode('utf-8', errors='surrogatepass')
|
||||
toks = py_tokenize_str(text)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#include "fregex.h"
|
||||
#include "fregex-2.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
|
@ -136,34 +136,29 @@ static char *xmemdup(const char *src, size_t len) {
|
|||
}
|
||||
|
||||
void tokenlist_init(TokenList *list) {
|
||||
list->tokens = NULL;
|
||||
list->lengths = NULL;
|
||||
list->splits = NULL;
|
||||
list->count = 0;
|
||||
list->capacity = 0;
|
||||
}
|
||||
|
||||
void tokenlist_free(TokenList *list) {
|
||||
if (!list)
|
||||
if (!list)
|
||||
return;
|
||||
for (size_t i = 0; i < list->count; ++i)
|
||||
free(list->tokens[i]);
|
||||
free(list->tokens);
|
||||
free(list->lengths);
|
||||
list->tokens = NULL;
|
||||
list->lengths = NULL;
|
||||
list->count = 0;
|
||||
free(list->splits);
|
||||
list->splits = NULL;
|
||||
list->count = 0;
|
||||
list->capacity = 0;
|
||||
}
|
||||
|
||||
static void tokenlist_push(TokenList *list, const char *start, size_t len) {
|
||||
if (list->count == list->capacity) {
|
||||
const size_t new_cap = list->capacity ? (list->capacity * 2) : 64;
|
||||
list->tokens = (char**)xrealloc(list->tokens, new_cap * sizeof(char*));
|
||||
list->lengths = (size_t*)xrealloc(list->lengths, new_cap * sizeof(size_t));
|
||||
const size_t new_cap = list->capacity ? (list->capacity * 2) : 128;
|
||||
list->splits = (TokenPos*)xrealloc(list->splits, new_cap * sizeof(TokenPos));
|
||||
list->capacity = new_cap;
|
||||
}
|
||||
list->tokens[list->count] = xmemdup(start, len);
|
||||
list->lengths[list->count] = len;
|
||||
/* Write the start / end position of string */
|
||||
list->splits[list->count].start = (size_t)start;
|
||||
list->splits[list->count].end = (size_t)(start + len); // len - 1 ?
|
||||
list->count++;
|
||||
}
|
||||
|
||||
|
|
@ -185,20 +180,6 @@ static void fput_escaped_char(unsigned char c, FILE *out) {
|
|||
}
|
||||
}
|
||||
|
||||
void print_token_escaped(const char *s, size_t len, FILE *out) {
|
||||
fprintf(out, "%zu\t", len);
|
||||
for (size_t i = 0; i < len; ++i) fput_escaped_char((unsigned char)s[i], out);
|
||||
fputc('\n', out);
|
||||
}
|
||||
|
||||
void print_tokens_escaped(const TokenList *list, FILE *out) {
|
||||
for (size_t i = 0; i < list->count; ++i) {
|
||||
const char *tok = list->tokens[i];
|
||||
size_t len = list->lengths[i];
|
||||
print_token_escaped(tok, len, out);
|
||||
}
|
||||
}
|
||||
|
||||
/* A) '(?i:[sdmt]|ll|ve|re) */
|
||||
static size_t match_contraction(const char *p, const char *end) {
|
||||
if (p >= end || *p != '\'' || (p + 1) >= end)
|
||||
|
|
@ -470,7 +451,7 @@ static size_t match_ws_run(const char *p, const char *end) {
|
|||
|
||||
void tokenize_fast(const char *input, size_t input_len, TokenList *out) {
|
||||
if (!input) {
|
||||
out->tokens = NULL;
|
||||
out->splits = NULL;
|
||||
out->count = 0;
|
||||
out->capacity = 0;
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -5,19 +5,17 @@
|
|||
#include <stdio.h>
|
||||
|
||||
typedef struct {
|
||||
char **tokens;
|
||||
size_t *lengths;
|
||||
size_t start, end;
|
||||
} TokenPos;
|
||||
|
||||
typedef struct {
|
||||
TokenPos *splits;
|
||||
size_t count;
|
||||
size_t capacity;
|
||||
} TokenList;
|
||||
|
||||
void tokenlist_init(TokenList *list);
|
||||
void tokenlist_free(TokenList *list);
|
||||
|
||||
void tokenize_fast(const char *input, size_t input_len, TokenList *out);
|
||||
void print_token_escaped(const char *s, size_t len, FILE *out);
|
||||
void print_tokens_escaped(const TokenList *list, FILE *out);
|
||||
|
||||
#endif // FAST_REGEX_H
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -242,21 +242,6 @@ def _format_tokens_dump(tokens: list[bytes]) -> str:
|
|||
lines.append(f"{len(b)}\t{escape_bytes(b)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def tokenize_c_bytes(data: bytes) -> list[bytes]:
|
||||
tl = TokenList()
|
||||
c_lib.tokenlist_init(ctypes.byref(tl))
|
||||
try:
|
||||
c_lib.tokenize_fast(data, len(data), ctypes.byref(tl))
|
||||
out: list[bytes] = []
|
||||
count = int(tl.count)
|
||||
for i in range(count):
|
||||
ptr = tl.tokens[i]
|
||||
ln = int(tl.lengths[i])
|
||||
out.append(ctypes.string_at(ptr, ln))
|
||||
return out
|
||||
finally:
|
||||
c_lib.tokenlist_free(ctypes.byref(tl))
|
||||
|
||||
def tokenize_py_bytes(data: bytes) -> list[bytes]:
|
||||
if py_tokenize_str is None:
|
||||
raise RuntimeError("py_tokenizer not available")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user