From d4d014010ba4f7367db9c73dd2cbc87c72e5ce98 Mon Sep 17 00:00:00 2001 From: Harsh Gupta Date: Wed, 28 Jan 2026 14:31:30 +0530 Subject: [PATCH] Fix generate() crash when top_k=0 Prevent a crash in generate() by skipping top-k filtering when top_k is set to 0 --- nanochat/gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index c55e893..6d75a31 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -489,7 +489,7 @@ class GPT(nn.Module): for _ in range(max_tokens): logits = self.forward(ids) # (B, T, vocab_size) logits = logits[:, -1, :] # (B, vocab_size) - if top_k is not None: + if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') if temperature > 0: