nanochat/tests/test_engine_extra_new.py

203 lines
6.6 KiB
Python

import runpy
from contextlib import contextmanager
from types import SimpleNamespace
import pytest
import torch
import nanochat.engine as engine
def test_timeout_eval_and_calculator_paths(monkeypatch):
original_timeout = engine.timeout
with engine.timeout(1, "1+1"):
assert 1 + 1 == 2
assert engine.eval_with_timeout("1+2", max_time=1) == 3
assert engine.eval_with_timeout("bad +", max_time=1) is None
@contextmanager
def boom(_duration, _formula):
raise Exception("t")
yield # pragma: no cover
monkeypatch.setattr(engine, "timeout", boom)
assert engine.eval_with_timeout("1+1", max_time=1) is None
monkeypatch.setattr(engine, "timeout", original_timeout)
assert engine.use_calculator("1,000 + 2") == 1002
assert engine.use_calculator("2 ** 8") is None
assert engine.use_calculator("x$y") is None
assert engine.use_calculator("__import__('os')") is None
assert engine.use_calculator("'abc'") is None
assert engine.use_calculator("'ababa'.count('a')") == 3
def test_timeout_handler_line(monkeypatch):
captured = {}
monkeypatch.setattr(engine.signal, "signal", lambda _sig, fn: captured.__setitem__("fn", fn))
monkeypatch.setattr(engine.signal, "alarm", lambda _n: None)
with pytest.raises(Exception):
with engine.timeout(1, "slow_expr"):
captured["fn"](0, None)
def test_sample_next_token_branches():
logits = torch.tensor([[0.1, 0.2, 0.9]], dtype=torch.float32)
rng = torch.Generator(device="cpu").manual_seed(123)
out0 = engine.sample_next_token(logits, rng, temperature=0.0, top_k=None)
assert out0.shape == (1, 1)
assert out0.item() == 2
out1 = engine.sample_next_token(logits, rng, temperature=1.0, top_k=2)
assert out1.shape == (1, 1)
assert out1.item() in {1, 2}
out2 = engine.sample_next_token(logits, rng, temperature=1.0, top_k=None)
assert out2.shape == (1, 1)
with pytest.raises(AssertionError):
engine.sample_next_token(logits, rng, temperature=-1.0)
def test_kv_cache_prefill_assertions():
a = engine.KVCache(1, 1, 4, 2, 1, "cpu", torch.float32)
b = engine.KVCache(1, 1, 2, 2, 1, "cpu", torch.float32)
b.advance(1)
a.advance(1)
with pytest.raises(AssertionError):
a.prefill(b) # non-empty destination
c = engine.KVCache(1, 2, 2, 2, 1, "cpu", torch.float32)
d = engine.KVCache(1, 1, 2, 2, 1, "cpu", torch.float32)
with pytest.raises(AssertionError):
c.prefill(d) # head mismatch
e = engine.KVCache(1, 1, 1, 2, 1, "cpu", torch.float32)
f = engine.KVCache(1, 1, 2, 2, 1, "cpu", torch.float32)
with pytest.raises(AssertionError):
e.prefill(f) # seq len too small
class _TinyTok:
def __init__(self):
self.special = {
"<|python_start|>": 100,
"<|python_end|>": 101,
"<|output_start|>": 102,
"<|output_end|>": 103,
"<|assistant_end|>": 104,
"<|bos|>": 0,
}
def encode_special(self, s):
return self.special[s]
def get_bos_token_id(self):
return 0
def encode(self, s, prepend=None):
out = [9] if s == "2" else [55]
if prepend is not None:
return [prepend] + out
return out
def decode(self, ids):
return "1+1" if ids else ""
class _TinyModel:
def __init__(self):
self.config = SimpleNamespace(n_kv_head=1, n_head=1, n_embd=4, n_layer=1, sequence_len=32)
self._device = torch.device("cpu")
self.vocab = 200
def get_device(self):
return self._device
def forward(self, ids, kv_cache=None):
b, t = ids.shape
if kv_cache is not None:
kv_cache.advance(t)
return torch.zeros((b, t, self.vocab), dtype=torch.float32)
def test_engine_generate_tool_forcing(monkeypatch):
tok = _TinyTok()
model = _TinyModel()
eng = engine.Engine(model, tok)
sampled = [
[100], # python_start
[55], # expression token
[101], # python_end -> enqueue forced output tokens
[104], # ignored while forced queue drains
[104],
[104],
[104], # finally consumed as assistant_end
]
idx = {"i": 0}
def fake_sample(_logits, _rng, _temperature, _top_k):
row = sampled[min(idx["i"], len(sampled) - 1)]
idx["i"] += 1
return torch.tensor([[row[0]]], dtype=torch.long)
calls = {"expr": None}
monkeypatch.setattr(engine, "sample_next_token", fake_sample)
monkeypatch.setattr(engine, "use_calculator", lambda expr: calls.__setitem__("expr", expr) or 2)
rows = list(eng.generate([1, 2], num_samples=1, max_tokens=10, temperature=0.0))
flat_tokens = [r[0][0] for r in rows]
flat_masks = [r[1][0] for r in rows]
assert 102 in flat_tokens and 103 in flat_tokens # output_start/output_end forced
assert 0 in flat_masks # forced token mask
assert calls["expr"] == "1+1"
def test_engine_generate_batch_and_input_validation(monkeypatch):
tok = _TinyTok()
model = _TinyModel()
eng = engine.Engine(model, tok)
# Stop immediately via BOS; also exercises max_tokens=None kv-length hint path.
monkeypatch.setattr(
engine,
"sample_next_token",
lambda *_a, **_k: torch.tensor([[tok.get_bos_token_id()]], dtype=torch.long),
)
rows = list(eng.generate([7], num_samples=1, max_tokens=None, temperature=0.0))
assert len(rows) == 1
results, masks = eng.generate_batch([7], num_samples=1, max_tokens=3, temperature=0.0)
assert results == [[7]]
assert masks == [[0]]
with pytest.raises(AssertionError):
list(eng.generate("bad", num_samples=1, max_tokens=1)) # type: ignore[arg-type]
def test_engine_main_block_runs(monkeypatch):
# Patch external entrypoints used by the __main__ block.
monkeypatch.setattr("nanochat.common.autodetect_device_type", lambda: "cpu")
monkeypatch.setattr("nanochat.common.compute_init", lambda _d: (False, 0, 0, 1, torch.device("cpu")))
class MainTok(_TinyTok):
def encode(self, s, prepend=None):
out = [11, 12]
return ([prepend] + out) if prepend is not None else out
def decode(self, ids):
return "x"
class MainModel(_TinyModel):
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
del tokens, max_tokens, temperature, top_k, seed
yield 11
yield 12
monkeypatch.setattr("nanochat.checkpoint_manager.load_model", lambda *a, **k: (MainModel(), MainTok(), {}))
monkeypatch.setattr(engine.torch.cuda, "synchronize", lambda: None)
runpy.run_module("nanochat.engine", run_name="__main__")