mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
feat(engine.py): Sample unique tokens per row in generation stream
Before, when initiating a batch generation, the first sampled token was broadcasted to all rows. This change updates the logic to ensure that a unique token is sampled for each row, improving the diversity and independence of generated sequences within a batch.
This commit is contained in:
parent
9a08bb4edb
commit
557b2d5840
|
|
@ -185,7 +185,7 @@ class Engine:
|
||||||
seq_len=len(tokens),
|
seq_len=len(tokens),
|
||||||
**kv_model_kwargs,
|
**kv_model_kwargs,
|
||||||
)
|
)
|
||||||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
ids = torch.tensor([tokens.copy() for _ in range(num_samples)], dtype=torch.long, device=device)
|
||||||
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||||
|
|
@ -218,7 +218,7 @@ class Engine:
|
||||||
# Get sampled tokens - either from prefill or from forward pass
|
# Get sampled tokens - either from prefill or from forward pass
|
||||||
if first_iteration:
|
if first_iteration:
|
||||||
# Use the tokens we already sampled from prefill
|
# Use the tokens we already sampled from prefill
|
||||||
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
|
# sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
|
||||||
# TODO: we should sample a token for each row instead of broadcasting
|
# TODO: we should sample a token for each row instead of broadcasting
|
||||||
first_iteration = False
|
first_iteration = False
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user