mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-09 03:59:52 +00:00
Merge branch 'karpathy:master' into master
This commit is contained in:
commit
7cdb05856d
|
|
@ -113,12 +113,24 @@ def print_banner():
|
|||
"""
|
||||
print0(banner)
|
||||
|
||||
def is_ddp():
|
||||
# TODO is there a proper way
|
||||
return int(os.environ.get('RANK', -1)) != -1
|
||||
def is_ddp_requested() -> bool:
|
||||
"""
|
||||
True if launched by torchrun (env present), even before init.
|
||||
Used to decide whether we *should* initialize a PG.
|
||||
"""
|
||||
return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
|
||||
|
||||
def is_ddp_initialized() -> bool:
|
||||
"""
|
||||
True if torch.distributed is available and the process group is initialized.
|
||||
Used at cleanup to avoid destroying a non-existent PG.
|
||||
"""
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
def get_dist_info():
|
||||
if is_ddp():
|
||||
if is_ddp_requested():
|
||||
# We rely on torchrun's env to decide if we SHOULD init.
|
||||
# (Initialization itself happens in compute init.)
|
||||
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
||||
ddp_rank = int(os.environ['RANK'])
|
||||
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
||||
|
|
@ -158,11 +170,11 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
|||
|
||||
# Precision
|
||||
if device_type == "cuda":
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls
|
||||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if ddp and device_type == "cuda":
|
||||
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if is_ddp_requested and device_type == "cuda":
|
||||
device = torch.device("cuda", ddp_local_rank)
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
|
|
@ -173,11 +185,11 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
|||
if ddp_rank == 0:
|
||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||
|
||||
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
|
||||
return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
|
||||
|
||||
def compute_cleanup():
|
||||
"""Companion function to compute_init, to clean things up before script exit"""
|
||||
if is_ddp():
|
||||
if is_ddp_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
class DummyWandb:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from contextlib import contextmanager
|
|||
from collections import deque
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from contextlib import nullcontext
|
||||
from contextlib import nullcontext
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Calculator tool helpers
|
||||
|
|
@ -107,23 +107,23 @@ class KVCache:
|
|||
# 1) validate the shapes
|
||||
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
|
||||
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
|
||||
|
||||
|
||||
# Extract dimensions explicitly
|
||||
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
|
||||
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
|
||||
|
||||
|
||||
# Validate dimensions
|
||||
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
|
||||
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
|
||||
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
|
||||
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
|
||||
|
||||
|
||||
# Batch size can be expanded (other can be 1, self can be larger)
|
||||
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
|
||||
|
||||
|
||||
# Sequence length: self must be longer than other
|
||||
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
|
||||
|
||||
|
||||
# 2) initialize the cache
|
||||
dtype, device = other.kv_cache.dtype, other.kv_cache.device
|
||||
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
|
||||
|
|
@ -223,9 +223,7 @@ class Engine:
|
|||
)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
||||
logits = logits[:, -1, :]
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
|
||||
|
||||
# 2) Replicate the KV cache for each sample/row
|
||||
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
||||
|
|
@ -242,7 +240,6 @@ class Engine:
|
|||
|
||||
# 4) Main generation loop
|
||||
num_generated = 0
|
||||
first_iteration = True
|
||||
while True:
|
||||
# Stop condition: we've reached max tokens
|
||||
if max_tokens is not None and num_generated >= max_tokens:
|
||||
|
|
@ -251,18 +248,9 @@ class Engine:
|
|||
if all(state.completed for state in row_states):
|
||||
break
|
||||
|
||||
# Get sampled tokens - either from prefill or from forward pass
|
||||
if first_iteration:
|
||||
# Use the tokens we already sampled from prefill
|
||||
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
|
||||
# TODO: we should sample a token for each row instead of broadcasting
|
||||
first_iteration = False
|
||||
else:
|
||||
# Forward the model and get the next token for each row
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
|
||||
logits = logits[:, -1, :] # (B, vocab_size) at last time step
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
# Sample the next token for each row
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
|
||||
# Process each row: choose the next token, update state, optional tool use
|
||||
token_column = [] # contains the next token id along each row
|
||||
|
|
@ -299,8 +287,10 @@ class Engine:
|
|||
# Yield the token column
|
||||
yield token_column, token_masks
|
||||
num_generated += 1
|
||||
# Prepare ids for next iteration
|
||||
|
||||
# Prepare logits for next iteration
|
||||
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
|
||||
|
||||
def generate_batch(self, tokens, num_samples=1, **kwargs):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -41,12 +41,10 @@ def norm(x):
|
|||
def apply_rotary_emb(x, cos, sin):
|
||||
assert x.ndim == 4 # multihead attention
|
||||
d = x.shape[3] // 2
|
||||
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
||||
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
|
||||
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
||||
y2 = x1 * (-sin) + x2 * cos
|
||||
out = torch.cat([y1, y2], 3) # re-assemble
|
||||
out = out.to(x.dtype) # ensure input/output dtypes match
|
||||
return out
|
||||
return torch.cat([y1, y2], 3)
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
|
|
|
|||
|
|
@ -465,6 +465,22 @@ impl Tokenizer {
|
|||
|
||||
all_ids
|
||||
}
|
||||
|
||||
/// Encode multiple texts in parallel using rayon.
|
||||
/// Returns a list of token ID vectors, one per input text.
|
||||
#[pyo3(signature = (texts))]
|
||||
#[pyo3(text_signature = "(self, texts)")]
|
||||
pub fn batch_encode(&self, py: Python<'_>, texts: Vec<String>) -> PyResult<Vec<Vec<u32>>> {
|
||||
// Release Python GIL and encode in parallel using rayon
|
||||
let results = py.allow_threads(|| {
|
||||
texts
|
||||
.par_iter()
|
||||
.map(|text| self.encode(text))
|
||||
.collect::<Vec<Vec<u32>>>()
|
||||
});
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
|
|
|
|||
|
|
@ -149,6 +149,8 @@ def main():
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||
parser.add_argument('--model-tag', type=str, default=None, help='optional model tag for the output directory name')
|
||||
parser.add_argument('--step', type=str, default=None, help='optional model step for the output directory name')
|
||||
args = parser.parse_args()
|
||||
|
||||
# distributed / precision setup
|
||||
|
|
@ -166,7 +168,7 @@ def main():
|
|||
model_slug = hf_path.replace("/", "-") # for the output csv file
|
||||
else:
|
||||
# load a local model from the file system
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
model_name = f"base_model (step {meta['step']})" # just for logging
|
||||
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 -
|
|||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ from tasks.gsm8k import GSM8K
|
|||
# RL hyperparameters
|
||||
run = "dummy" # wandb run name
|
||||
source = "sft" # mid|sft
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 8 # no forward pass will go above this to not OOM
|
||||
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
|
||||
|
|
@ -64,7 +66,7 @@ use_dummy_wandb = run == "dummy" or not master_process
|
|||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
|
||||
|
||||
# Init model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval")
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -307,8 +309,8 @@ for step in range(num_steps):
|
|||
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag)
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
|
|||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
|
|
@ -250,8 +250,8 @@ for step in range(num_iterations):
|
|||
if master_process:
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_
|
|||
|
||||
from collections import deque
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
import wandb
|
||||
import torch
|
||||
|
|
@ -207,7 +207,7 @@ while True:
|
|||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step and not dry_run:
|
||||
output_dirname = f"d{depth}" # e.g. d12
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,85 @@ python -m pytest tests/test_engine.py -v
|
|||
"""
|
||||
|
||||
import torch
|
||||
from nanochat.engine import KVCache
|
||||
from nanochat.engine import KVCache, Engine
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Mock classes for testing Engine without loading a real model
|
||||
|
||||
@dataclass
|
||||
class MockConfig:
|
||||
"""Minimal config for Engine tests."""
|
||||
n_kv_head: int = 4
|
||||
n_head: int = 4
|
||||
n_embd: int = 64
|
||||
n_layer: int = 2
|
||||
sequence_len: int = 128
|
||||
|
||||
|
||||
class MockModel:
|
||||
"""
|
||||
Mock model that returns uniform logits over the vocab.
|
||||
This ensures that with temperature > 0, different samples should
|
||||
(with very high probability) produce different tokens.
|
||||
"""
|
||||
def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.config = MockConfig()
|
||||
self._device = "cpu"
|
||||
|
||||
def get_device(self):
|
||||
return self._device
|
||||
|
||||
def forward(self, ids, kv_cache=None):
|
||||
"""Return uniform logits so sampling is spread across vocab."""
|
||||
B, T = ids.shape
|
||||
# Simulate what a real transformer does: insert k,v into the cache for each layer
|
||||
if kv_cache is not None:
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
for layer_idx in range(self.config.n_layer):
|
||||
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
||||
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
||||
kv_cache.insert_kv(layer_idx, k, v)
|
||||
# Uniform logits -> equal probability for all tokens
|
||||
logits = torch.zeros(B, T, self.vocab_size)
|
||||
return logits
|
||||
|
||||
|
||||
class ByteTokenizer:
|
||||
"""
|
||||
Simple byte-level tokenizer for testing.
|
||||
Tokens 0-255 are raw bytes, 256+ are special tokens.
|
||||
"""
|
||||
def __init__(self):
|
||||
# Special tokens start at 256
|
||||
self._special_tokens = {
|
||||
"<|python_start|>": 256,
|
||||
"<|python_end|>": 257,
|
||||
"<|output_start|>": 258,
|
||||
"<|output_end|>": 259,
|
||||
"<|assistant_end|>": 260,
|
||||
"<|bos|>": 261,
|
||||
}
|
||||
self._bos = 261
|
||||
|
||||
def encode_special(self, s):
|
||||
return self._special_tokens[s]
|
||||
|
||||
def get_bos_token_id(self):
|
||||
return self._bos
|
||||
|
||||
def encode(self, s, prepend=None):
|
||||
tokens = list(s.encode("utf-8")) # bytes 0-255
|
||||
if prepend is not None:
|
||||
tokens = [prepend] + tokens
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
# Filter out special tokens before decoding
|
||||
byte_tokens = [t for t in tokens if t < 256]
|
||||
return bytes(byte_tokens).decode("utf-8", errors="replace")
|
||||
|
||||
def test_kv_cache_resize():
|
||||
"""
|
||||
|
|
@ -64,3 +142,46 @@ def test_kv_cache_resize():
|
|||
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
|
||||
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
|
||||
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
|
||||
|
||||
|
||||
def test_multi_sample_first_token_diversity():
|
||||
"""
|
||||
Test that when generating multiple samples, each sample gets an independently
|
||||
sampled first token (not a broadcast of the same token to all rows).
|
||||
|
||||
Previously, the first token after prefill was sampled once and broadcast to all
|
||||
rows, causing all samples to start identically. The fix expands the prefill logits
|
||||
to num_samples and samples independently for each row.
|
||||
|
||||
With uniform logits over 262 tokens and 16 samples, the probability that all
|
||||
samples independently pick the same token is (1/262)^15 ≈ 10^-36. So if they're
|
||||
all identical, it indicates tokens are being broadcast instead of independently sampled.
|
||||
"""
|
||||
model = MockModel(vocab_size=262)
|
||||
tokenizer = ByteTokenizer()
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
# Generate 16 samples with temperature=1.0 (stochastic sampling)
|
||||
prompt_tokens = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
||||
num_samples = 16
|
||||
|
||||
# Collect the first generated token from each sample
|
||||
first_tokens = []
|
||||
gen = engine.generate(
|
||||
prompt_tokens,
|
||||
num_samples=num_samples,
|
||||
max_tokens=1, # We only need the first token
|
||||
temperature=1.0,
|
||||
seed=42,
|
||||
)
|
||||
for token_column, token_masks in gen:
|
||||
first_tokens = token_column # This is the first (and only) yield
|
||||
|
||||
# With uniform distribution and 16 samples, they should NOT all be identical
|
||||
# If they are all identical, the bug exists (broadcasting instead of sampling)
|
||||
unique_tokens = set(first_tokens)
|
||||
assert len(unique_tokens) > 1, (
|
||||
f"All {num_samples} samples got the same first token ({first_tokens[0]}). "
|
||||
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
|
||||
f"unless tokens are being broadcast instead of independently sampled."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ python -m pytest tests/test_rustbpe.py -v -s
|
|||
import regex as re
|
||||
from collections import Counter, defaultdict
|
||||
import time
|
||||
import warnings
|
||||
import rustbpe
|
||||
import tiktoken
|
||||
import pytest
|
||||
|
|
@ -633,3 +634,85 @@ def test_interface(enwik8_small):
|
|||
ids_reloaded = tok_reloaded.encode(encode_text)
|
||||
assert ids_reloaded == ids, "Reloaded tokenizer should produce same results"
|
||||
print("✅ Save/load through temporary directory OK")
|
||||
|
||||
|
||||
def test_batch_encode_correctness(enwik8_small):
|
||||
"""Quick correctness test for batch_encode()"""
|
||||
text = enwik8_small
|
||||
vocab_size = 512
|
||||
|
||||
tokenizer = rustbpe.Tokenizer()
|
||||
tokenizer.train_from_iterator([text], vocab_size)
|
||||
|
||||
# Test with various batch sizes and edge cases
|
||||
test_texts = [
|
||||
"Hello world",
|
||||
"The quick brown fox",
|
||||
"jumps over the lazy dog",
|
||||
"", # empty string
|
||||
"a", # single char
|
||||
]
|
||||
|
||||
# Compare batch vs individual encoding
|
||||
individual = [tokenizer.encode(t) for t in test_texts]
|
||||
batched = tokenizer.batch_encode(test_texts)
|
||||
|
||||
assert individual == batched, "Batch encoding should match individual encoding"
|
||||
print("✅ batch_encode() correctness verified")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_batch_encode_performance(enwik8_large):
|
||||
"""
|
||||
Benchmark batch_encode() vs sequential encode() loop.
|
||||
Demonstrates parallelization speedup.
|
||||
"""
|
||||
# Setup
|
||||
text = enwik8_large # 10MB dataset
|
||||
vocab_size = 2048
|
||||
|
||||
# Train tokenizer
|
||||
print("\nTraining tokenizer...")
|
||||
tokenizer = rustbpe.Tokenizer()
|
||||
tokenizer.train_from_iterator([text], vocab_size)
|
||||
|
||||
# Create test batch: split text into chunks
|
||||
chunk_size = 50_000 # ~50KB per chunk
|
||||
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
|
||||
chunks = chunks[:20] # Use first 20 chunks (~1MB total)
|
||||
|
||||
print(f"\nBatch encoding benchmark:")
|
||||
print(f" Number of texts: {len(chunks)}")
|
||||
print(f" Avg text length: {sum(len(c) for c in chunks) / len(chunks):.0f} chars")
|
||||
|
||||
# Benchmark 1: Sequential encoding (baseline)
|
||||
print("\n [1/3] Sequential encode() loop...")
|
||||
sequential_results, sequential_time = time_function(
|
||||
lambda: [tokenizer.encode(chunk) for chunk in chunks]
|
||||
)
|
||||
print(f" Time: {sequential_time:.4f}s")
|
||||
|
||||
# Benchmark 2: Parallel batch_encode()
|
||||
print(" [2/3] Parallel batch_encode()...")
|
||||
batch_results, batch_time = time_function(
|
||||
tokenizer.batch_encode, chunks
|
||||
)
|
||||
print(f" Time: {batch_time:.4f}s")
|
||||
|
||||
# Verify correctness
|
||||
print(" [3/3] Verifying correctness...")
|
||||
assert len(batch_results) == len(sequential_results), "Result count mismatch"
|
||||
for i, (seq, batch) in enumerate(zip(sequential_results, batch_results)):
|
||||
assert seq == batch, f"Mismatch at index {i}"
|
||||
print(" ✓ All results match")
|
||||
|
||||
# Report speedup
|
||||
speedup = sequential_time / batch_time
|
||||
print(f"\n Performance Results:")
|
||||
print(f" Sequential: {sequential_time:.4f}s")
|
||||
print(f" Batch: {batch_time:.4f}s")
|
||||
print(f" Speedup: {speedup:.2f}x")
|
||||
|
||||
# Warn if speedup is low (can vary by machine/load)
|
||||
if speedup < 1.5:
|
||||
warnings.warn(f"batch_encode() speedup was only {speedup:.2f}x (expected >1.5x)")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user