diff --git a/nanochat/engine.py b/nanochat/engine.py index d749d94..f94ff19 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -218,6 +218,8 @@ class Engine: ids = torch.tensor([tokens], dtype=torch.long, device=device) logits = self.model.forward(ids, kv_cache=kv_cache_prefill) logits = logits[:, -1, :] + # expand the logits to be of size (num_samples, vocab_size) + logits = logits.expand(num_samples, -1) next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) sampled_tokens = next_ids[:, 0].tolist() @@ -248,8 +250,6 @@ class Engine: # Get sampled tokens - either from prefill or from forward pass if first_iteration: # Use the tokens we already sampled from prefill - sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows - # TODO: we should sample a token for each row instead of broadcasting first_iteration = False else: # Forward the model and get the next token for each row