mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-04 06:35:23 +00:00
add test for seed variation in sampling
Add test for seed variation in sampling with temperature > 0.
This commit is contained in:
parent
31aeda19d1
commit
57ffd35e0a
|
|
@ -233,3 +233,28 @@ def test_num_samples_count():
|
|||
for num_samples in [1, 4, 16, 64]:
|
||||
results, _ = engine.generate_batch(prompt, num_samples=num_samples, max_tokens=3)
|
||||
assert len(results) == num_samples, f"Expected {num_samples} sequences from {num_samples} samples, got {len(results)}"
|
||||
|
||||
|
||||
def test_seed_variation_in_sampling():
|
||||
"""With temperature > 0, different seeds should introduce sampling variation."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
||||
|
||||
outputs = set()
|
||||
|
||||
for seed in [1, 42, 123, 999, 1000, 1001, 1002, 1003, 1004, 1005]:
|
||||
results, _ = engine.generate_batch(
|
||||
prompt,
|
||||
temperature=1.0,
|
||||
max_tokens=5,
|
||||
seed=seed,
|
||||
)
|
||||
outputs.add(tuple(results[0]))
|
||||
|
||||
# Sanity check: sampling actually introduces variation
|
||||
assert len(outputs) > 1, (
|
||||
f"All seeds produced the same output: {outputs}"
|
||||
f"with temperature > 0 and different seeds, this is statistically impossible."
|
||||
f"implies an issue within engine."
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user