diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 1b8ea93..79c6085 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,7 +22,6 @@ import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW -from nanochat.engine import KVCache @dataclass class GPTConfig: @@ -294,7 +293,7 @@ class GPT(nn.Module): @torch.inference_mode() def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42): """ - Efficient autoregressive streaming inference with KV caching. + Efficient autoregressive streaming inference with KV cache. To make it super simple, let's assume: - batch size is 1 - ids and the yielded tokens are simple Python lists and ints @@ -305,25 +304,10 @@ class GPT(nn.Module): if temperature > 0: rng = torch.Generator(device=device) rng.manual_seed(seed) - - # Initialize KV cache - m = self.config - kv_cache = KVCache( - batch_size=1, - num_heads=m.n_kv_head, - seq_len=len(tokens) + max_tokens, - head_dim=m.n_embd // m.n_head, - num_layers=m.n_layer, - ) - - # Prefill: forward pass on full prompt to populate KV cache and get initial logits ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim - logits = self.forward(ids, kv_cache=kv_cache) # (B, T, vocab_size) - logits = logits[:, -1, :] # (B, vocab_size) - - # Generation loop: process one token at a time for _ in range(max_tokens): - # Sample from existing logits (from prefill or previous iteration) + logits = self.forward(ids) # (B, T, vocab_size) + logits = logits[:, -1, :] # (B, vocab_size) if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') @@ -333,9 +317,6 @@ class GPT(nn.Module): next_ids = torch.multinomial(probs, num_samples=1, generator=rng) else: next_ids = torch.argmax(logits, dim=-1, keepdim=True) + ids = torch.cat((ids, next_ids), dim=1) token = next_ids.item() yield token - - # Forward pass on only the new token to get next logits - logits = self.forward(next_ids, kv_cache=kv_cache) # (B, 1, vocab_size) - logits = logits[:, -1, :] # (B, vocab_size)