diff --git a/tests/test_engine.py b/tests/test_engine.py index 01a30ee..75ad7b8 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -206,11 +206,10 @@ def test_temperature_zero_determinism(): engine = Engine(model, ByteTokenizer()) prompt = [261, 72, 101, 108, 108, 111] - for seed in [1, 42, 123, 999]: - r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed) - r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed) - r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed) - assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed." + r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1) + r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42) + r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123) + assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed." def test_max_tokens_respected():