diff --git a/nanochat/engine.py b/nanochat/engine.py index cc207e8..ef2867a 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -359,7 +359,7 @@ class Engine: ) else: logits = self.model.forward(ids, kv_cache=kv_cache_decode) - logits = logits[:, -1, :] + logits = logits[:, -1, :] # (B, vocab_size) def generate_batch(self, tokens, num_samples=1, **kwargs): """