mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
feat(gpt): implement prefill phase for efficient prompt processing with KV-caching
Add prefill phase that processes entire prompt in single forward pass before generation loop, extracting logits only for last token position and populating KV-cache
This commit is contained in:
parent
d0383978df
commit
1131c37a62
|
|
@ -22,7 +22,6 @@ import torch.nn.functional as F
|
||||||
from nanochat.common import get_dist_info, print0
|
from nanochat.common import get_dist_info, print0
|
||||||
from nanochat.muon import Muon, DistMuon
|
from nanochat.muon import Muon, DistMuon
|
||||||
from nanochat.adamw import DistAdamW
|
from nanochat.adamw import DistAdamW
|
||||||
from nanochat.engine import KVCache
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTConfig:
|
class GPTConfig:
|
||||||
|
|
@ -305,16 +304,13 @@ class GPT(nn.Module):
|
||||||
if temperature > 0:
|
if temperature > 0:
|
||||||
rng = torch.Generator(device=device)
|
rng = torch.Generator(device=device)
|
||||||
rng.manual_seed(seed)
|
rng.manual_seed(seed)
|
||||||
# Initialize KV cache for efficient generation
|
|
||||||
kv_length_hint = len(tokens) + max_tokens
|
|
||||||
kv_cache = KVCache(
|
|
||||||
batch_size=1,
|
|
||||||
num_heads=self.config.n_kv_head,
|
|
||||||
seq_len=kv_length_hint,
|
|
||||||
head_dim=self.config.n_embd // self.config.n_head,
|
|
||||||
num_layers=self.config.n_layer
|
|
||||||
)
|
|
||||||
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
||||||
|
|
||||||
|
# Prefill phase: process entire prompt in a single forward pass
|
||||||
|
logits = self.forward(ids, kv_cache=kv_cache) # (B, T_prompt, vocab_size)
|
||||||
|
logits = logits[:, -1, :] # (B, vocab_size) - only need last token's logits
|
||||||
|
|
||||||
|
# Incremental decoding: generate tokens one at a time using cached K/V pairs
|
||||||
for _ in range(max_tokens):
|
for _ in range(max_tokens):
|
||||||
logits = self.forward(ids) # (B, T, vocab_size)
|
logits = self.forward(ids) # (B, T, vocab_size)
|
||||||
logits = logits[:, -1, :] # (B, vocab_size)
|
logits = logits[:, -1, :] # (B, vocab_size)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user