mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
also add a test that failed before the fix and passes now with the fix for kv cache resize
This commit is contained in:
parent
f1db6b4712
commit
baf0b3fdda
66
tests/test_engine.py
Normal file
66
tests/test_engine.py
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
"""
|
||||||
|
Test Engine class. Example run:
|
||||||
|
|
||||||
|
python -m pytest tests/test_engine.py -v
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from nanochat.engine import KVCache
|
||||||
|
|
||||||
|
def test_kv_cache_resize():
|
||||||
|
"""
|
||||||
|
The KV cache was not resized correctly, more information here:
|
||||||
|
https://github.com/karpathy/nanochat/pull/186
|
||||||
|
This test reproduces the issue and will be merged alongside the fix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
num_heads = 3
|
||||||
|
seq_len = 4
|
||||||
|
head_dim = 5
|
||||||
|
num_layers = 6
|
||||||
|
|
||||||
|
kv_cache = KVCache(
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
seq_len=seq_len,
|
||||||
|
head_dim=head_dim,
|
||||||
|
num_layers=num_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert a single token with a distinct fill value to all layers
|
||||||
|
def insert_token(token_idx):
|
||||||
|
for layer_idx in range(num_layers):
|
||||||
|
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
|
||||||
|
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
|
||||||
|
kv_cache.insert_kv(layer_idx, k, v)
|
||||||
|
|
||||||
|
# Insert 4 tokens (fills the initial seq_len=4)
|
||||||
|
for i in range(4):
|
||||||
|
insert_token(i)
|
||||||
|
|
||||||
|
# Record the original state of the cache
|
||||||
|
original_cache = kv_cache.kv_cache.clone()
|
||||||
|
original_seq_len = original_cache.shape[4]
|
||||||
|
|
||||||
|
# Insert the 5th token, which will trigger a resize
|
||||||
|
insert_token(4)
|
||||||
|
# Verify that the cache actually resized
|
||||||
|
new_seq_len = kv_cache.kv_cache.shape[4]
|
||||||
|
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
|
||||||
|
|
||||||
|
# Verify that the original 4 tokens are still intact after resize
|
||||||
|
for layer_idx in range(num_layers):
|
||||||
|
for token_idx in range(4):
|
||||||
|
# Check that resized cache matches expected values
|
||||||
|
expected_k = float(token_idx)
|
||||||
|
expected_v = float(token_idx * 100)
|
||||||
|
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
|
||||||
|
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
|
||||||
|
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
|
||||||
|
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
|
||||||
|
# And that the original cache matches resized cache
|
||||||
|
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
|
||||||
|
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
|
||||||
|
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
|
||||||
|
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
|
||||||
Loading…
Reference in New Issue
Block a user