Refactor dimension validation for KV cache

This commit is contained in:
deepbuilder 2025-11-28 15:22:18 -05:00 committed by GitHub
parent a770dcef2e
commit 06677c30e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -107,17 +107,23 @@ class KVCache:
# 1) validate the shapes # 1) validate the shapes
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
assert other.kv_cache is not None, "Cannot prefill with a None KV cache" assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
# ix 0: num_layers, 1: k/v, 2: batch_size, 3: num_heads, 4: seq_len, 5: head_dim # Extract dimensions explicitly
if ix in [0, 1, 3, 5]: self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
# num_layers, k/v, num_heads, head_dim must match other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}"
elif ix == 2: # Validate dimensions
# batch_size can be expanded assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}" assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
elif ix == 4: assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
# seq_len: self must be longer than other assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
# Batch size can be expanded (other can be 1, self can be larger)
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
# Sequence length: self must be longer than other
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
# 2) initialize the cache # 2) initialize the cache
dtype, device = other.kv_cache.dtype, other.kv_cache.device dtype, device = other.kv_cache.dtype, other.kv_cache.device
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device) self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)