From a770dcef2eda7e62802cecd65e5037009132c4f9 Mon Sep 17 00:00:00 2001 From: deepbuilder <129704853+deepbuilder@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:00:14 -0500 Subject: [PATCH] Fix kv_cache indexing to explicitly include head dimension --- nanochat/engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index d749d94..d13e2b8 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -143,11 +143,11 @@ class KVCache: self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous() self.kv_shape = self.kv_cache.shape # Insert k, v into the cache - self.kv_cache[layer_idx, 0, :, :, t0:t1] = k - self.kv_cache[layer_idx, 1, :, :, t0:t1] = v + self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k + self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v # Return the full cached keys/values up to current position (as a view) - key_view = self.kv_cache[layer_idx, 0, :, :, :t1] - value_view = self.kv_cache[layer_idx, 1, :, :, :t1] + key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :] + value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :] # Increment pos after the last layer of the Transformer processes if layer_idx == self.kv_cache.size(0) - 1: self.pos = t1