""" Test Engine class. Example run: python -m pytest tests/test_engine.py -v """ import torch from nanochat.engine import KVCache def test_kv_cache_resize(): """ The KV cache was not resized correctly, more information here: https://github.com/karpathy/nanochat/pull/186 This test reproduces the issue and will be merged alongside the fix. """ batch_size = 2 num_heads = 3 seq_len = 4 head_dim = 5 num_layers = 6 kv_cache = KVCache( batch_size=batch_size, num_heads=num_heads, seq_len=seq_len, head_dim=head_dim, num_layers=num_layers ) # Insert a single token with a distinct fill value to all layers def insert_token(token_idx): for layer_idx in range(num_layers): k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32) v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32) kv_cache.insert_kv(layer_idx, k, v) # Insert 4 tokens (fills the initial seq_len=4) for i in range(4): insert_token(i) # Record the original state of the cache original_cache = kv_cache.kv_cache.clone() original_seq_len = original_cache.shape[4] # Insert the 5th token, which will trigger a resize insert_token(4) # Verify that the cache actually resized new_seq_len = kv_cache.kv_cache.shape[4] assert new_seq_len > original_seq_len, ( f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}" ) # Verify that the original 4 tokens are still intact after resize for layer_idx in range(num_layers): for token_idx in range(4): # Check that resized cache matches expected values expected_k = float(token_idx) expected_v = float(token_idx * 100) actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :] actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :] assert (actual_k == expected_k).all(), ( f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}" ) assert (actual_v == expected_v).all(), ( f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}" ) # And that the original cache matches resized cache original_k = original_cache[layer_idx, 0, :, :, token_idx, :] original_v = original_cache[layer_idx, 1, :, :, token_idx, :] assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original" assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"