diff --git a/nanochat/engine.py b/nanochat/engine.py index de1253a..d461fe3 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -341,3 +341,37 @@ if __name__ == "__main__": print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}") break print(f"Match: {reference_ids == generated_tokens}") + + # Test multi-sample generation for token diversity + print("\n" + "=" * 60) + print("Testing token broadcasting fix...") + print("=" * 60) + + # Generate 10 samples with stochastic sampling + first_tokens = [] + for token_column, token_masks in engine.generate( + prompt_tokens, + num_samples=10, + temperature=1.0, + top_k=50, + max_tokens=1, + seed=42 + ): + # Extract first token from each sample + first_tokens = token_column + break # Only need the first iteration + + # Calculate diversity metrics + unique_tokens = len(set(first_tokens)) + + # Print results + print(f"Generated 10 samples") + print(f"First tokens: {first_tokens}") + print(f"Unique first tokens: {unique_tokens}/10") + + # Display pass/fail verdict + if unique_tokens > 1: + print(f"✅ PASSED: Multiple unique first tokens ({unique_tokens}/10)") + print("Note: With temperature=1.0, expect 5-8 unique tokens out of 10") + else: + print(f"❌ FAILED: All samples have the same first token (broadcasting bug still exists)")