diff --git a/nanochat/engine.py b/nanochat/engine.py index d461fe37..f585b87b 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -188,7 +188,9 @@ class Engine: ids = torch.tensor([tokens], dtype=torch.long, device=device) logits = self.model.forward(ids, kv_cache=kv_cache_prefill) logits = logits[:, -1, :] - next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) + # Sample num_samples independent tokens from the same distribution + logits_repeated = logits.repeat(num_samples, 1) + next_ids = sample_next_token(logits_repeated, rng, temperature, top_k) # (B, 1) sampled_tokens = next_ids[:, 0].tolist() # 2) Replicate the KV cache for each sample/row @@ -341,37 +343,3 @@ if __name__ == "__main__": print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}") break print(f"Match: {reference_ids == generated_tokens}") - - # Test multi-sample generation for token diversity - print("\n" + "=" * 60) - print("Testing token broadcasting fix...") - print("=" * 60) - - # Generate 10 samples with stochastic sampling - first_tokens = [] - for token_column, token_masks in engine.generate( - prompt_tokens, - num_samples=10, - temperature=1.0, - top_k=50, - max_tokens=1, - seed=42 - ): - # Extract first token from each sample - first_tokens = token_column - break # Only need the first iteration - - # Calculate diversity metrics - unique_tokens = len(set(first_tokens)) - - # Print results - print(f"Generated 10 samples") - print(f"First tokens: {first_tokens}") - print(f"Unique first tokens: {unique_tokens}/10") - - # Display pass/fail verdict - if unique_tokens > 1: - print(f"✅ PASSED: Multiple unique first tokens ({unique_tokens}/10)") - print("Note: With temperature=1.0, expect 5-8 unique tokens out of 10") - else: - print(f"❌ FAILED: All samples have the same first token (broadcasting bug still exists)")