diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..7403b36 --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,66 @@ +""" +Test Engine class. Example run: + +python -m pytest tests/test_engine.py -v +""" + +import torch +from nanochat.engine import KVCache + +def test_kv_cache_resize(): + """ + The KV cache was not resized correctly, more information here: + https://github.com/karpathy/nanochat/pull/186 + This test reproduces the issue and will be merged alongside the fix. + """ + + batch_size = 2 + num_heads = 3 + seq_len = 4 + head_dim = 5 + num_layers = 6 + + kv_cache = KVCache( + batch_size=batch_size, + num_heads=num_heads, + seq_len=seq_len, + head_dim=head_dim, + num_layers=num_layers + ) + + # Insert a single token with a distinct fill value to all layers + def insert_token(token_idx): + for layer_idx in range(num_layers): + k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32) + v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32) + kv_cache.insert_kv(layer_idx, k, v) + + # Insert 4 tokens (fills the initial seq_len=4) + for i in range(4): + insert_token(i) + + # Record the original state of the cache + original_cache = kv_cache.kv_cache.clone() + original_seq_len = original_cache.shape[4] + + # Insert the 5th token, which will trigger a resize + insert_token(4) + # Verify that the cache actually resized + new_seq_len = kv_cache.kv_cache.shape[4] + assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}" + + # Verify that the original 4 tokens are still intact after resize + for layer_idx in range(num_layers): + for token_idx in range(4): + # Check that resized cache matches expected values + expected_k = float(token_idx) + expected_v = float(token_idx * 100) + actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :] + actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :] + assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}" + assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}" + # And that the original cache matches resized cache + original_k = original_cache[layer_idx, 0, :, :, token_idx, :] + original_v = original_cache[layer_idx, 1, :, :, token_idx, :] + assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original" + assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"