Validate engine prompt tokens

This commit is contained in:
陈家名 2026-04-27 16:43:25 +08:00
parent 0aaca56805
commit 9ef5c4c12c
2 changed files with 12 additions and 1 deletions

View File

@ -175,7 +175,7 @@ class Engine:
@torch.inference_mode()
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
"""Same as generate, but does single prefill and then clones the KV cache."""
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
assert isinstance(tokens, list) and len(tokens) > 0 and all(isinstance(token, int) for token in tokens), "expecting non-empty list of ints"
device = self.model.get_device()
# NOTE: setting the dtype here and in this way is an ugly hack.
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.

View File

@ -5,6 +5,7 @@ python -m pytest tests/test_engine.py -v
"""
import torch
import pytest
from nanochat.engine import KVCache, Engine
from dataclasses import dataclass
@ -198,6 +199,16 @@ def test_multi_sample_first_token_diversity():
)
def test_generate_rejects_invalid_prompt_tokens():
"""Prompt tokens must be a non-empty list of ints."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
for prompt in ([], [261, "not-an-int"]):
with pytest.raises(AssertionError, match="expecting non-empty list of ints"):
next(engine.generate(prompt, max_tokens=1))
def test_seed_reproducibility():
"""Same seed must produce identical output."""
model = MockModel()