mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
387 lines
13 KiB
Python
387 lines
13 KiB
Python
"""
|
|
Unit tests for auto-discovery batch size functionality.
|
|
|
|
Run with: pytest tests/test_auto_batch_size.py -v
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
import tempfile
|
|
import os
|
|
import json
|
|
|
|
# Import the module to test
|
|
from nanochat.auto_batch_size import (
|
|
discover_batch_size,
|
|
_perform_discovery,
|
|
_test_batch_size,
|
|
_get_cache_key,
|
|
_load_from_cache,
|
|
_save_to_cache,
|
|
)
|
|
|
|
|
|
class SimpleTestModel(nn.Module):
|
|
"""Simple model for testing."""
|
|
def __init__(self, hidden_size=1024):
|
|
super().__init__()
|
|
self.layer = nn.Linear(hidden_size, hidden_size)
|
|
|
|
def forward(self, x, y=None):
|
|
# Simplified forward pass
|
|
out = self.layer(x.float())
|
|
if y is not None:
|
|
loss = (out - y.float()).pow(2).mean()
|
|
return loss
|
|
return out
|
|
|
|
|
|
# ============================================================================
|
|
# Test 1: Exponential Search Logic
|
|
# ============================================================================
|
|
|
|
def test_exponential_search():
|
|
"""Test that exponential search finds upper bound correctly."""
|
|
model = SimpleTestModel()
|
|
device = torch.device('cpu')
|
|
max_seq_len = 256
|
|
|
|
# Mock _test_batch_size to return True up to 32, False at 64
|
|
with patch('nanochat.auto_batch_size._test_batch_size') as mock_test:
|
|
def side_effect(model, bs, seq_len, dev):
|
|
return bs < 64
|
|
|
|
mock_test.side_effect = side_effect
|
|
|
|
# Mock _perform_discovery to track calls
|
|
with patch('nanochat.auto_batch_size._perform_discovery') as mock_discover:
|
|
# Simulate exponential search behavior
|
|
tried_sizes = []
|
|
batch_size = 1
|
|
while batch_size <= 128:
|
|
works = mock_test(model, batch_size, max_seq_len, device)
|
|
tried_sizes.append(batch_size)
|
|
if not works:
|
|
break
|
|
batch_size *= 2
|
|
|
|
# Verify exponential progression: 1, 2, 4, 8, 16, 32, 64
|
|
assert tried_sizes == [1, 2, 4, 8, 16, 32, 64], \
|
|
f"Expected [1, 2, 4, 8, 16, 32, 64], got {tried_sizes}"
|
|
|
|
# Verify we found the boundary (32 works, 64 fails)
|
|
assert mock_test(model, 32, max_seq_len, device) == True
|
|
assert mock_test(model, 64, max_seq_len, device) == False
|
|
|
|
|
|
# ============================================================================
|
|
# Test 2: Binary Search Refinement
|
|
# ============================================================================
|
|
|
|
def test_binary_search_refinement():
|
|
"""Test that binary search narrows down to exact boundary."""
|
|
model = SimpleTestModel()
|
|
device = torch.device('cpu')
|
|
max_seq_len = 256
|
|
|
|
# Mock OOM boundary at batch_size=52
|
|
with patch('nanochat.auto_batch_size._test_batch_size') as mock_test:
|
|
def side_effect(model, bs, seq_len, dev):
|
|
return bs <= 52
|
|
|
|
mock_test.side_effect = side_effect
|
|
|
|
# Simulate binary search between 32 and 64
|
|
tried_sizes = []
|
|
low, high = 32, 64
|
|
|
|
while low < high:
|
|
mid = (low + high + 1) // 2
|
|
tried_sizes.append(mid)
|
|
if mock_test(model, mid, max_seq_len, device):
|
|
low = mid
|
|
else:
|
|
high = mid - 1
|
|
|
|
result = low
|
|
|
|
# Should have tried: 48, 56, 52
|
|
assert 48 in tried_sizes, "Should try midpoint 48"
|
|
assert 56 in tried_sizes, "Should try midpoint 56"
|
|
assert 52 in tried_sizes, "Should try midpoint 52"
|
|
|
|
# Should converge to 52
|
|
assert result == 52, f"Expected 52, got {result}"
|
|
|
|
|
|
# ============================================================================
|
|
# Test 3: Safety Margin Application
|
|
# ============================================================================
|
|
|
|
def test_safety_margin():
|
|
"""Test that safety margin is applied correctly."""
|
|
margins = [0.85, 0.90, 0.95]
|
|
max_batch = 60
|
|
expected = [51, 54, 57] # int(60 * margin)
|
|
|
|
for margin, exp in zip(margins, expected):
|
|
result = int(max_batch * margin)
|
|
assert result == exp, f"Margin {margin}: expected {exp}, got {result}"
|
|
|
|
# Test with discover_batch_size
|
|
model = SimpleTestModel()
|
|
device = torch.device('cpu')
|
|
max_seq_len = 256
|
|
|
|
with patch('nanochat.auto_batch_size._perform_discovery') as mock_discover:
|
|
# Mock returns max batch before margin
|
|
mock_discover.return_value = max_batch
|
|
|
|
for margin, exp in zip(margins, expected):
|
|
# The actual function should apply the margin internally
|
|
# For now, test the calculation
|
|
applied = int(max_batch * margin)
|
|
assert applied == exp
|
|
|
|
|
|
# ============================================================================
|
|
# Test 4: Cache Mechanism
|
|
# ============================================================================
|
|
|
|
def test_cache_hit():
|
|
"""Test that cache hit skips discovery."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# Create mock cache
|
|
cache_components = {
|
|
'model_config': {'n_layer': 12, 'n_embd': 768},
|
|
'gpu': 'A100',
|
|
'max_seq_len': 2048,
|
|
}
|
|
|
|
cached_batch_size = 32
|
|
|
|
# Mock get_base_dir to use tmpdir
|
|
with patch('nanochat.auto_batch_size.get_base_dir', return_value=tmpdir):
|
|
# Save to cache
|
|
_save_to_cache(cache_components, cached_batch_size)
|
|
|
|
# Load from cache
|
|
loaded_size = _load_from_cache(cache_components)
|
|
|
|
assert loaded_size == cached_batch_size, \
|
|
f"Expected {cached_batch_size}, got {loaded_size}"
|
|
|
|
|
|
def test_cache_miss():
|
|
"""Test that cache miss triggers discovery."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cache_components = {
|
|
'model_config': {'n_layer': 12, 'n_embd': 768},
|
|
'gpu': 'A100',
|
|
'max_seq_len': 2048,
|
|
}
|
|
|
|
with patch('nanochat.auto_batch_size.get_base_dir', return_value=tmpdir):
|
|
# Try to load from empty cache
|
|
loaded_size = _load_from_cache(cache_components)
|
|
|
|
assert loaded_size is None, "Expected cache miss"
|
|
|
|
|
|
def test_cache_key_includes_components():
|
|
"""Test that cache key includes all components."""
|
|
components1 = {
|
|
'model_config': {'n_layer': 12, 'n_embd': 768},
|
|
'gpu': 'A100',
|
|
'max_seq_len': 2048,
|
|
}
|
|
|
|
components2 = {
|
|
'model_config': {'n_layer': 20, 'n_embd': 1280}, # Different model
|
|
'gpu': 'A100',
|
|
'max_seq_len': 2048,
|
|
}
|
|
|
|
components3 = {
|
|
'model_config': {'n_layer': 12, 'n_embd': 768},
|
|
'gpu': 'A100',
|
|
'max_seq_len': 1024, # Different seq_len
|
|
}
|
|
|
|
key1 = _get_cache_key(components1)
|
|
key2 = _get_cache_key(components2)
|
|
key3 = _get_cache_key(components3)
|
|
|
|
assert key1 != key2, "Different model configs should have different keys"
|
|
assert key1 != key3, "Different max_seq_len should have different keys"
|
|
assert key2 != key3, "All different components should have different keys"
|
|
|
|
# Same components should give same key
|
|
key1_again = _get_cache_key(components1)
|
|
assert key1 == key1_again, "Same components should give same key"
|
|
|
|
|
|
# ============================================================================
|
|
# Test 5: DDP Broadcast Simulation
|
|
# ============================================================================
|
|
|
|
def test_ddp_broadcast():
|
|
"""Test that rank 0 discovery is broadcast to all ranks."""
|
|
model = SimpleTestModel()
|
|
device = torch.device('cpu')
|
|
max_seq_len = 256
|
|
discovered_size = 12
|
|
|
|
# Mock distributed operations
|
|
with patch('nanochat.auto_batch_size._perform_discovery') as mock_discover:
|
|
mock_discover.return_value = discovered_size
|
|
|
|
# Test rank 0 (performs discovery)
|
|
with patch('nanochat.auto_batch_size.dist.broadcast') as mock_broadcast:
|
|
result = discover_batch_size(
|
|
model, max_seq_len, device,
|
|
ddp_rank=0, ddp_world_size=4
|
|
)
|
|
|
|
# Rank 0 should perform discovery
|
|
mock_discover.assert_called_once()
|
|
|
|
# Should broadcast the result
|
|
assert mock_broadcast.called
|
|
|
|
# Result should be the discovered size
|
|
# Note: actual broadcast simulation is complex,
|
|
# this tests the logic flow
|
|
|
|
|
|
def test_ddp_broadcast_rank_non_zero():
|
|
"""Test that non-zero ranks receive broadcasted value."""
|
|
model = SimpleTestModel()
|
|
device = torch.device('cpu')
|
|
max_seq_len = 256
|
|
|
|
with patch('nanochat.auto_batch_size._perform_discovery') as mock_discover:
|
|
with patch('nanochat.auto_batch_size.dist.broadcast') as mock_broadcast:
|
|
# Simulate broadcast receiving value
|
|
def broadcast_side_effect(tensor, src):
|
|
tensor.fill_(16) # Simulated received value
|
|
|
|
mock_broadcast.side_effect = broadcast_side_effect
|
|
|
|
result = discover_batch_size(
|
|
model, max_seq_len, device,
|
|
ddp_rank=1, ddp_world_size=4
|
|
)
|
|
|
|
# Rank 1 should NOT perform discovery
|
|
mock_discover.assert_not_called()
|
|
|
|
# Should receive broadcast
|
|
assert mock_broadcast.called
|
|
|
|
|
|
# ============================================================================
|
|
# Additional Tests
|
|
# ============================================================================
|
|
|
|
def test_min_max_batch_size_constraints():
|
|
"""Test that discovery respects min/max constraints."""
|
|
model = SimpleTestModel()
|
|
device = torch.device('cpu')
|
|
max_seq_len = 256
|
|
|
|
with patch('nanochat.auto_batch_size._perform_discovery') as mock_discover:
|
|
# Test with very low max
|
|
mock_discover.return_value = 4
|
|
result = discover_batch_size(
|
|
model, max_seq_len, device,
|
|
min_batch_size=1, max_batch_size=4,
|
|
ddp_rank=0, ddp_world_size=1
|
|
)
|
|
|
|
# Should be called with the constraints
|
|
call_args = mock_discover.call_args
|
|
assert call_args[0][4] == 1 # min_batch_size
|
|
assert call_args[0][5] == 4 # max_batch_size
|
|
|
|
|
|
def test_discover_with_no_cache():
|
|
"""Test discovery without caching."""
|
|
model = SimpleTestModel()
|
|
device = torch.device('cpu')
|
|
max_seq_len = 256
|
|
|
|
with patch('nanochat.auto_batch_size._perform_discovery') as mock_discover:
|
|
mock_discover.return_value = 16
|
|
|
|
result = discover_batch_size(
|
|
model, max_seq_len, device,
|
|
use_cache=False,
|
|
ddp_rank=0, ddp_world_size=1
|
|
)
|
|
|
|
# Should perform discovery
|
|
mock_discover.assert_called_once()
|
|
assert result == 16
|
|
|
|
|
|
def test_cache_corruption_handling():
|
|
"""Test that corrupted cache is handled gracefully."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cache_components = {
|
|
'model_config': {'n_layer': 12},
|
|
'gpu': 'A100',
|
|
'max_seq_len': 2048,
|
|
}
|
|
|
|
with patch('nanochat.auto_batch_size.get_base_dir', return_value=tmpdir):
|
|
# Create corrupted cache file
|
|
cache_dir = os.path.join(tmpdir, "auto_batch_cache")
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
cache_key = _get_cache_key(cache_components)
|
|
cache_file = os.path.join(cache_dir, f"{cache_key}.json")
|
|
|
|
# Write corrupted JSON
|
|
with open(cache_file, 'w') as f:
|
|
f.write("invalid json {{{")
|
|
|
|
# Should return None instead of crashing
|
|
loaded_size = _load_from_cache(cache_components)
|
|
assert loaded_size is None, "Corrupted cache should return None"
|
|
|
|
|
|
# ============================================================================
|
|
# Integration-style unit test
|
|
# ============================================================================
|
|
|
|
def test_full_discovery_flow():
|
|
"""Test the full discovery flow end-to-end."""
|
|
model = SimpleTestModel()
|
|
device = torch.device('cpu')
|
|
max_seq_len = 128 # Small for CPU testing
|
|
|
|
# Run actual discovery (on CPU, so it won't OOM)
|
|
result = discover_batch_size(
|
|
model, max_seq_len, device,
|
|
safety_margin=0.85,
|
|
min_batch_size=1,
|
|
max_batch_size=16, # Keep small for CPU
|
|
ddp_rank=0,
|
|
ddp_world_size=1,
|
|
use_cache=False,
|
|
)
|
|
|
|
# Result should be within bounds
|
|
assert 1 <= result <= 16, f"Result {result} out of bounds [1, 16]"
|
|
|
|
# Result should be reasonable
|
|
assert result >= 1, "Should find at least batch_size=1"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Run tests
|
|
pytest.main([__file__, "-v", "--tb=short"])
|