From b78bc3fd9f87a152cf64a1b6522e11c1dfe75599 Mon Sep 17 00:00:00 2001 From: Artemis Git Integration Date: Mon, 3 Nov 2025 10:04:43 +0000 Subject: [PATCH] =?UTF-8?q?perf:=20optimize=20generation=20loop=20from=20O?= =?UTF-8?q?(T=C2=B2)=20to=20O(T)=20using=20KV-cache?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactor to process one token per iteration instead of reprocessing entire sequence. Reorder loop to sample → yield → forward with single token input, enabling fast path in attention ( --- nanochat/gpt.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 11ecbd3..1b8ea93 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW +from nanochat.engine import KVCache @dataclass class GPTConfig: @@ -293,7 +294,7 @@ class GPT(nn.Module): @torch.inference_mode() def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42): """ - Naive autoregressive streaming inference. + Efficient autoregressive streaming inference with KV caching. To make it super simple, let's assume: - batch size is 1 - ids and the yielded tokens are simple Python lists and ints @@ -304,16 +305,25 @@ class GPT(nn.Module): if temperature > 0: rng = torch.Generator(device=device) rng.manual_seed(seed) + + # Initialize KV cache + m = self.config + kv_cache = KVCache( + batch_size=1, + num_heads=m.n_kv_head, + seq_len=len(tokens) + max_tokens, + head_dim=m.n_embd // m.n_head, + num_layers=m.n_layer, + ) + + # Prefill: forward pass on full prompt to populate KV cache and get initial logits ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim + logits = self.forward(ids, kv_cache=kv_cache) # (B, T, vocab_size) + logits = logits[:, -1, :] # (B, vocab_size) - # Prefill phase: process entire prompt in a single forward pass - logits = self.forward(ids, kv_cache=kv_cache) # (B, T_prompt, vocab_size) - logits = logits[:, -1, :] # (B, vocab_size) - only need last token's logits - - # Incremental decoding: generate tokens one at a time using cached K/V pairs + # Generation loop: process one token at a time for _ in range(max_tokens): - logits = self.forward(ids) # (B, T, vocab_size) - logits = logits[:, -1, :] # (B, vocab_size) + # Sample from existing logits (from prefill or previous iteration) if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') @@ -323,6 +333,9 @@ class GPT(nn.Module): next_ids = torch.multinomial(probs, num_samples=1, generator=rng) else: next_ids = torch.argmax(logits, dim=-1, keepdim=True) - ids = torch.cat((ids, next_ids), dim=1) token = next_ids.item() yield token + + # Forward pass on only the new token to get next logits + logits = self.forward(next_ids, kv_cache=kv_cache) # (B, 1, vocab_size) + logits = logits[:, -1, :] # (B, vocab_size)