mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
Merge pull request #4 from Dianababaei/feat/kv-cached-generation-loop-o-t-optimization
refactor: Update GPT generate method and modify GPTConfig class parameters
This commit is contained in:
commit
333919d764
|
|
@ -22,6 +22,7 @@ import torch.nn.functional as F
|
||||||
from nanochat.common import get_dist_info, print0
|
from nanochat.common import get_dist_info, print0
|
||||||
from nanochat.muon import Muon, DistMuon
|
from nanochat.muon import Muon, DistMuon
|
||||||
from nanochat.adamw import DistAdamW
|
from nanochat.adamw import DistAdamW
|
||||||
|
from nanochat.engine import KVCache
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTConfig:
|
class GPTConfig:
|
||||||
|
|
@ -293,7 +294,7 @@ class GPT(nn.Module):
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
||||||
"""
|
"""
|
||||||
Naive autoregressive streaming inference.
|
Efficient autoregressive streaming inference with KV caching.
|
||||||
To make it super simple, let's assume:
|
To make it super simple, let's assume:
|
||||||
- batch size is 1
|
- batch size is 1
|
||||||
- ids and the yielded tokens are simple Python lists and ints
|
- ids and the yielded tokens are simple Python lists and ints
|
||||||
|
|
@ -304,16 +305,25 @@ class GPT(nn.Module):
|
||||||
if temperature > 0:
|
if temperature > 0:
|
||||||
rng = torch.Generator(device=device)
|
rng = torch.Generator(device=device)
|
||||||
rng.manual_seed(seed)
|
rng.manual_seed(seed)
|
||||||
|
|
||||||
|
# Initialize KV cache
|
||||||
|
m = self.config
|
||||||
|
kv_cache = KVCache(
|
||||||
|
batch_size=1,
|
||||||
|
num_heads=m.n_kv_head,
|
||||||
|
seq_len=len(tokens) + max_tokens,
|
||||||
|
head_dim=m.n_embd // m.n_head,
|
||||||
|
num_layers=m.n_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill: forward pass on full prompt to populate KV cache and get initial logits
|
||||||
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
||||||
|
logits = self.forward(ids, kv_cache=kv_cache) # (B, T, vocab_size)
|
||||||
|
logits = logits[:, -1, :] # (B, vocab_size)
|
||||||
|
|
||||||
# Prefill phase: process entire prompt in a single forward pass
|
# Generation loop: process one token at a time
|
||||||
logits = self.forward(ids, kv_cache=kv_cache) # (B, T_prompt, vocab_size)
|
|
||||||
logits = logits[:, -1, :] # (B, vocab_size) - only need last token's logits
|
|
||||||
|
|
||||||
# Incremental decoding: generate tokens one at a time using cached K/V pairs
|
|
||||||
for _ in range(max_tokens):
|
for _ in range(max_tokens):
|
||||||
logits = self.forward(ids) # (B, T, vocab_size)
|
# Sample from existing logits (from prefill or previous iteration)
|
||||||
logits = logits[:, -1, :] # (B, vocab_size)
|
|
||||||
if top_k is not None:
|
if top_k is not None:
|
||||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||||
|
|
@ -323,6 +333,9 @@ class GPT(nn.Module):
|
||||||
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
|
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
|
||||||
else:
|
else:
|
||||||
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||||
ids = torch.cat((ids, next_ids), dim=1)
|
|
||||||
token = next_ids.item()
|
token = next_ids.item()
|
||||||
yield token
|
yield token
|
||||||
|
|
||||||
|
# Forward pass on only the new token to get next logits
|
||||||
|
logits = self.forward(next_ids, kv_cache=kv_cache) # (B, 1, vocab_size)
|
||||||
|
logits = logits[:, -1, :] # (B, vocab_size)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user