From a770dcef2eda7e62802cecd65e5037009132c4f9 Mon Sep 17 00:00:00 2001 From: deepbuilder <129704853+deepbuilder@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:00:14 -0500 Subject: [PATCH 1/2] Fix kv_cache indexing to explicitly include head dimension --- nanochat/engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index d749d94..d13e2b8 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -143,11 +143,11 @@ class KVCache: self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous() self.kv_shape = self.kv_cache.shape # Insert k, v into the cache - self.kv_cache[layer_idx, 0, :, :, t0:t1] = k - self.kv_cache[layer_idx, 1, :, :, t0:t1] = v + self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k + self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v # Return the full cached keys/values up to current position (as a view) - key_view = self.kv_cache[layer_idx, 0, :, :, :t1] - value_view = self.kv_cache[layer_idx, 1, :, :, :t1] + key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :] + value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :] # Increment pos after the last layer of the Transformer processes if layer_idx == self.kv_cache.size(0) - 1: self.pos = t1 From 06677c30e01624fbfc502d1a90bc8b4d5c1e442e Mon Sep 17 00:00:00 2001 From: deepbuilder <129704853+deepbuilder@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:22:18 -0500 Subject: [PATCH 2/2] Refactor dimension validation for KV cache --- nanochat/engine.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index d13e2b8..dc43faf 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -107,17 +107,23 @@ class KVCache: # 1) validate the shapes assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" assert other.kv_cache is not None, "Cannot prefill with a None KV cache" - for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)): - # ix 0: num_layers, 1: k/v, 2: batch_size, 3: num_heads, 4: seq_len, 5: head_dim - if ix in [0, 1, 3, 5]: - # num_layers, k/v, num_heads, head_dim must match - assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}" - elif ix == 2: - # batch_size can be expanded - assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}" - elif ix == 4: - # seq_len: self must be longer than other - assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}" + + # Extract dimensions explicitly + self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape + other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape + + # Validate dimensions + assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}" + assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}" + assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}" + assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}" + + # Batch size can be expanded (other can be 1, self can be larger) + assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)" + + # Sequence length: self must be longer than other + assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}" + # 2) initialize the cache dtype, device = other.kv_cache.dtype, other.kv_cache.device self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)