From b062b422ac6c4578b986eae487645aef42eb9152 Mon Sep 17 00:00:00 2001 From: Haowei Zhang Date: Mon, 27 Oct 2025 02:23:08 -0700 Subject: [PATCH 1/4] Fix kv cache, given resize will destroys the logical structure --- nanochat/engine.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index fee06a1..307590b 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -135,9 +135,10 @@ class KVCache: if t1 > self.kv_cache.size(4): t_needed = t1 + 1024 # as much as we need plus buffer of 1024 t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024 - current_shape = list(self.kv_cache.shape) - current_shape[4] = t_needed - self.kv_cache.resize_(current_shape) + additional_shape = list(self.kv_cache.shape) + additional_shape[4] = t_needed - self.kv_cache.size(4) + additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device) + self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous() # Insert k, v into the cache self.kv_cache[layer_idx, 0, :, :, t0:t1] = k self.kv_cache[layer_idx, 1, :, :, t0:t1] = v From 2b9c0855592dbd128e5a44109a95ba261c1e0201 Mon Sep 17 00:00:00 2001 From: Haowei Zhang Date: Mon, 27 Oct 2025 02:47:13 -0700 Subject: [PATCH 2/4] update the kv_shape --- nanochat/engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nanochat/engine.py b/nanochat/engine.py index 307590b..44ed16b 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -139,6 +139,7 @@ class KVCache: additional_shape[4] = t_needed - self.kv_cache.size(4) additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device) self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous() + self.kv_shape = self.kv_cache.shape # Insert k, v into the cache self.kv_cache[layer_idx, 0, :, :, t0:t1] = k self.kv_cache[layer_idx, 1, :, :, t0:t1] = v From f1db6b47127f90b7b2cd67e136ddfe5ae7e889ac Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 28 Oct 2025 15:17:43 +0000 Subject: [PATCH 3/4] delete czar call for help, i'm working through the inbound on that now. add current LLM policy which just asks for disclosure atm --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 94b982b..4f09c3b 100644 --- a/README.md +++ b/README.md @@ -192,7 +192,7 @@ python -m pytest tests/test_rustbpe.py -v -s nanochat is nowhere finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card. -I am looking for someone to be the "nanochat repo czar" to help me manage the nanochat repo and its issues and PRs and be the first round of defense. Examples of work include merging simple fixes (docs, typos, clear and simple bugs etc.), rejecting vibe coded PRs, managing the Issues/PRs, doing brief "sanity check testing" of PRs on the two officially supported platforms (Linux/GPU and Macbook), organizing information into brief updates and highlights for me. We'd be in touch on DMs on Discord or X or whatever is easiest. For your services to the repo you will be listed and linked to under acknowledgements as the nanochat repo czar. Position is at-will so you can contribute for a while and then "resign" at any time later, totally ok and thank you for your help, just me know. Apply via DM to me on X, thank you! +Current LLM policy: disclosure. When submitting a PR, please declare any parts that had substantial LLM contribution and that you have not written or that you do not fully understand. ## Acknowledgements From baf0b3fddaa8c8d9bafd1f3dda449eb65cb98976 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 28 Oct 2025 16:54:17 +0000 Subject: [PATCH 4/4] also add a test that failed before the fix and passes now with the fix for kv cache resize --- tests/test_engine.py | 66 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 tests/test_engine.py diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..7403b36 --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,66 @@ +""" +Test Engine class. Example run: + +python -m pytest tests/test_engine.py -v +""" + +import torch +from nanochat.engine import KVCache + +def test_kv_cache_resize(): + """ + The KV cache was not resized correctly, more information here: + https://github.com/karpathy/nanochat/pull/186 + This test reproduces the issue and will be merged alongside the fix. + """ + + batch_size = 2 + num_heads = 3 + seq_len = 4 + head_dim = 5 + num_layers = 6 + + kv_cache = KVCache( + batch_size=batch_size, + num_heads=num_heads, + seq_len=seq_len, + head_dim=head_dim, + num_layers=num_layers + ) + + # Insert a single token with a distinct fill value to all layers + def insert_token(token_idx): + for layer_idx in range(num_layers): + k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32) + v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32) + kv_cache.insert_kv(layer_idx, k, v) + + # Insert 4 tokens (fills the initial seq_len=4) + for i in range(4): + insert_token(i) + + # Record the original state of the cache + original_cache = kv_cache.kv_cache.clone() + original_seq_len = original_cache.shape[4] + + # Insert the 5th token, which will trigger a resize + insert_token(4) + # Verify that the cache actually resized + new_seq_len = kv_cache.kv_cache.shape[4] + assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}" + + # Verify that the original 4 tokens are still intact after resize + for layer_idx in range(num_layers): + for token_idx in range(4): + # Check that resized cache matches expected values + expected_k = float(token_idx) + expected_v = float(token_idx * 100) + actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :] + actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :] + assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}" + assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}" + # And that the original cache matches resized cache + original_k = original_cache[layer_idx, 0, :, :, token_idx, :] + original_v = original_cache[layer_idx, 1, :, :, token_idx, :] + assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original" + assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"