diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 82f13b6..2280de6 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -305,6 +305,15 @@ class GPT(nn.Module): if temperature > 0: rng = torch.Generator(device=device) rng.manual_seed(seed) + # Initialize KV cache for efficient generation + kv_length_hint = len(tokens) + max_tokens + kv_cache = KVCache( + batch_size=1, + num_heads=self.config.n_kv_head, + seq_len=kv_length_hint, + head_dim=self.config.n_embd // self.config.n_head, + num_layers=self.config.n_layer + ) ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim for _ in range(max_tokens): logits = self.forward(ids) # (B, T, vocab_size)