mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-07 11:09:55 +00:00
Fix temperature test
This commit is contained in:
parent
bc81d6a460
commit
31aeda19d1
|
|
@ -206,11 +206,10 @@ def test_temperature_zero_determinism():
|
|||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111]
|
||||
|
||||
for seed in [1, 42, 123, 999]:
|
||||
r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed)
|
||||
r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed)
|
||||
r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=seed)
|
||||
assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed."
|
||||
r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1)
|
||||
r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42)
|
||||
r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123)
|
||||
assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed."
|
||||
|
||||
|
||||
def test_max_tokens_respected():
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user