mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
405 lines
17 KiB
Python
405 lines
17 KiB
Python
"""
|
|
Engine for efficient inference of our models.
|
|
|
|
Everything works around token sequences:
|
|
- The user can send token sequences to the engine
|
|
- The engine returns the next token
|
|
|
|
Notes:
|
|
- The engine knows nothing about tokenization, it's purely token id sequences.
|
|
|
|
The whole thing is made as efficient as possible.
|
|
"""
|
|
|
|
import signal
|
|
import warnings
|
|
from collections import deque
|
|
from contextlib import contextmanager, nullcontext
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from nanochat.checkpoint_manager import load_model
|
|
from nanochat.common import autodetect_device_type, compute_init
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Calculator tool helpers
|
|
@contextmanager
|
|
def timeout(duration, formula):
|
|
def timeout_handler(signum, frame):
|
|
raise Exception(f"'{formula}': timed out after {duration} seconds")
|
|
|
|
signal.signal(signal.SIGALRM, timeout_handler)
|
|
signal.alarm(duration)
|
|
yield
|
|
signal.alarm(0)
|
|
|
|
|
|
def eval_with_timeout(formula, max_time=3):
|
|
try:
|
|
with timeout(max_time, formula):
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore", SyntaxWarning)
|
|
return eval(formula, {"__builtins__": {}}, {})
|
|
except Exception:
|
|
signal.alarm(0)
|
|
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
|
|
return None
|
|
|
|
|
|
def use_calculator(expr):
|
|
"""
|
|
Evaluate a Python expression safely.
|
|
Supports both math expressions and string operations like .count()
|
|
"""
|
|
# Remove commas from numbers
|
|
expr = expr.replace(",", "")
|
|
|
|
# Check if it's a pure math expression (old behavior)
|
|
if all([x in "0123456789*+-/.() " for x in expr]):
|
|
if "**" in expr: # disallow power operator
|
|
return None
|
|
return eval_with_timeout(expr)
|
|
|
|
# Check if it's a string operation we support
|
|
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
|
|
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
|
if not all([x in allowed_chars for x in expr]):
|
|
return None
|
|
|
|
# Disallow dangerous patterns
|
|
dangerous_patterns = [
|
|
'__',
|
|
'import',
|
|
'exec',
|
|
'eval',
|
|
'compile',
|
|
'open',
|
|
'file',
|
|
'input',
|
|
'raw_input',
|
|
'globals',
|
|
'locals',
|
|
'vars',
|
|
'dir',
|
|
'getattr',
|
|
'setattr',
|
|
'delattr',
|
|
'hasattr',
|
|
]
|
|
expr_lower = expr.lower()
|
|
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
|
return None
|
|
|
|
# Only allow .count() method for now (can expand later)
|
|
if '.count(' not in expr:
|
|
return None
|
|
|
|
# Evaluate with timeout
|
|
return eval_with_timeout(expr)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class KVCache:
|
|
"""
|
|
Works hand-in-hand with the GPT model to maintain the KV cache.
|
|
Note that the .pos advances automatically after the last layer of the Transformer inserts.
|
|
"""
|
|
|
|
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
|
|
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
|
|
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
|
|
self.kv_cache = None
|
|
self.pos = 0 # current position in time in the cache
|
|
|
|
def reset(self):
|
|
self.pos = 0
|
|
|
|
def get_pos(self):
|
|
return self.pos
|
|
|
|
def prefill(self, other):
|
|
"""
|
|
Prefill given another KV cache. Optionally expand along batch dim.
|
|
This is used when we do batch 1 prefill and then want to generate
|
|
multiple samples in parallel from there.
|
|
"""
|
|
# 1) validate the shapes
|
|
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
|
|
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
|
|
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
|
|
# ix 0: num_layers, 1: k/v, 2: batch_size, 3: num_heads, 4: seq_len, 5: head_dim
|
|
if ix in [0, 1, 3, 5]:
|
|
# num_layers, k/v, num_heads, head_dim must match
|
|
assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}"
|
|
elif ix == 2:
|
|
# batch_size can be expanded
|
|
assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
|
|
elif ix == 4:
|
|
# seq_len: self must be longer than other
|
|
assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
|
|
# 2) initialize the cache
|
|
dtype, device = other.kv_cache.dtype, other.kv_cache.device
|
|
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
|
|
# 3) copy the data over
|
|
self.kv_cache[:, :, :, :, : other.pos, :] = other.kv_cache
|
|
# 4) update the pos
|
|
self.pos = other.pos
|
|
|
|
def insert_kv(self, layer_idx, k, v):
|
|
# Lazy initialize the cache here because we need to know the dtype/device
|
|
if self.kv_cache is None:
|
|
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
|
|
# Insert new keys/values to the cache and return the full cache so far
|
|
B, H, T_add, D = k.size()
|
|
t0, t1 = self.pos, self.pos + T_add
|
|
# Dynamically grow the cache if needed
|
|
if t1 > self.kv_cache.size(4):
|
|
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
|
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
|
additional_shape = list(self.kv_cache.shape)
|
|
additional_shape[4] = t_needed - self.kv_cache.size(4)
|
|
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
|
|
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
|
|
self.kv_shape = self.kv_cache.shape
|
|
# Insert k, v into the cache
|
|
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
|
|
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
|
|
# Return the full cached keys/values up to current position (as a view)
|
|
key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
|
|
value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
|
|
# Increment pos after the last layer of the Transformer processes
|
|
if layer_idx == self.kv_cache.size(0) - 1:
|
|
self.pos = t1
|
|
return key_view, value_view
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@torch.inference_mode()
|
|
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
|
|
"""Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
|
|
assert temperature >= 0.0, "temperature must be non-negative"
|
|
if temperature == 0.0:
|
|
return torch.argmax(logits, dim=-1, keepdim=True)
|
|
if top_k is not None:
|
|
k = min(top_k, logits.size(-1))
|
|
vals, idx = torch.topk(logits, k, dim=-1)
|
|
vals = vals / temperature
|
|
probs = F.softmax(vals, dim=-1)
|
|
choice = torch.multinomial(probs, num_samples=1, generator=rng)
|
|
return idx.gather(1, choice)
|
|
else:
|
|
logits = logits / temperature
|
|
probs = F.softmax(logits, dim=-1)
|
|
return torch.multinomial(probs, num_samples=1, generator=rng)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
class RowState:
|
|
# Per-row state tracking during generation
|
|
def __init__(self, current_tokens=None):
|
|
self.current_tokens = current_tokens or [] # Current token sequence for this row
|
|
self.forced_tokens = deque() # Queue of tokens to force inject
|
|
self.in_python_block = False # Whether we are inside a python block
|
|
self.python_expr_tokens = [] # Tokens of the current python expression
|
|
self.completed = False # Whether this row has completed generation
|
|
|
|
|
|
class Engine:
|
|
def __init__(self, model, tokenizer):
|
|
self.model = model
|
|
self.tokenizer = tokenizer # needed for tool use
|
|
|
|
@torch.inference_mode()
|
|
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
|
"""Same as generate, but does single prefill and then clones the KV cache."""
|
|
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
|
device = self.model.get_device()
|
|
rng = torch.Generator(device=device)
|
|
rng.manual_seed(seed)
|
|
|
|
# Get the special tokens we need to coordinate the tool use state machine
|
|
get_special = lambda s: self.tokenizer.encode_special(s)
|
|
python_start = get_special("<|python_start|>")
|
|
python_end = get_special("<|python_end|>")
|
|
output_start = get_special("<|output_start|>")
|
|
output_end = get_special("<|output_end|>")
|
|
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
|
|
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
|
|
|
|
# 1) Run a batch 1 prefill of the prompt tokens
|
|
m = self.model.config
|
|
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
|
|
kv_cache_prefill = KVCache(
|
|
batch_size=1,
|
|
seq_len=len(tokens),
|
|
**kv_model_kwargs,
|
|
)
|
|
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
|
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
|
logits = logits[:, -1, :]
|
|
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
|
sampled_tokens = next_ids[:, 0].tolist()
|
|
|
|
# 2) Replicate the KV cache for each sample/row
|
|
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
|
kv_cache_decode = KVCache(
|
|
batch_size=num_samples,
|
|
seq_len=kv_length_hint,
|
|
**kv_model_kwargs,
|
|
)
|
|
kv_cache_decode.prefill(kv_cache_prefill)
|
|
del kv_cache_prefill # no need to keep this memory around
|
|
|
|
# 3) Initialize states for each sample
|
|
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
|
|
|
|
# 4) Main generation loop
|
|
num_generated = 0
|
|
first_iteration = True
|
|
while True:
|
|
# Stop condition: we've reached max tokens
|
|
if max_tokens is not None and num_generated >= max_tokens:
|
|
break
|
|
# Stop condition: all rows are completed
|
|
if all(state.completed for state in row_states):
|
|
break
|
|
|
|
# Get sampled tokens - either from prefill or from forward pass
|
|
if first_iteration:
|
|
# Use the tokens we already sampled from prefill
|
|
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
|
|
# TODO: we should sample a token for each row instead of broadcasting
|
|
first_iteration = False
|
|
else:
|
|
# Forward the model and get the next token for each row
|
|
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
|
|
logits = logits[:, -1, :] # (B, vocab_size) at last time step
|
|
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
|
sampled_tokens = next_ids[:, 0].tolist()
|
|
|
|
# Process each row: choose the next token, update state, optional tool use
|
|
token_column = [] # contains the next token id along each row
|
|
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
|
|
for i, state in enumerate(row_states):
|
|
# Select the next token in this row
|
|
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
|
|
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
|
|
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
|
|
token_column.append(next_token)
|
|
# Update the state of this row to include the next token
|
|
state.current_tokens.append(next_token)
|
|
# On <|assistant_end|> or <|bos|>, mark the row as completed
|
|
if next_token == assistant_end or next_token == bos:
|
|
state.completed = True
|
|
# Handle tool logic
|
|
if next_token == python_start:
|
|
state.in_python_block = True
|
|
state.python_expr_tokens = []
|
|
elif next_token == python_end and state.in_python_block:
|
|
state.in_python_block = False
|
|
if state.python_expr_tokens:
|
|
expr = self.tokenizer.decode(state.python_expr_tokens)
|
|
result = use_calculator(expr)
|
|
if result is not None:
|
|
result_tokens = self.tokenizer.encode(str(result))
|
|
state.forced_tokens.append(output_start)
|
|
state.forced_tokens.extend(result_tokens)
|
|
state.forced_tokens.append(output_end)
|
|
state.python_expr_tokens = []
|
|
elif state.in_python_block:
|
|
state.python_expr_tokens.append(next_token)
|
|
|
|
# Yield the token column
|
|
yield token_column, token_masks
|
|
num_generated += 1
|
|
# Prepare ids for next iteration
|
|
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
|
|
|
|
def generate_batch(self, tokens, num_samples=1, **kwargs):
|
|
"""
|
|
Non-streaming batch generation that just returns the final token sequences.
|
|
Returns a list of token sequences (list of lists of ints).
|
|
Terminal tokens (assistant_end, bos) are not included in the results.
|
|
"""
|
|
assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
|
|
bos = self.tokenizer.get_bos_token_id()
|
|
results = [tokens.copy() for _ in range(num_samples)]
|
|
masks = [[0] * len(tokens) for _ in range(num_samples)]
|
|
completed = [False] * num_samples
|
|
for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
|
|
for i, (token, mask) in enumerate(zip(token_column, token_masks)):
|
|
if not completed[i]:
|
|
if token == assistant_end or token == bos:
|
|
completed[i] = True
|
|
else:
|
|
results[i].append(token)
|
|
masks[i].append(mask)
|
|
# Stop if all rows are completed
|
|
if all(completed):
|
|
break
|
|
return results, masks
|
|
|
|
|
|
if __name__ == "__main__":
|
|
"""
|
|
Quick inline test to make sure that the naive/slow model.generate function
|
|
is equivalent to the faster Engine.generate function here.
|
|
"""
|
|
import time
|
|
|
|
# init compute
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
|
device_type = autodetect_device_type()
|
|
autocast_ctx = (
|
|
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
|
)
|
|
|
|
# load the model and tokenizer
|
|
model, tokenizer, meta = load_model("base", device, phase="eval")
|
|
bos_token_id = tokenizer.get_bos_token_id()
|
|
# common hyperparameters
|
|
kwargs = dict(max_tokens=64, temperature=0.0)
|
|
# set the starting prompt
|
|
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
|
|
# generate the reference sequence using the model.generate() function
|
|
generated_tokens = []
|
|
torch.cuda.synchronize()
|
|
t0 = time.time()
|
|
stream = model.generate(prompt_tokens, **kwargs)
|
|
with autocast_ctx:
|
|
for token in stream:
|
|
generated_tokens.append(token)
|
|
chunk = tokenizer.decode([token])
|
|
print(chunk, end="", flush=True)
|
|
print()
|
|
torch.cuda.synchronize()
|
|
t1 = time.time()
|
|
print(f"Reference time: {t1 - t0:.2f}s")
|
|
reference_ids = generated_tokens
|
|
# generate tokens with Engine
|
|
generated_tokens = []
|
|
engine = Engine(model, tokenizer)
|
|
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
|
torch.cuda.synchronize()
|
|
t0 = time.time()
|
|
with autocast_ctx:
|
|
for token_column, token_masks in stream:
|
|
token = token_column[0] # only print out the first row
|
|
generated_tokens.append(token)
|
|
chunk = tokenizer.decode([token])
|
|
print(chunk, end="", flush=True)
|
|
print()
|
|
torch.cuda.synchronize()
|
|
t1 = time.time()
|
|
print(f"Engine time: {t1 - t0:.2f}s")
|
|
# compare the two sequences
|
|
for i in range(len(reference_ids)):
|
|
if reference_ids[i] != generated_tokens[i]:
|
|
print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
|
|
break
|
|
print(f"Match: {reference_ids == generated_tokens}")
|