Fix temperature test

This commit is contained in:
Sofie Van Landeghem 2025-12-31 11:49:46 +01:00 committed by GitHub
parent bc81d6a460
commit 31aeda19d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -206,11 +206,10 @@ def test_temperature_zero_determinism():
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111]
for seed in [1, 42, 123, 999]:
r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed)
r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed)
r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed)
assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed."
r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1)
r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42)
r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123)
assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed."
def test_max_tokens_respected():