diff --git a/nanochat/gpt.py b/nanochat/gpt.py index c55e893..6d75a31 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -489,7 +489,7 @@ class GPT(nn.Module): for _ in range(max_tokens): logits = self.forward(ids) # (B, T, vocab_size) logits = logits[:, -1, :] # (B, vocab_size) - if top_k is not None: + if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') if temperature > 0: