diff --git a/nanochat/engine.py b/nanochat/engine.py index d13e2b8..dc43faf 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -107,17 +107,23 @@ class KVCache: # 1) validate the shapes assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" assert other.kv_cache is not None, "Cannot prefill with a None KV cache" - for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)): - # ix 0: num_layers, 1: k/v, 2: batch_size, 3: num_heads, 4: seq_len, 5: head_dim - if ix in [0, 1, 3, 5]: - # num_layers, k/v, num_heads, head_dim must match - assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}" - elif ix == 2: - # batch_size can be expanded - assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}" - elif ix == 4: - # seq_len: self must be longer than other - assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}" + + # Extract dimensions explicitly + self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape + other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape + + # Validate dimensions + assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}" + assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}" + assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}" + assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}" + + # Batch size can be expanded (other can be 1, self can be larger) + assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)" + + # Sequence length: self must be longer than other + assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}" + # 2) initialize the cache dtype, device = other.kv_cache.dtype, other.kv_cache.device self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)