Merge pull request #6 from Dianababaei/docs/update-generate-docstring-kv-cache-optimization

Update GPT class generate method signature in nanochat/gpt.py
This commit is contained in:
Dianababaei 2025-11-03 16:07:15 +03:30 committed by GitHub
commit 878d8bbdfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,7 +22,6 @@ import torch.nn.functional as F
from nanochat.common import get_dist_info, print0 from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW from nanochat.adamw import DistAdamW
from nanochat.engine import KVCache
@dataclass @dataclass
class GPTConfig: class GPTConfig:
@ -294,7 +293,7 @@ class GPT(nn.Module):
@torch.inference_mode() @torch.inference_mode()
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42): def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
""" """
Efficient autoregressive streaming inference with KV caching. Efficient autoregressive streaming inference with KV cache.
To make it super simple, let's assume: To make it super simple, let's assume:
- batch size is 1 - batch size is 1
- ids and the yielded tokens are simple Python lists and ints - ids and the yielded tokens are simple Python lists and ints
@ -305,25 +304,10 @@ class GPT(nn.Module):
if temperature > 0: if temperature > 0:
rng = torch.Generator(device=device) rng = torch.Generator(device=device)
rng.manual_seed(seed) rng.manual_seed(seed)
# Initialize KV cache
m = self.config
kv_cache = KVCache(
batch_size=1,
num_heads=m.n_kv_head,
seq_len=len(tokens) + max_tokens,
head_dim=m.n_embd // m.n_head,
num_layers=m.n_layer,
)
# Prefill: forward pass on full prompt to populate KV cache and get initial logits
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
logits = self.forward(ids, kv_cache=kv_cache) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
# Generation loop: process one token at a time
for _ in range(max_tokens): for _ in range(max_tokens):
# Sample from existing logits (from prefill or previous iteration) logits = self.forward(ids) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
if top_k is not None: if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf') logits[logits < v[:, [-1]]] = -float('Inf')
@ -333,9 +317,6 @@ class GPT(nn.Module):
next_ids = torch.multinomial(probs, num_samples=1, generator=rng) next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
else: else:
next_ids = torch.argmax(logits, dim=-1, keepdim=True) next_ids = torch.argmax(logits, dim=-1, keepdim=True)
ids = torch.cat((ids, next_ids), dim=1)
token = next_ids.item() token = next_ids.item()
yield token yield token
# Forward pass on only the new token to get next logits
logits = self.forward(next_ids, kv_cache=kv_cache) # (B, 1, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)