mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-07 21:02:15 +00:00
175 lines
5.4 KiB
Python
Executable File
175 lines
5.4 KiB
Python
Executable File
import sys
|
|
import ctypes
|
|
import random
|
|
import time
|
|
import statistics
|
|
import os
|
|
import gc
|
|
from pathlib import Path
|
|
|
|
from nanochat.tokenizer import SPLIT_PATTERN
|
|
|
|
os.environ.update({
|
|
'OMP_NUM_THREADS': '1',
|
|
'OPENBLAS_NUM_THREADS': '1',
|
|
'MKL_NUM_THREADS': '1',
|
|
'VECLIB_MAXIMUM_THREADS': '1',
|
|
'NUMEXPR_NUM_THREADS': '1',
|
|
'RAYON_NUM_THREADS': '1',
|
|
})
|
|
|
|
os.setpriority(os.PRIO_PROCESS, 0, -10)
|
|
|
|
from rustbpe import split_text as rust_split_text
|
|
from fregex.fuzz import gen_valid_unicode_string, compare_pair_text
|
|
from fregex.cload import *
|
|
|
|
PyBytes_AsString = ctypes.pythonapi.PyBytes_AsString
|
|
PyBytes_AsString.restype = ctypes.c_void_p
|
|
PyBytes_AsString.argtypes = [ctypes.py_object]
|
|
|
|
def _run_once_c(data: bytes) -> float:
|
|
token_list = TokenList()
|
|
c_lib.tokenlist_init(ctypes.byref(token_list))
|
|
base_ptr = PyBytes_AsString(data)
|
|
t0 = time.perf_counter_ns()
|
|
c_lib.tokenize_fast(base_ptr, len(data), ctypes.byref(token_list))
|
|
dt_ms = (time.perf_counter_ns() - t0) / 1e6
|
|
c_lib.tokenlist_free(ctypes.byref(token_list))
|
|
return dt_ms
|
|
|
|
def _run_once_rust(text: str) -> float:
|
|
t0 = time.perf_counter_ns()
|
|
rust_split_text(SPLIT_PATTERN, text)
|
|
return (time.perf_counter_ns() - t0) / 1e6
|
|
|
|
def stats_summary(times: list) -> dict:
|
|
"""Compute statistics from timing list."""
|
|
if not times or len(times) == 0:
|
|
return {}
|
|
|
|
return {
|
|
'min': min(times),
|
|
'max': max(times),
|
|
'mean': statistics.mean(times),
|
|
'median': statistics.median(times),
|
|
'stdev': statistics.stdev(times) if len(times) > 1 else 0,
|
|
}
|
|
|
|
def format_stats(name: str, data_size: int, times: list) -> str:
|
|
"""Format timing statistics for output."""
|
|
if not times or len(times) == 0:
|
|
return f"{name:20} {data_size:>10} B --\n"
|
|
|
|
stats = stats_summary(times)
|
|
|
|
return (f"{name:20} {data_size:>10} B "
|
|
f"min={stats['min']:.3f}ms max={stats['max']:.3f}ms "
|
|
f"mean={stats['mean']:.3f}ms median={stats['median']:.3f}ms "
|
|
f"stdev={stats['stdev']:.3f}ms\n")
|
|
|
|
def benchmark_dataset(name: str, data_bytes: bytes, iterations: int) -> None:
|
|
test_text = data_bytes.decode('utf-8', errors='replace')
|
|
|
|
print(f"\n--- Dataset: {name} ({len(data_bytes)} bytes, {iterations} iterations) ---")
|
|
print()
|
|
|
|
# Pre-touch data to avoid first-touch/page-fault skew
|
|
if data_bytes:
|
|
_ = data_bytes[0]
|
|
for i in range(0, len(data_bytes), 4096):
|
|
_ = data_bytes[i]
|
|
|
|
# Warm-up
|
|
for _ in range(20):
|
|
_run_once_c(data_bytes)
|
|
_run_once_rust(test_text)
|
|
|
|
# Disable GC during timed section
|
|
gc_was_enabled = gc.isenabled()
|
|
if gc_was_enabled:
|
|
gc.disable()
|
|
|
|
c_times = []
|
|
rust_times = []
|
|
for _ in range(iterations):
|
|
c_times.append(_run_once_c(data_bytes))
|
|
rust_times.append(_run_once_rust(test_text))
|
|
|
|
if gc_was_enabled:
|
|
gc.enable()
|
|
|
|
print(format_stats("C tokenizer", len(data_bytes), c_times), end='')
|
|
print(format_stats("Rust split", len(data_bytes), rust_times), end='')
|
|
|
|
if c_times and rust_times:
|
|
c_mean = statistics.mean(c_times)
|
|
rust_mean = statistics.mean(rust_times)
|
|
ratio = rust_mean / c_mean
|
|
speedup = "C is faster" if ratio > 1 else "Rust is faster"
|
|
print(f"Speedup: {ratio:.2f}x ({speedup})")
|
|
|
|
print()
|
|
|
|
# Verify token splits match between C and Python regex tokenizer
|
|
cmp_text = data_bytes.decode('utf-8', errors='surrogatepass')
|
|
ok, err, out_c, out_py = compare_pair_text(cmp_text)
|
|
if ok:
|
|
print("Compare: OK (C vs Py splits match)")
|
|
else:
|
|
print("Compare: MISMATCH (C vs Py)")
|
|
if err:
|
|
print(err)
|
|
if out_c is not None and out_py is not None:
|
|
c_lines = out_c.splitlines()
|
|
p_lines = out_py.splitlines()
|
|
print(f"C tokens: {len(c_lines)} | Py tokens: {len(p_lines)}")
|
|
print("--- C (head) ---")
|
|
print("\n".join(c_lines[:10]))
|
|
print("--- Py (head) ---")
|
|
print("\n".join(p_lines[:10]))
|
|
# Stop the benchmark if mismatch detected
|
|
raise SystemExit(1)
|
|
|
|
def main():
|
|
# Check if files were provided as arguments
|
|
file_args = sys.argv[1:] if len(sys.argv) > 1 else []
|
|
|
|
# If files provided, benchmark them
|
|
if file_args:
|
|
for file_path in file_args:
|
|
path = Path(file_path)
|
|
if not path.exists():
|
|
print(f"❌ File not found: {file_path}")
|
|
continue
|
|
|
|
try:
|
|
data = path.read_bytes()
|
|
benchmark_dataset(path.name, data, 1_000)
|
|
except Exception as e:
|
|
print(f"❌ Error reading {file_path}: {e}")
|
|
else:
|
|
# Run random generated data
|
|
configs = [
|
|
("tiny", 100, 1000),
|
|
("small", 1024, 500),
|
|
("medium", 10 * 1024, 100),
|
|
("large", 100 * 1024, 100),
|
|
("xlarge", 1024 * 1024, 100),
|
|
]
|
|
|
|
for name, size_bytes, iterations in configs:
|
|
# Generate test data
|
|
test_text = gen_valid_unicode_string(
|
|
random.Random(hash(name)),
|
|
size_bytes
|
|
)
|
|
test_bytes = test_text.encode('utf-8')
|
|
|
|
benchmark_dataset(name, test_bytes, iterations)
|
|
|
|
print("=" * 140)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|