mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
Fix kv_cache indexing to explicitly include head dimension
This commit is contained in:
parent
4a87a0d19f
commit
a770dcef2e
|
|
@ -143,11 +143,11 @@ class KVCache:
|
||||||
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
|
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
|
||||||
self.kv_shape = self.kv_cache.shape
|
self.kv_shape = self.kv_cache.shape
|
||||||
# Insert k, v into the cache
|
# Insert k, v into the cache
|
||||||
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
|
self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k
|
||||||
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
|
self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v
|
||||||
# Return the full cached keys/values up to current position (as a view)
|
# Return the full cached keys/values up to current position (as a view)
|
||||||
key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
|
key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
|
||||||
value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
|
value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]
|
||||||
# Increment pos after the last layer of the Transformer processes
|
# Increment pos after the last layer of the Transformer processes
|
||||||
if layer_idx == self.kv_cache.size(0) - 1:
|
if layer_idx == self.kv_cache.size(0) - 1:
|
||||||
self.pos = t1
|
self.pos = t1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user