mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-16 01:02:18 +00:00
test: add comprehensive edge case test suite for sampling with deterministic and stochastic validation
This commit is contained in:
parent
737165ce44
commit
8c8f08955a
|
|
@ -188,7 +188,9 @@ class Engine:
|
|||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
||||
logits = logits[:, -1, :]
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
# Sample num_samples independent tokens from the same distribution
|
||||
logits_repeated = logits.repeat(num_samples, 1)
|
||||
next_ids = sample_next_token(logits_repeated, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
|
||||
# 2) Replicate the KV cache for each sample/row
|
||||
|
|
@ -217,7 +219,9 @@ class Engine:
|
|||
|
||||
# Get sampled tokens - either from prefill or from forward pass
|
||||
if first_iteration:
|
||||
# sampled_tokens already contains num_samples independently sampled tokens
|
||||
# Use the tokens we already sampled from prefill
|
||||
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
|
||||
# TODO: we should sample a token for each row instead of broadcasting
|
||||
first_iteration = False
|
||||
else:
|
||||
# Forward the model and get the next token for each row
|
||||
|
|
@ -339,3 +343,81 @@ if __name__ == "__main__":
|
|||
print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
|
||||
break
|
||||
print(f"Match: {reference_ids == generated_tokens}")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Comprehensive Edge Case Test Suite
|
||||
# -------------------------------------------------------------------------
|
||||
print("\n" + "="*80)
|
||||
print("COMPREHENSIVE EDGE CASE TEST SUITE")
|
||||
print("="*80)
|
||||
|
||||
# Test 1: Single Sample Regression
|
||||
print("\n[TEST 1] Single Sample Regression")
|
||||
print("-" * 40)
|
||||
print("Config: num_samples=1, temperature=1.0, max_tokens=5")
|
||||
test_prompt = tokenizer.encode("Hello", prepend=bos_token_id)
|
||||
results, masks = engine.generate_batch(test_prompt, num_samples=1, temperature=1.0, max_tokens=5, seed=42)
|
||||
generated_sequence = results[0]
|
||||
print(f"Generated tokens: {generated_sequence[:15]}") # Show first 10 tokens after prompt
|
||||
print(f"First 10 generated: {generated_sequence[len(test_prompt):len(test_prompt)+10]}")
|
||||
print("✓ PASS: Single sample generation works")
|
||||
|
||||
# Test 2: Deterministic Sampling (Temperature=0)
|
||||
print("\n[TEST 2] Deterministic Sampling (Temperature=0)")
|
||||
print("-" * 40)
|
||||
print("Config: num_samples=5, temperature=0.0, max_tokens=1")
|
||||
results, masks = engine.generate_batch(test_prompt, num_samples=5, temperature=0.0, max_tokens=1, seed=42)
|
||||
first_tokens = [result[len(test_prompt)] for result in results]
|
||||
print(f"First tokens from all samples: {first_tokens}")
|
||||
unique_tokens = len(set(first_tokens))
|
||||
print(f"Unique tokens: {unique_tokens} / {len(first_tokens)}")
|
||||
if unique_tokens == 1:
|
||||
print("✓ EXPECTED: All samples identical with temperature=0 (argmax behavior)")
|
||||
else:
|
||||
print("✗ UNEXPECTED: Samples differ with temperature=0")
|
||||
|
||||
# Test 3: Stochastic Sampling (Temperature>0)
|
||||
print("\n[TEST 3] Stochastic Sampling (Temperature=1.0)")
|
||||
print("-" * 40)
|
||||
print("Config: num_samples=10, temperature=1.0, max_tokens=1")
|
||||
results, masks = engine.generate_batch(test_prompt, num_samples=10, temperature=1.0, max_tokens=1, seed=42)
|
||||
first_tokens = [result[len(test_prompt)] for result in results]
|
||||
print(f"First tokens from all samples: {first_tokens}")
|
||||
unique_tokens = len(set(first_tokens))
|
||||
print(f"Unique tokens: {unique_tokens} / {len(first_tokens)}")
|
||||
print(f"Diversity: {unique_tokens / len(first_tokens) * 100:.1f}%")
|
||||
if unique_tokens > 1:
|
||||
print("✓ PASS: High temperature produces diverse samples")
|
||||
else:
|
||||
print("✗ FAIL: Expected diversity with temperature=1.0")
|
||||
|
||||
# Test 4a: Top-K Sampling (No Constraint)
|
||||
print("\n[TEST 4a] Top-K Sampling (No Constraint)")
|
||||
print("-" * 40)
|
||||
print("Config: num_samples=10, temperature=1.0, max_tokens=1, top_k=None")
|
||||
results, masks = engine.generate_batch(test_prompt, num_samples=10, temperature=1.0, max_tokens=1, top_k=None, seed=42)
|
||||
first_tokens = [result[len(test_prompt)] for result in results]
|
||||
print(f"First tokens from all samples: {first_tokens}")
|
||||
unique_tokens = len(set(first_tokens))
|
||||
print(f"Unique tokens: {unique_tokens} / {len(first_tokens)}")
|
||||
print(f"Diversity: {unique_tokens / len(first_tokens) * 100:.1f}%")
|
||||
print("✓ PASS: Unconstrained top_k allows full vocabulary sampling")
|
||||
|
||||
# Test 4b: Top-K Sampling (Constrained)
|
||||
print("\n[TEST 4b] Top-K Sampling (Constrained)")
|
||||
print("-" * 40)
|
||||
print("Config: num_samples=20, temperature=1.0, max_tokens=1, top_k=10")
|
||||
results, masks = engine.generate_batch(test_prompt, num_samples=20, temperature=1.0, max_tokens=1, top_k=10, seed=42)
|
||||
first_tokens = [result[len(test_prompt)] for result in results]
|
||||
print(f"First tokens from all samples: {first_tokens}")
|
||||
unique_tokens = len(set(first_tokens))
|
||||
print(f"Unique tokens: {unique_tokens} / {len(first_tokens)}")
|
||||
print(f"Diversity: {unique_tokens / len(first_tokens) * 100:.1f}%")
|
||||
if unique_tokens <= 10:
|
||||
print(f"✓ PASS: Unique tokens ({unique_tokens}) ≤ top_k (10)")
|
||||
else:
|
||||
print(f"✗ FAIL: Unique tokens ({unique_tokens}) > top_k (10)")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("TEST SUITE COMPLETE")
|
||||
print("="*80)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user