diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 2280de6..11ecbd3 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,7 +22,6 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW -from nanochat.engine import KVCache @dataclass class GPTConfig: @@ -305,16 +304,13 @@ class GPT(nn.Module): if temperature > 0: rng = torch.Generator(device=device) rng.manual_seed(seed) - # Initialize KV cache for efficient generation - kv_length_hint = len(tokens) + max_tokens - kv_cache = KVCache( - batch_size=1, - num_heads=self.config.n_kv_head, - seq_len=kv_length_hint, - head_dim=self.config.n_embd // self.config.n_head, - num_layers=self.config.n_layer - ) ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim + + # Prefill phase: process entire prompt in a single forward pass + logits = self.forward(ids, kv_cache=kv_cache) # (B, T_prompt, vocab_size) + logits = logits[:, -1, :] # (B, vocab_size) - only need last token's logits + + # Incremental decoding: generate tokens one at a time using cached K/V pairs for _ in range(max_tokens): logits = self.forward(ids) # (B, T, vocab_size) logits = logits[:, -1, :] # (B, vocab_size)