From 31aeda19d10e4686cf6fcdcc079676909f2a9910 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Wed, 31 Dec 2025 11:49:46 +0100 Subject: [PATCH] Fix temperature test --- tests/test_engine.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_engine.py b/tests/test_engine.py index 01a30ee..75ad7b8 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -206,11 +206,10 @@ def test_temperature_zero_determinism(): engine = Engine(model, ByteTokenizer()) prompt = [261, 72, 101, 108, 108, 111] - for seed in [1, 42, 123, 999]: - r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed) - r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed) - r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed) - assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed." + r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1) + r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42) + r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123) + assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed." def test_max_tokens_respected():