Compare commits

...

3 Commits

Author SHA1 Message Date
Abdulaziz Gabitov
920fa8663d
Merge 7bd999ba02 into bc1fca39f3 2025-11-15 22:38:03 +03:00
Andrej Karpathy
bc1fca39f3 mqa -> gqa to reduce confusion 2025-11-15 15:43:37 +00:00
Azekowka
7bd999ba02 feat(engine.py): Sample unique initial tokens for each sequence in a batch
Before, when initiating a batch generation, the first sampled token was broadcasted to all sequences. This change now expands the output logits after a single efficient prefill, allowing for the sampling of a unique starting token for each sequence in the batch.
2025-10-29 22:02:28 +05:00
2 changed files with 4 additions and 4 deletions

View File

@ -218,6 +218,8 @@ class Engine:
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :]
# expand the logits to be of size (num_samples, vocab_size)
logits = logits.expand(num_samples, -1)
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
@ -248,8 +250,6 @@ class Engine:
# Get sampled tokens - either from prefill or from forward pass
if first_iteration:
# Use the tokens we already sampled from prefill
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
# TODO: we should sample a token for each row instead of broadcasting
first_iteration = False
else:
# Forward the model and get the next token for each row

View File

@ -8,7 +8,7 @@ Notable features:
- norm after token embedding
- no learnable params in rmsnorm
- no bias in linear layers
- Multi-Query Attention (MQA) support for more efficient inference
- Group-Query Attention (GQA) support for more efficient inference
"""
import math
@ -29,7 +29,7 @@ class GPTConfig:
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (MQA)
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768