diff --git a/tests/test_engine.py b/tests/test_engine.py index 683f89b..01a30ee 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -185,3 +185,52 @@ def test_multi_sample_first_token_diversity(): f"With uniform logits, this is statistically impossible (~10^-36 probability) " f"unless tokens are being broadcast instead of independently sampled." ) + + +def test_seed_reproducibility(): + """Same seed must produce identical output.""" + model = MockModel() + engine = Engine(model, ByteTokenizer()) + prompt = [261, 72, 101, 108, 108, 111] # + "Hello" + + for seed in [1, 42, 123, 999]: + r1, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed) + r2, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed) + r3, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed) + assert r1 == r2 == r3, "Same seed must produce identical output for the same prompt." + + +def test_temperature_zero_determinism(): + """Temperature=0 is deterministic regardless of seed.""" + model = MockModel() + engine = Engine(model, ByteTokenizer()) + prompt = [261, 72, 101, 108, 108, 111] + + for seed in [1, 42, 123, 999]: + r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed) + r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed) + r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed) + assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed." + + +def test_max_tokens_respected(): + """Generation stops at max_tokens limit.""" + model = MockModel() + engine = Engine(model, ByteTokenizer()) + prompt = [261, 72, 101, 108, 108, 111] + + for max_tokens in [1, 4, 16, 64]: + results, _ = engine.generate_batch(prompt, max_tokens=max_tokens) + num_generated_tokens = len(results[0]) - len(prompt) + assert num_generated_tokens <= max_tokens, f"Generated {num_generated_tokens} tokens, expected max_tokens={max_tokens} or less." + + +def test_num_samples_count(): + """num_samples=N produces exactly N sequences.""" + model = MockModel() + engine = Engine(model, ByteTokenizer()) + prompt = [261, 72, 101, 108, 108, 111] + + for num_samples in [1, 4, 16, 64]: + results, _ = engine.generate_batch(prompt, num_samples=num_samples, max_tokens=3) + assert len(results) == num_samples, f"Expected {num_samples} sequences from {num_samples} samples, got {len(results)}"