diff --git a/nanochat/engine.py b/nanochat/engine.py index 916a9cf..1d541c7 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -107,8 +107,9 @@ class KVCache: assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" assert other.kv_cache is not None, "Cannot prefill with a None KV cache" for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)): + # ix 0: num_layers, 1: k/v, 2: batch_size, 3: num_heads, 4: seq_len, 5: head_dim if ix in [0, 1, 3, 5]: - # num_layers, batch_size, num_heads, head_dim must match + # num_layers, k/v, num_heads, head_dim must match assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}" elif ix == 2: # batch_size can be expanded