mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 13:15:21 +00:00
Merge ee9e262adc into 5019accc5b
This commit is contained in:
commit
8ac78a6d54
|
|
@ -172,10 +172,20 @@ class Engine:
|
|||
self.model = model
|
||||
self.tokenizer = tokenizer # needed for tool use
|
||||
|
||||
def _truncate_tokens_for_context(self, tokens):
|
||||
max_context = self.model.config.sequence_len
|
||||
if len(tokens) <= max_context:
|
||||
return tokens
|
||||
bos = self.tokenizer.get_bos_token_id()
|
||||
if tokens[0] == bos and max_context > 0:
|
||||
return [bos] + tokens[-(max_context - 1):]
|
||||
return tokens[-max_context:]
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
||||
"""Same as generate, but does single prefill and then clones the KV cache."""
|
||||
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
||||
tokens = self._truncate_tokens_for_context(tokens)
|
||||
device = self.model.get_device()
|
||||
# NOTE: setting the dtype here and in this way is an ugly hack.
|
||||
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
|
||||
|
|
@ -285,6 +295,7 @@ class Engine:
|
|||
Returns a list of token sequences (list of lists of ints).
|
||||
Terminal tokens (assistant_end, bos) are not included in the results.
|
||||
"""
|
||||
tokens = self._truncate_tokens_for_context(tokens)
|
||||
assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
|
||||
bos = self.tokenizer.get_bos_token_id()
|
||||
results = [tokens.copy() for _ in range(num_samples)]
|
||||
|
|
|
|||
|
|
@ -47,6 +47,24 @@ class MockModel:
|
|||
return logits
|
||||
|
||||
|
||||
class StrictLengthModel(MockModel):
|
||||
"""Mock model that fails if the prompt exceeds sequence_len."""
|
||||
def __init__(self, vocab_size=262, sequence_len=8):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
self.config = MockConfig(sequence_len=sequence_len)
|
||||
self.prefill_lengths = []
|
||||
self.prefill_tokens = []
|
||||
|
||||
def forward(self, ids, kv_cache=None):
|
||||
B, T = ids.shape
|
||||
if T > self.config.sequence_len:
|
||||
raise RuntimeError(f"sequence length exceeded: {T} > {self.config.sequence_len}")
|
||||
if kv_cache is not None and kv_cache.get_pos() == 0:
|
||||
self.prefill_lengths.append(T)
|
||||
self.prefill_tokens.append(ids[0].tolist())
|
||||
return super().forward(ids, kv_cache=kv_cache)
|
||||
|
||||
|
||||
class ByteTokenizer:
|
||||
"""
|
||||
Simple byte-level tokenizer for testing.
|
||||
|
|
@ -265,3 +283,16 @@ def test_different_seeds_introduce_variation_when_temperature_nonzero():
|
|||
|
||||
# Sanity check: sampling actually introduces variation
|
||||
assert len(outputs) > 1, "All seeds produced the same output which is statistically highly improbable."
|
||||
|
||||
|
||||
def test_generate_truncates_overlong_prompts_to_model_context_window():
|
||||
"""Overlong prompts should be truncated to the latest context window before prefill."""
|
||||
model = StrictLengthModel(sequence_len=8)
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
|
||||
results, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=1)
|
||||
|
||||
assert model.prefill_lengths == [8]
|
||||
assert model.prefill_tokens == [[261, 4, 5, 6, 7, 8, 9, 10]]
|
||||
assert results[0][:8] == [261, 4, 5, 6, 7, 8, 9, 10]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user