add test for seed variation in sampling

Add test for seed variation in sampling with temperature > 0.
This commit is contained in:
Barış Özmen 2025-12-31 15:43:42 +03:00 committed by GitHub
parent 31aeda19d1
commit 57ffd35e0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -233,3 +233,28 @@ def test_num_samples_count():
for num_samples in [1, 4, 16, 64]:
results, _ = engine.generate_batch(prompt, num_samples=num_samples, max_tokens=3)
assert len(results) == num_samples, f"Expected {num_samples} sequences from {num_samples} samples, got {len(results)}"
def test_seed_variation_in_sampling():
"""With temperature > 0, different seeds should introduce sampling variation."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
outputs = set()
for seed in [1, 42, 123, 999, 1000, 1001, 1002, 1003, 1004, 1005]:
results, _ = engine.generate_batch(
prompt,
temperature=1.0,
max_tokens=5,
seed=seed,
)
outputs.add(tuple(results[0]))
# Sanity check: sampling actually introduces variation
assert len(outputs) > 1, (
f"All seeds produced the same output: {outputs}"
f"with temperature > 0 and different seeds, this is statistically impossible."
f"implies an issue within engine."
)