diff --git a/nanochat/engine.py b/nanochat/engine.py index d749d94..d13e2b8 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -143,11 +143,11 @@ class KVCache: self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous() self.kv_shape = self.kv_cache.shape # Insert k, v into the cache - self.kv_cache[layer_idx, 0, :, :, t0:t1] = k - self.kv_cache[layer_idx, 1, :, :, t0:t1] = v + self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k + self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v # Return the full cached keys/values up to current position (as a view) - key_view = self.kv_cache[layer_idx, 0, :, :, :t1] - value_view = self.kv_cache[layer_idx, 1, :, :, :t1] + key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :] + value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :] # Increment pos after the last layer of the Transformer processes if layer_idx == self.kv_cache.size(0) - 1: self.pos = t1