mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-08 16:59:59 +00:00
Validate engine prompt tokens
This commit is contained in:
parent
0aaca56805
commit
9ef5c4c12c
|
|
@ -175,7 +175,7 @@ class Engine:
|
|||
@torch.inference_mode()
|
||||
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
||||
"""Same as generate, but does single prefill and then clones the KV cache."""
|
||||
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
||||
assert isinstance(tokens, list) and len(tokens) > 0 and all(isinstance(token, int) for token in tokens), "expecting non-empty list of ints"
|
||||
device = self.model.get_device()
|
||||
# NOTE: setting the dtype here and in this way is an ugly hack.
|
||||
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ python -m pytest tests/test_engine.py -v
|
|||
"""
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
from nanochat.engine import KVCache, Engine
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
|
@ -198,6 +199,16 @@ def test_multi_sample_first_token_diversity():
|
|||
)
|
||||
|
||||
|
||||
def test_generate_rejects_invalid_prompt_tokens():
|
||||
"""Prompt tokens must be a non-empty list of ints."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
|
||||
for prompt in ([], [261, "not-an-int"]):
|
||||
with pytest.raises(AssertionError, match="expecting non-empty list of ints"):
|
||||
next(engine.generate(prompt, max_tokens=1))
|
||||
|
||||
|
||||
def test_seed_reproducibility():
|
||||
"""Same seed must produce identical output."""
|
||||
model = MockModel()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user