nanochat/tests/test_auto_batch_size.py

387 lines
13 KiB
Python

"""
Unit tests for auto-discovery batch size functionality.
Run with: pytest tests/test_auto_batch_size.py -v
"""
import pytest
import torch
import torch.nn as nn
from unittest.mock import Mock, patch, MagicMock
import tempfile
import os
import json
# Import the module to test
from nanochat.auto_batch_size import (
discover_batch_size,
_perform_discovery,
_test_batch_size,
_get_cache_key,
_load_from_cache,
_save_to_cache,
)
class SimpleTestModel(nn.Module):
"""Simple model for testing."""
def __init__(self, hidden_size=1024):
super().__init__()
self.layer = nn.Linear(hidden_size, hidden_size)
def forward(self, x, y=None):
# Simplified forward pass
out = self.layer(x.float())
if y is not None:
loss = (out - y.float()).pow(2).mean()
return loss
return out
# ============================================================================
# Test 1: Exponential Search Logic
# ============================================================================
def test_exponential_search():
"""Test that exponential search finds upper bound correctly."""
model = SimpleTestModel()
device = torch.device('cpu')
max_seq_len = 256
# Mock _test_batch_size to return True up to 32, False at 64
with patch('nanochat.auto_batch_size._test_batch_size') as mock_test:
def side_effect(model, bs, seq_len, dev):
return bs < 64
mock_test.side_effect = side_effect
# Mock _perform_discovery to track calls
with patch('nanochat.auto_batch_size._perform_discovery') as mock_discover:
# Simulate exponential search behavior
tried_sizes = []
batch_size = 1
while batch_size <= 128:
works = mock_test(model, batch_size, max_seq_len, device)
tried_sizes.append(batch_size)
if not works:
break
batch_size *= 2
# Verify exponential progression: 1, 2, 4, 8, 16, 32, 64
assert tried_sizes == [1, 2, 4, 8, 16, 32, 64], \
f"Expected [1, 2, 4, 8, 16, 32, 64], got {tried_sizes}"
# Verify we found the boundary (32 works, 64 fails)
assert mock_test(model, 32, max_seq_len, device) == True
assert mock_test(model, 64, max_seq_len, device) == False
# ============================================================================
# Test 2: Binary Search Refinement
# ============================================================================
def test_binary_search_refinement():
"""Test that binary search narrows down to exact boundary."""
model = SimpleTestModel()
device = torch.device('cpu')
max_seq_len = 256
# Mock OOM boundary at batch_size=52
with patch('nanochat.auto_batch_size._test_batch_size') as mock_test:
def side_effect(model, bs, seq_len, dev):
return bs <= 52
mock_test.side_effect = side_effect
# Simulate binary search between 32 and 64
tried_sizes = []
low, high = 32, 64
while low < high:
mid = (low + high + 1) // 2
tried_sizes.append(mid)
if mock_test(model, mid, max_seq_len, device):
low = mid
else:
high = mid - 1
result = low
# Should have tried: 48, 56, 52
assert 48 in tried_sizes, "Should try midpoint 48"
assert 56 in tried_sizes, "Should try midpoint 56"
assert 52 in tried_sizes, "Should try midpoint 52"
# Should converge to 52
assert result == 52, f"Expected 52, got {result}"
# ============================================================================
# Test 3: Safety Margin Application
# ============================================================================
def test_safety_margin():
"""Test that safety margin is applied correctly."""
margins = [0.85, 0.90, 0.95]
max_batch = 60
expected = [51, 54, 57] # int(60 * margin)
for margin, exp in zip(margins, expected):
result = int(max_batch * margin)
assert result == exp, f"Margin {margin}: expected {exp}, got {result}"
# Test with discover_batch_size
model = SimpleTestModel()
device = torch.device('cpu')
max_seq_len = 256
with patch('nanochat.auto_batch_size._perform_discovery') as mock_discover:
# Mock returns max batch before margin
mock_discover.return_value = max_batch
for margin, exp in zip(margins, expected):
# The actual function should apply the margin internally
# For now, test the calculation
applied = int(max_batch * margin)
assert applied == exp
# ============================================================================
# Test 4: Cache Mechanism
# ============================================================================
def test_cache_hit():
"""Test that cache hit skips discovery."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create mock cache
cache_components = {
'model_config': {'n_layer': 12, 'n_embd': 768},
'gpu': 'A100',
'max_seq_len': 2048,
}
cached_batch_size = 32
# Mock get_base_dir to use tmpdir
with patch('nanochat.auto_batch_size.get_base_dir', return_value=tmpdir):
# Save to cache
_save_to_cache(cache_components, cached_batch_size)
# Load from cache
loaded_size = _load_from_cache(cache_components)
assert loaded_size == cached_batch_size, \
f"Expected {cached_batch_size}, got {loaded_size}"
def test_cache_miss():
"""Test that cache miss triggers discovery."""
with tempfile.TemporaryDirectory() as tmpdir:
cache_components = {
'model_config': {'n_layer': 12, 'n_embd': 768},
'gpu': 'A100',
'max_seq_len': 2048,
}
with patch('nanochat.auto_batch_size.get_base_dir', return_value=tmpdir):
# Try to load from empty cache
loaded_size = _load_from_cache(cache_components)
assert loaded_size is None, "Expected cache miss"
def test_cache_key_includes_components():
"""Test that cache key includes all components."""
components1 = {
'model_config': {'n_layer': 12, 'n_embd': 768},
'gpu': 'A100',
'max_seq_len': 2048,
}
components2 = {
'model_config': {'n_layer': 20, 'n_embd': 1280}, # Different model
'gpu': 'A100',
'max_seq_len': 2048,
}
components3 = {
'model_config': {'n_layer': 12, 'n_embd': 768},
'gpu': 'A100',
'max_seq_len': 1024, # Different seq_len
}
key1 = _get_cache_key(components1)
key2 = _get_cache_key(components2)
key3 = _get_cache_key(components3)
assert key1 != key2, "Different model configs should have different keys"
assert key1 != key3, "Different max_seq_len should have different keys"
assert key2 != key3, "All different components should have different keys"
# Same components should give same key
key1_again = _get_cache_key(components1)
assert key1 == key1_again, "Same components should give same key"
# ============================================================================
# Test 5: DDP Broadcast Simulation
# ============================================================================
def test_ddp_broadcast():
"""Test that rank 0 discovery is broadcast to all ranks."""
model = SimpleTestModel()
device = torch.device('cpu')
max_seq_len = 256
discovered_size = 12
# Mock distributed operations
with patch('nanochat.auto_batch_size._perform_discovery') as mock_discover:
mock_discover.return_value = discovered_size
# Test rank 0 (performs discovery)
with patch('nanochat.auto_batch_size.dist.broadcast') as mock_broadcast:
result = discover_batch_size(
model, max_seq_len, device,
ddp_rank=0, ddp_world_size=4
)
# Rank 0 should perform discovery
mock_discover.assert_called_once()
# Should broadcast the result
assert mock_broadcast.called
# Result should be the discovered size
# Note: actual broadcast simulation is complex,
# this tests the logic flow
def test_ddp_broadcast_rank_non_zero():
"""Test that non-zero ranks receive broadcasted value."""
model = SimpleTestModel()
device = torch.device('cpu')
max_seq_len = 256
with patch('nanochat.auto_batch_size._perform_discovery') as mock_discover:
with patch('nanochat.auto_batch_size.dist.broadcast') as mock_broadcast:
# Simulate broadcast receiving value
def broadcast_side_effect(tensor, src):
tensor.fill_(16) # Simulated received value
mock_broadcast.side_effect = broadcast_side_effect
result = discover_batch_size(
model, max_seq_len, device,
ddp_rank=1, ddp_world_size=4
)
# Rank 1 should NOT perform discovery
mock_discover.assert_not_called()
# Should receive broadcast
assert mock_broadcast.called
# ============================================================================
# Additional Tests
# ============================================================================
def test_min_max_batch_size_constraints():
"""Test that discovery respects min/max constraints."""
model = SimpleTestModel()
device = torch.device('cpu')
max_seq_len = 256
with patch('nanochat.auto_batch_size._perform_discovery') as mock_discover:
# Test with very low max
mock_discover.return_value = 4
result = discover_batch_size(
model, max_seq_len, device,
min_batch_size=1, max_batch_size=4,
ddp_rank=0, ddp_world_size=1
)
# Should be called with the constraints
call_args = mock_discover.call_args
assert call_args[0][4] == 1 # min_batch_size
assert call_args[0][5] == 4 # max_batch_size
def test_discover_with_no_cache():
"""Test discovery without caching."""
model = SimpleTestModel()
device = torch.device('cpu')
max_seq_len = 256
with patch('nanochat.auto_batch_size._perform_discovery') as mock_discover:
mock_discover.return_value = 16
result = discover_batch_size(
model, max_seq_len, device,
use_cache=False,
ddp_rank=0, ddp_world_size=1
)
# Should perform discovery
mock_discover.assert_called_once()
assert result == 16
def test_cache_corruption_handling():
"""Test that corrupted cache is handled gracefully."""
with tempfile.TemporaryDirectory() as tmpdir:
cache_components = {
'model_config': {'n_layer': 12},
'gpu': 'A100',
'max_seq_len': 2048,
}
with patch('nanochat.auto_batch_size.get_base_dir', return_value=tmpdir):
# Create corrupted cache file
cache_dir = os.path.join(tmpdir, "auto_batch_cache")
os.makedirs(cache_dir, exist_ok=True)
cache_key = _get_cache_key(cache_components)
cache_file = os.path.join(cache_dir, f"{cache_key}.json")
# Write corrupted JSON
with open(cache_file, 'w') as f:
f.write("invalid json {{{")
# Should return None instead of crashing
loaded_size = _load_from_cache(cache_components)
assert loaded_size is None, "Corrupted cache should return None"
# ============================================================================
# Integration-style unit test
# ============================================================================
def test_full_discovery_flow():
"""Test the full discovery flow end-to-end."""
model = SimpleTestModel()
device = torch.device('cpu')
max_seq_len = 128 # Small for CPU testing
# Run actual discovery (on CPU, so it won't OOM)
result = discover_batch_size(
model, max_seq_len, device,
safety_margin=0.85,
min_batch_size=1,
max_batch_size=16, # Keep small for CPU
ddp_rank=0,
ddp_world_size=1,
use_cache=False,
)
# Result should be within bounds
assert 1 <= result <= 16, f"Result {result} out of bounds [1, 16]"
# Result should be reasonable
assert result >= 1, "Should find at least batch_size=1"
if __name__ == "__main__":
# Run tests
pytest.main([__file__, "-v", "--tb=short"])