mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-21 12:23:13 +00:00
203 lines
6.6 KiB
Python
203 lines
6.6 KiB
Python
import runpy
|
|
from contextlib import contextmanager
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import nanochat.engine as engine
|
|
|
|
|
|
def test_timeout_eval_and_calculator_paths(monkeypatch):
|
|
original_timeout = engine.timeout
|
|
with engine.timeout(1, "1+1"):
|
|
assert 1 + 1 == 2
|
|
|
|
assert engine.eval_with_timeout("1+2", max_time=1) == 3
|
|
assert engine.eval_with_timeout("bad +", max_time=1) is None
|
|
|
|
@contextmanager
|
|
def boom(_duration, _formula):
|
|
raise Exception("t")
|
|
yield # pragma: no cover
|
|
|
|
monkeypatch.setattr(engine, "timeout", boom)
|
|
assert engine.eval_with_timeout("1+1", max_time=1) is None
|
|
monkeypatch.setattr(engine, "timeout", original_timeout)
|
|
|
|
assert engine.use_calculator("1,000 + 2") == 1002
|
|
assert engine.use_calculator("2 ** 8") is None
|
|
assert engine.use_calculator("x$y") is None
|
|
assert engine.use_calculator("__import__('os')") is None
|
|
assert engine.use_calculator("'abc'") is None
|
|
assert engine.use_calculator("'ababa'.count('a')") == 3
|
|
|
|
|
|
def test_timeout_handler_line(monkeypatch):
|
|
captured = {}
|
|
monkeypatch.setattr(engine.signal, "signal", lambda _sig, fn: captured.__setitem__("fn", fn))
|
|
monkeypatch.setattr(engine.signal, "alarm", lambda _n: None)
|
|
with pytest.raises(Exception):
|
|
with engine.timeout(1, "slow_expr"):
|
|
captured["fn"](0, None)
|
|
|
|
|
|
def test_sample_next_token_branches():
|
|
logits = torch.tensor([[0.1, 0.2, 0.9]], dtype=torch.float32)
|
|
rng = torch.Generator(device="cpu").manual_seed(123)
|
|
|
|
out0 = engine.sample_next_token(logits, rng, temperature=0.0, top_k=None)
|
|
assert out0.shape == (1, 1)
|
|
assert out0.item() == 2
|
|
|
|
out1 = engine.sample_next_token(logits, rng, temperature=1.0, top_k=2)
|
|
assert out1.shape == (1, 1)
|
|
assert out1.item() in {1, 2}
|
|
|
|
out2 = engine.sample_next_token(logits, rng, temperature=1.0, top_k=None)
|
|
assert out2.shape == (1, 1)
|
|
|
|
with pytest.raises(AssertionError):
|
|
engine.sample_next_token(logits, rng, temperature=-1.0)
|
|
|
|
|
|
def test_kv_cache_prefill_assertions():
|
|
a = engine.KVCache(1, 1, 4, 2, 1, "cpu", torch.float32)
|
|
b = engine.KVCache(1, 1, 2, 2, 1, "cpu", torch.float32)
|
|
b.advance(1)
|
|
a.advance(1)
|
|
with pytest.raises(AssertionError):
|
|
a.prefill(b) # non-empty destination
|
|
|
|
c = engine.KVCache(1, 2, 2, 2, 1, "cpu", torch.float32)
|
|
d = engine.KVCache(1, 1, 2, 2, 1, "cpu", torch.float32)
|
|
with pytest.raises(AssertionError):
|
|
c.prefill(d) # head mismatch
|
|
|
|
e = engine.KVCache(1, 1, 1, 2, 1, "cpu", torch.float32)
|
|
f = engine.KVCache(1, 1, 2, 2, 1, "cpu", torch.float32)
|
|
with pytest.raises(AssertionError):
|
|
e.prefill(f) # seq len too small
|
|
|
|
|
|
class _TinyTok:
|
|
def __init__(self):
|
|
self.special = {
|
|
"<|python_start|>": 100,
|
|
"<|python_end|>": 101,
|
|
"<|output_start|>": 102,
|
|
"<|output_end|>": 103,
|
|
"<|assistant_end|>": 104,
|
|
"<|bos|>": 0,
|
|
}
|
|
|
|
def encode_special(self, s):
|
|
return self.special[s]
|
|
|
|
def get_bos_token_id(self):
|
|
return 0
|
|
|
|
def encode(self, s, prepend=None):
|
|
out = [9] if s == "2" else [55]
|
|
if prepend is not None:
|
|
return [prepend] + out
|
|
return out
|
|
|
|
def decode(self, ids):
|
|
return "1+1" if ids else ""
|
|
|
|
|
|
class _TinyModel:
|
|
def __init__(self):
|
|
self.config = SimpleNamespace(n_kv_head=1, n_head=1, n_embd=4, n_layer=1, sequence_len=32)
|
|
self._device = torch.device("cpu")
|
|
self.vocab = 200
|
|
|
|
def get_device(self):
|
|
return self._device
|
|
|
|
def forward(self, ids, kv_cache=None):
|
|
b, t = ids.shape
|
|
if kv_cache is not None:
|
|
kv_cache.advance(t)
|
|
return torch.zeros((b, t, self.vocab), dtype=torch.float32)
|
|
|
|
|
|
def test_engine_generate_tool_forcing(monkeypatch):
|
|
tok = _TinyTok()
|
|
model = _TinyModel()
|
|
eng = engine.Engine(model, tok)
|
|
|
|
sampled = [
|
|
[100], # python_start
|
|
[55], # expression token
|
|
[101], # python_end -> enqueue forced output tokens
|
|
[104], # ignored while forced queue drains
|
|
[104],
|
|
[104],
|
|
[104], # finally consumed as assistant_end
|
|
]
|
|
idx = {"i": 0}
|
|
|
|
def fake_sample(_logits, _rng, _temperature, _top_k):
|
|
row = sampled[min(idx["i"], len(sampled) - 1)]
|
|
idx["i"] += 1
|
|
return torch.tensor([[row[0]]], dtype=torch.long)
|
|
|
|
calls = {"expr": None}
|
|
monkeypatch.setattr(engine, "sample_next_token", fake_sample)
|
|
monkeypatch.setattr(engine, "use_calculator", lambda expr: calls.__setitem__("expr", expr) or 2)
|
|
|
|
rows = list(eng.generate([1, 2], num_samples=1, max_tokens=10, temperature=0.0))
|
|
flat_tokens = [r[0][0] for r in rows]
|
|
flat_masks = [r[1][0] for r in rows]
|
|
assert 102 in flat_tokens and 103 in flat_tokens # output_start/output_end forced
|
|
assert 0 in flat_masks # forced token mask
|
|
assert calls["expr"] == "1+1"
|
|
|
|
|
|
def test_engine_generate_batch_and_input_validation(monkeypatch):
|
|
tok = _TinyTok()
|
|
model = _TinyModel()
|
|
eng = engine.Engine(model, tok)
|
|
|
|
# Stop immediately via BOS; also exercises max_tokens=None kv-length hint path.
|
|
monkeypatch.setattr(
|
|
engine,
|
|
"sample_next_token",
|
|
lambda *_a, **_k: torch.tensor([[tok.get_bos_token_id()]], dtype=torch.long),
|
|
)
|
|
rows = list(eng.generate([7], num_samples=1, max_tokens=None, temperature=0.0))
|
|
assert len(rows) == 1
|
|
|
|
results, masks = eng.generate_batch([7], num_samples=1, max_tokens=3, temperature=0.0)
|
|
assert results == [[7]]
|
|
assert masks == [[0]]
|
|
|
|
with pytest.raises(AssertionError):
|
|
list(eng.generate("bad", num_samples=1, max_tokens=1)) # type: ignore[arg-type]
|
|
|
|
|
|
def test_engine_main_block_runs(monkeypatch):
|
|
# Patch external entrypoints used by the __main__ block.
|
|
monkeypatch.setattr("nanochat.common.autodetect_device_type", lambda: "cpu")
|
|
monkeypatch.setattr("nanochat.common.compute_init", lambda _d: (False, 0, 0, 1, torch.device("cpu")))
|
|
|
|
class MainTok(_TinyTok):
|
|
def encode(self, s, prepend=None):
|
|
out = [11, 12]
|
|
return ([prepend] + out) if prepend is not None else out
|
|
|
|
def decode(self, ids):
|
|
return "x"
|
|
|
|
class MainModel(_TinyModel):
|
|
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
|
del tokens, max_tokens, temperature, top_k, seed
|
|
yield 11
|
|
yield 12
|
|
|
|
monkeypatch.setattr("nanochat.checkpoint_manager.load_model", lambda *a, **k: (MainModel(), MainTok(), {}))
|
|
monkeypatch.setattr(engine.torch.cuda, "synchronize", lambda: None)
|
|
runpy.run_module("nanochat.engine", run_name="__main__")
|