This commit is contained in:
deepbuilder 2025-11-28 15:22:42 -05:00 committed by GitHub
commit 606314b05e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -107,17 +107,23 @@ class KVCache:
# 1) validate the shapes
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
# ix 0: num_layers, 1: k/v, 2: batch_size, 3: num_heads, 4: seq_len, 5: head_dim
if ix in [0, 1, 3, 5]:
# num_layers, k/v, num_heads, head_dim must match
assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}"
elif ix == 2:
# batch_size can be expanded
assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
elif ix == 4:
# seq_len: self must be longer than other
assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
# Extract dimensions explicitly
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
# Validate dimensions
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
# Batch size can be expanded (other can be 1, self can be larger)
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
# Sequence length: self must be longer than other
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
# 2) initialize the cache
dtype, device = other.kv_cache.dtype, other.kv_cache.device
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
@ -143,11 +149,11 @@ class KVCache:
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
self.kv_shape = self.kv_cache.shape
# Insert k, v into the cache
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k
self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v
# Return the full cached keys/values up to current position (as a view)
key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]
# Increment pos after the last layer of the Transformer processes
if layer_idx == self.kv_cache.size(0) - 1:
self.pos = t1