From 9ef5c4c12c7d003a56beb6553e8dc71bb9ee656c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E5=90=8D?= Date: Mon, 27 Apr 2026 16:43:25 +0800 Subject: [PATCH] Validate engine prompt tokens --- nanochat/engine.py | 2 +- tests/test_engine.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index aa2e6a98..1291bef7 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -175,7 +175,7 @@ class Engine: @torch.inference_mode() def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42): """Same as generate, but does single prefill and then clones the KV cache.""" - assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints" + assert isinstance(tokens, list) and len(tokens) > 0 and all(isinstance(token, int) for token in tokens), "expecting non-empty list of ints" device = self.model.get_device() # NOTE: setting the dtype here and in this way is an ugly hack. # Currently the repo assumes that cuda -> bfloat16 and everything else -> float32. diff --git a/tests/test_engine.py b/tests/test_engine.py index 784ffcb9..4c61953d 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -5,6 +5,7 @@ python -m pytest tests/test_engine.py -v """ import torch +import pytest from nanochat.engine import KVCache, Engine from dataclasses import dataclass @@ -198,6 +199,16 @@ def test_multi_sample_first_token_diversity(): ) +def test_generate_rejects_invalid_prompt_tokens(): + """Prompt tokens must be a non-empty list of ints.""" + model = MockModel() + engine = Engine(model, ByteTokenizer()) + + for prompt in ([], [261, "not-an-int"]): + with pytest.raises(AssertionError, match="expecting non-empty list of ints"): + next(engine.generate(prompt, max_tokens=1)) + + def test_seed_reproducibility(): """Same seed must produce identical output.""" model = MockModel()