Merge branch 'haowei01-fix_kv_cache_due_to_resize'

This commit is contained in:
Andrej Karpathy 2025-10-28 16:54:30 +00:00
commit fdda5826e3
2 changed files with 71 additions and 3 deletions

View File

@ -135,9 +135,11 @@ class KVCache:
if t1 > self.kv_cache.size(4): if t1 > self.kv_cache.size(4):
t_needed = t1 + 1024 # as much as we need plus buffer of 1024 t_needed = t1 + 1024 # as much as we need plus buffer of 1024
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024 t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
current_shape = list(self.kv_cache.shape) additional_shape = list(self.kv_cache.shape)
current_shape[4] = t_needed additional_shape[4] = t_needed - self.kv_cache.size(4)
self.kv_cache.resize_(current_shape) additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
self.kv_shape = self.kv_cache.shape
# Insert k, v into the cache # Insert k, v into the cache
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v self.kv_cache[layer_idx, 1, :, :, t0:t1] = v

66
tests/test_engine.py Normal file
View File

@ -0,0 +1,66 @@
"""
Test Engine class. Example run:
python -m pytest tests/test_engine.py -v
"""
import torch
from nanochat.engine import KVCache
def test_kv_cache_resize():
"""
The KV cache was not resized correctly, more information here:
https://github.com/karpathy/nanochat/pull/186
This test reproduces the issue and will be merged alongside the fix.
"""
batch_size = 2
num_heads = 3
seq_len = 4
head_dim = 5
num_layers = 6
kv_cache = KVCache(
batch_size=batch_size,
num_heads=num_heads,
seq_len=seq_len,
head_dim=head_dim,
num_layers=num_layers
)
# Insert a single token with a distinct fill value to all layers
def insert_token(token_idx):
for layer_idx in range(num_layers):
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
kv_cache.insert_kv(layer_idx, k, v)
# Insert 4 tokens (fills the initial seq_len=4)
for i in range(4):
insert_token(i)
# Record the original state of the cache
original_cache = kv_cache.kv_cache.clone()
original_seq_len = original_cache.shape[4]
# Insert the 5th token, which will trigger a resize
insert_token(4)
# Verify that the cache actually resized
new_seq_len = kv_cache.kv_cache.shape[4]
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
# Verify that the original 4 tokens are still intact after resize
for layer_idx in range(num_layers):
for token_idx in range(4):
# Check that resized cache matches expected values
expected_k = float(token_idx)
expected_v = float(token_idx * 100)
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
# And that the original cache matches resized cache
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"