Fix kv_cache indexing to explicitly include head dimension

This commit is contained in:
deepbuilder 2025-11-28 15:00:14 -05:00 committed by GitHub
parent 4a87a0d19f
commit a770dcef2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -143,11 +143,11 @@ class KVCache:
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous() self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
self.kv_shape = self.kv_cache.shape self.kv_shape = self.kv_cache.shape
# Insert k, v into the cache # Insert k, v into the cache
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v
# Return the full cached keys/values up to current position (as a view) # Return the full cached keys/values up to current position (as a view)
key_view = self.kv_cache[layer_idx, 0, :, :, :t1] key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
value_view = self.kv_cache[layer_idx, 1, :, :, :t1] value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]
# Increment pos after the last layer of the Transformer processes # Increment pos after the last layer of the Transformer processes
if layer_idx == self.kv_cache.size(0) - 1: if layer_idx == self.kv_cache.size(0) - 1:
self.pos = t1 self.pos = t1