diff --git a/nanochat/engine.py b/nanochat/engine.py index 4724c8f..8942a46 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -166,10 +166,20 @@ class Engine: self.model = model self.tokenizer = tokenizer # needed for tool use + def _truncate_tokens_for_context(self, tokens): + max_context = self.model.config.sequence_len + if len(tokens) <= max_context: + return tokens + bos = self.tokenizer.get_bos_token_id() + if tokens[0] == bos and max_context > 0: + return [bos] + tokens[-(max_context - 1):] + return tokens[-max_context:] + @torch.inference_mode() def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42): """Same as generate, but does single prefill and then clones the KV cache.""" assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints" + tokens = self._truncate_tokens_for_context(tokens) device = self.model.get_device() # NOTE: setting the dtype here and in this way is an ugly hack. # Currently the repo assumes that cuda -> bfloat16 and everything else -> float32. @@ -279,6 +289,7 @@ class Engine: Returns a list of token sequences (list of lists of ints). Terminal tokens (assistant_end, bos) are not included in the results. """ + tokens = self._truncate_tokens_for_context(tokens) assistant_end = self.tokenizer.encode_special("<|assistant_end|>") bos = self.tokenizer.get_bos_token_id() results = [tokens.copy() for _ in range(num_samples)] diff --git a/tests/test_engine.py b/tests/test_engine.py index 784ffcb..58407a9 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -47,6 +47,24 @@ class MockModel: return logits +class StrictLengthModel(MockModel): + """Mock model that fails if the prompt exceeds sequence_len.""" + def __init__(self, vocab_size=262, sequence_len=8): + super().__init__(vocab_size=vocab_size) + self.config = MockConfig(sequence_len=sequence_len) + self.prefill_lengths = [] + self.prefill_tokens = [] + + def forward(self, ids, kv_cache=None): + B, T = ids.shape + if T > self.config.sequence_len: + raise RuntimeError(f"sequence length exceeded: {T} > {self.config.sequence_len}") + if kv_cache is not None and kv_cache.get_pos() == 0: + self.prefill_lengths.append(T) + self.prefill_tokens.append(ids[0].tolist()) + return super().forward(ids, kv_cache=kv_cache) + + class ByteTokenizer: """ Simple byte-level tokenizer for testing. @@ -265,3 +283,16 @@ def test_different_seeds_introduce_variation_when_temperature_nonzero(): # Sanity check: sampling actually introduces variation assert len(outputs) > 1, "All seeds produced the same output which is statistically highly improbable." + + +def test_generate_truncates_overlong_prompts_to_model_context_window(): + """Overlong prompts should be truncated to the latest context window before prefill.""" + model = StrictLengthModel(sequence_len=8) + engine = Engine(model, ByteTokenizer()) + prompt = [261, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + results, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=1) + + assert model.prefill_lengths == [8] + assert model.prefill_tokens == [[261, 4, 5, 6, 7, 8, 9, 10]] + assert results[0][:8] == [261, 4, 5, 6, 7, 8, 9, 10]