diff --git a/nanochat/engine.py b/nanochat/engine.py index f9c5d9a..53ced30 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -106,7 +106,12 @@ class KVCache: def insert_kv(self, layer_idx, k, v): # Lazy initialize the cache here because we need to know the dtype/device if self.kv_cache is None: + # Pre-allocate a larger cache to avoid frequent resizing + self.kv_shape = list(self.kv_shape) + self.kv_shape[4] *= 2 # Double the sequence length for pre-allocation self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device) + self.kv_shape[4] //= 2 # a bit of a hack to restore the original shape for future checks + # Insert new keys/values to the cache and return the full cache so far B, H, T_add, D = k.size() t0, t1 = self.pos, self.pos + T_add