Merge pull request #3 from Dianababaei/feat/gpt-prefill-phase-kv-caching

Add 6 lines to GPT class to expand model capabilities and configuration
This commit is contained in:
Dianababaei 2025-11-03 13:33:16 +03:30 committed by GitHub
commit 8927ec79c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,7 +22,6 @@ import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
from nanochat.engine import KVCache
@dataclass
class GPTConfig:
@ -305,16 +304,13 @@ class GPT(nn.Module):
if temperature > 0:
rng = torch.Generator(device=device)
rng.manual_seed(seed)
# Initialize KV cache for efficient generation
kv_length_hint = len(tokens) + max_tokens
kv_cache = KVCache(
batch_size=1,
num_heads=self.config.n_kv_head,
seq_len=kv_length_hint,
head_dim=self.config.n_embd // self.config.n_head,
num_layers=self.config.n_layer
)
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
# Prefill phase: process entire prompt in a single forward pass
logits = self.forward(ids, kv_cache=kv_cache) # (B, T_prompt, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size) - only need last token's logits
# Incremental decoding: generate tokens one at a time using cached K/V pairs
for _ in range(max_tokens):
logits = self.forward(ids) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)