This commit is contained in:
Wooram Son 2026-03-23 16:53:02 +01:00 committed by GitHub
commit 8ac78a6d54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 0 deletions

View File

@ -172,10 +172,20 @@ class Engine:
self.model = model
self.tokenizer = tokenizer # needed for tool use
def _truncate_tokens_for_context(self, tokens):
max_context = self.model.config.sequence_len
if len(tokens) <= max_context:
return tokens
bos = self.tokenizer.get_bos_token_id()
if tokens[0] == bos and max_context > 0:
return [bos] + tokens[-(max_context - 1):]
return tokens[-max_context:]
@torch.inference_mode()
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
"""Same as generate, but does single prefill and then clones the KV cache."""
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
tokens = self._truncate_tokens_for_context(tokens)
device = self.model.get_device()
# NOTE: setting the dtype here and in this way is an ugly hack.
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
@ -285,6 +295,7 @@ class Engine:
Returns a list of token sequences (list of lists of ints).
Terminal tokens (assistant_end, bos) are not included in the results.
"""
tokens = self._truncate_tokens_for_context(tokens)
assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
bos = self.tokenizer.get_bos_token_id()
results = [tokens.copy() for _ in range(num_samples)]

View File

@ -47,6 +47,24 @@ class MockModel:
return logits
class StrictLengthModel(MockModel):
"""Mock model that fails if the prompt exceeds sequence_len."""
def __init__(self, vocab_size=262, sequence_len=8):
super().__init__(vocab_size=vocab_size)
self.config = MockConfig(sequence_len=sequence_len)
self.prefill_lengths = []
self.prefill_tokens = []
def forward(self, ids, kv_cache=None):
B, T = ids.shape
if T > self.config.sequence_len:
raise RuntimeError(f"sequence length exceeded: {T} > {self.config.sequence_len}")
if kv_cache is not None and kv_cache.get_pos() == 0:
self.prefill_lengths.append(T)
self.prefill_tokens.append(ids[0].tolist())
return super().forward(ids, kv_cache=kv_cache)
class ByteTokenizer:
"""
Simple byte-level tokenizer for testing.
@ -265,3 +283,16 @@ def test_different_seeds_introduce_variation_when_temperature_nonzero():
# Sanity check: sampling actually introduces variation
assert len(outputs) > 1, "All seeds produced the same output which is statistically highly improbable."
def test_generate_truncates_overlong_prompts_to_model_context_window():
"""Overlong prompts should be truncated to the latest context window before prefill."""
model = StrictLengthModel(sequence_len=8)
engine = Engine(model, ByteTokenizer())
prompt = [261, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
results, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=1)
assert model.prefill_lengths == [8]
assert model.prefill_tokens == [[261, 4, 5, 6, 7, 8, 9, 10]]
assert results[0][:8] == [261, 4, 5, 6, 7, 8, 9, 10]