nanochat/tests/test_gpt.py
Rimom Costa 44764ffff0 test: add comprehensive test suite with 66 passing tests
Add test coverage for all major components:
- GPT model: architecture, generation, MQA, rotary embeddings (19 tests)
- Inference engine: KV cache, sampling, tool use (17 tests)
- Optimizers: Muon and AdamW functionality (10 tests)
- Checkpoint management: save/load, metadata (5 tests)
- Data loading and utilities (13 tests)

docs: update README with test documentation and learning guide
- Add For Students section with structured learning path
- Document architectural decisions and key concepts
- Add test usage instructions
2025-10-13 19:18:30 +01:00

414 lines
12 KiB
Python

"""
Tests for the GPT model architecture.
Run with:
python -m pytest tests/test_gpt.py -v -s --timeout=60
"""
import torch
import pytest
from nanochat.gpt import GPT, GPTConfig, norm, apply_rotary_emb, repeat_kv
@pytest.fixture
def small_config():
"""A small GPT config for fast testing."""
return GPTConfig(
sequence_len=128,
vocab_size=256,
n_layer=2,
n_head=4,
n_kv_head=2, # Test MQA with 2:1 ratio
n_embd=64,
)
@pytest.fixture
def tiny_config():
"""An even tinier config for quick tests."""
return GPTConfig(
sequence_len=32,
vocab_size=128,
n_layer=1,
n_head=2,
n_kv_head=1, # Test MQA with 2:1 ratio
n_embd=32,
)
def prepare_model_for_testing(model):
"""Prepare model for CPU testing by converting to float32 but keeping rotary embeddings in bfloat16."""
model = model.float()
model.cos = model.cos.bfloat16()
model.sin = model.sin.bfloat16()
return model
def test_gpt_config():
"""Test that GPTConfig initializes with correct defaults."""
config = GPTConfig()
assert config.sequence_len == 1024
assert config.vocab_size == 50304
assert config.n_layer == 12
assert config.n_head == 6
assert config.n_kv_head == 6
assert config.n_embd == 768
def test_norm_function():
"""Test the RMSNorm function."""
x = torch.randn(2, 4, 8)
y = norm(x)
# Check shape is preserved
assert y.shape == x.shape
# Check it's actually normalized (approximately)
# RMSNorm: y = x / rms(x) where rms = sqrt(mean(x^2))
expected_rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True))
expected_y = x / expected_rms
torch.testing.assert_close(y, expected_y, rtol=1e-4, atol=1e-4)
def test_apply_rotary_emb():
"""Test rotary embeddings application."""
batch_size, num_heads, seq_len, head_dim = 2, 4, 8, 16
x = torch.randn(batch_size, seq_len, num_heads, head_dim)
# Create simple cos/sin for testing
cos = torch.ones(1, seq_len, 1, head_dim // 2)
sin = torch.zeros(1, seq_len, 1, head_dim // 2)
y = apply_rotary_emb(x, cos, sin)
# Check shape is preserved
assert y.shape == x.shape
# Check dtype is preserved
assert y.dtype == x.dtype
def test_repeat_kv():
"""Test the repeat_kv function for MQA."""
bs, n_kv_heads, slen, head_dim = 2, 2, 8, 16
n_rep = 3 # Repeat each KV head 3 times
x = torch.randn(bs, n_kv_heads, slen, head_dim)
y = repeat_kv(x, n_rep)
# Check output shape
assert y.shape == (bs, n_kv_heads * n_rep, slen, head_dim)
# Check that heads are repeated correctly
for i in range(n_kv_heads):
for j in range(n_rep):
torch.testing.assert_close(y[:, i * n_rep + j], x[:, i])
# Test n_rep=1 (no-op case)
y_no_rep = repeat_kv(x, 1)
torch.testing.assert_close(y_no_rep, x)
def test_gpt_initialization(small_config):
"""Test that GPT model initializes correctly."""
model = GPT(small_config)
# Check model has the right components
assert hasattr(model, 'transformer')
assert hasattr(model, 'lm_head')
assert len(model.transformer.h) == small_config.n_layer
# Check parameter count is reasonable
num_params = sum(p.numel() for p in model.parameters())
assert num_params > 0
print(f"Small model has {num_params:,} parameters")
def test_gpt_forward_shape(tiny_config):
"""Test that forward pass produces correct output shapes."""
model = GPT(tiny_config)
model.init_weights()
model = prepare_model_for_testing(model)
model.eval()
batch_size = 2
seq_len = 16
# Create input tokens
idx = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
# Forward pass without targets (inference mode)
with torch.no_grad():
logits = model(idx)
# Check output shape
assert logits.shape == (batch_size, seq_len, tiny_config.vocab_size)
def test_gpt_forward_with_targets(tiny_config):
"""Test forward pass with targets (training mode)."""
model = GPT(tiny_config)
model.init_weights()
model = prepare_model_for_testing(model)
model.train()
batch_size = 2
seq_len = 16
# Create input tokens and targets
idx = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
targets = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
# Forward pass with targets
loss = model(idx, targets=targets)
# Check loss is a scalar
assert loss.shape == ()
assert loss.item() > 0 # Cross-entropy loss should be positive
assert not torch.isnan(loss)
assert not torch.isinf(loss)
def test_gpt_backward(tiny_config):
"""Test that backward pass works and gradients flow."""
model = GPT(tiny_config)
model.init_weights()
model = prepare_model_for_testing(model)
model.train()
batch_size = 2
seq_len = 8
idx = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
targets = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
# Forward and backward
loss = model(idx, targets=targets)
loss.backward()
# Check that gradients are computed for all parameters
for name, param in model.named_parameters():
if param.requires_grad:
assert param.grad is not None, f"No gradient for {name}"
assert not torch.isnan(param.grad).any(), f"NaN gradient in {name}"
assert not torch.isinf(param.grad).any(), f"Inf gradient in {name}"
def test_gpt_generate(tiny_config):
"""Test autoregressive generation."""
model = GPT(tiny_config)
model.init_weights()
model = prepare_model_for_testing(model)
model.eval()
# Start with a few tokens
initial_tokens = [1, 2, 3]
max_new_tokens = 10
# Generate tokens
generated = []
for token in model.generate(initial_tokens, max_tokens=max_new_tokens, temperature=1.0, seed=42):
generated.append(token)
# Check we generated the right number of tokens
assert len(generated) == max_new_tokens
# Check all tokens are valid
for token in generated:
assert 0 <= token < tiny_config.vocab_size
def test_gpt_generate_deterministic(tiny_config):
"""Test that generation is deterministic with same seed."""
model = GPT(tiny_config)
model.init_weights()
model = prepare_model_for_testing(model)
model.eval()
initial_tokens = [1, 2, 3]
max_new_tokens = 5
# Generate twice with same seed
gen1 = list(model.generate(initial_tokens, max_tokens=max_new_tokens, temperature=1.0, seed=42))
gen2 = list(model.generate(initial_tokens, max_tokens=max_new_tokens, temperature=1.0, seed=42))
assert gen1 == gen2, "Generation should be deterministic with same seed"
def test_gpt_generate_greedy(tiny_config):
"""Test greedy decoding (temperature=0)."""
model = GPT(tiny_config)
model.init_weights()
model = prepare_model_for_testing(model)
model.eval()
initial_tokens = [1, 2, 3]
max_new_tokens = 5
# Greedy decoding
generated = list(model.generate(initial_tokens, max_tokens=max_new_tokens, temperature=0.0))
assert len(generated) == max_new_tokens
for token in generated:
assert 0 <= token < tiny_config.vocab_size
def test_gpt_estimate_flops(small_config):
"""Test FLOP estimation."""
model = GPT(small_config)
flops = model.estimate_flops()
# FLOPs should be positive
assert flops > 0
print(f"Estimated FLOPs per token: {flops:,}")
def test_gpt_setup_optimizers(tiny_config):
"""Test optimizer setup."""
model = GPT(tiny_config)
model.init_weights()
# Setup optimizers
optimizers = model.setup_optimizers()
# Should return a list of 2 optimizers
assert len(optimizers) == 2
# Check they have parameter groups
for opt in optimizers:
assert len(opt.param_groups) > 0
def test_gpt_mqa_shapes(small_config):
"""Test that Multi-Query Attention produces correct shapes."""
# Modify config to have different n_head and n_kv_head
config = GPTConfig(
sequence_len=32,
vocab_size=128,
n_layer=1,
n_head=8,
n_kv_head=2, # 4:1 ratio
n_embd=64,
)
model = GPT(config)
model.init_weights()
model = prepare_model_for_testing(model)
model.eval()
batch_size = 2
seq_len = 16
idx = torch.randint(0, config.vocab_size, (batch_size, seq_len))
with torch.no_grad():
logits = model(idx)
assert logits.shape == (batch_size, seq_len, config.vocab_size)
def test_gpt_long_sequence(small_config):
"""Test with sequences up to max length."""
model = GPT(small_config)
model.init_weights()
model = prepare_model_for_testing(model)
model.eval()
batch_size = 1
seq_len = small_config.sequence_len # Full sequence length
idx = torch.randint(0, small_config.vocab_size, (batch_size, seq_len))
with torch.no_grad():
logits = model(idx)
assert logits.shape == (batch_size, seq_len, small_config.vocab_size)
def test_gpt_embedding_dtype():
"""Test that embeddings are cast to bfloat16."""
config = GPTConfig(
sequence_len=32,
vocab_size=128,
n_layer=1,
n_head=2,
n_kv_head=2,
n_embd=32,
)
model = GPT(config)
# Check that embeddings are in bfloat16
assert model.transformer.wte.weight.dtype == torch.bfloat16
def test_gpt_rotary_embeddings_dtype(tiny_config):
"""Test that rotary embeddings are in bfloat16."""
model = GPT(tiny_config)
model.init_weights()
assert model.cos.dtype == torch.bfloat16
assert model.sin.dtype == torch.bfloat16
def test_gpt_loss_reduction_modes(tiny_config):
"""Test different loss reduction modes."""
model = GPT(tiny_config)
model.init_weights()
model = prepare_model_for_testing(model)
model.train()
batch_size = 2
seq_len = 8
idx = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
targets = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
# Test 'mean' reduction (default)
loss_mean = model(idx, targets=targets, loss_reduction='mean')
assert loss_mean.shape == ()
# Test 'none' reduction
loss_none = model(idx, targets=targets, loss_reduction='none')
assert loss_none.shape == (batch_size * seq_len,)
# The mean of loss_none should be close to loss_mean
torch.testing.assert_close(loss_none.mean(), loss_mean, rtol=1e-4, atol=1e-4)
def test_gpt_with_ignore_index(tiny_config):
"""Test that -1 targets are ignored in loss."""
model = GPT(tiny_config)
model.init_weights()
model = prepare_model_for_testing(model)
model.train()
batch_size = 2
seq_len = 8
idx = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
targets = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
# Compute loss with all valid targets
loss_all = model(idx, targets=targets)
# Mask out ALL targets except first token
targets_masked = targets.clone()
targets_masked[:, 1:] = -1
loss_masked = model(idx, targets=targets_masked)
# Both should be valid losses
assert loss_all.item() > 0
assert loss_masked.item() > 0
# Loss should be finite
assert not torch.isnan(loss_masked)
assert not torch.isinf(loss_masked)
# Test that all -1 targets produces no loss computation error
targets_all_masked = torch.full_like(targets, -1)
# This should still work (loss over 0 tokens)
try:
loss_all_masked = model(idx, targets=targets_all_masked)
# If it works, loss should be finite or zero
assert not torch.isnan(loss_all_masked) or loss_all_masked.item() == 0
except:
# It's okay if this raises an error - that's one way to handle no valid targets
pass