Fix generate() crash when top_k=0 (#467)

Prevent a crash in generate() by skipping top-k filtering when top_k is set to 0
This commit is contained in:
Harsh Gupta 2026-01-30 22:51:02 +05:30 committed by GitHub
parent 02baa15405
commit 2e17723817
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -440,7 +440,7 @@ class GPT(nn.Module):
for _ in range(max_tokens):
logits = self.forward(ids) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
if top_k is not None:
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
if temperature > 0: