diff --git a/nanochat/engine.py b/nanochat/engine.py index 4724c8f..e5e33d0 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -115,7 +115,11 @@ class KVCache: def advance(self, num_tokens): """Advance the cache position by num_tokens.""" - self.cache_seqlens += num_tokens + # Validate that we don't exceed max sequence length + new_seqlens = self.cache_seqlens + num_tokens + if torch.any(new_seqlens > self.max_seq_len): + raise ValueError(f"Cache overflow: attempted to advance beyond max_seq_len={self.max_seq_len}") + self.cache_seqlens.copy_(new_seqlens) def prefill(self, other): """