mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 13:45:21 +00:00
Add test coverage for all major components: - GPT model: architecture, generation, MQA, rotary embeddings (19 tests) - Inference engine: KV cache, sampling, tool use (17 tests) - Optimizers: Muon and AdamW functionality (10 tests) - Checkpoint management: save/load, metadata (5 tests) - Data loading and utilities (13 tests) docs: update README with test documentation and learning guide - Add For Students section with structured learning path - Document architectural decisions and key concepts - Add test usage instructions
414 lines
12 KiB
Python
414 lines
12 KiB
Python
"""
|
|
Tests for the GPT model architecture.
|
|
|
|
Run with:
|
|
python -m pytest tests/test_gpt.py -v -s --timeout=60
|
|
"""
|
|
|
|
import torch
|
|
import pytest
|
|
from nanochat.gpt import GPT, GPTConfig, norm, apply_rotary_emb, repeat_kv
|
|
|
|
|
|
@pytest.fixture
|
|
def small_config():
|
|
"""A small GPT config for fast testing."""
|
|
return GPTConfig(
|
|
sequence_len=128,
|
|
vocab_size=256,
|
|
n_layer=2,
|
|
n_head=4,
|
|
n_kv_head=2, # Test MQA with 2:1 ratio
|
|
n_embd=64,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def tiny_config():
|
|
"""An even tinier config for quick tests."""
|
|
return GPTConfig(
|
|
sequence_len=32,
|
|
vocab_size=128,
|
|
n_layer=1,
|
|
n_head=2,
|
|
n_kv_head=1, # Test MQA with 2:1 ratio
|
|
n_embd=32,
|
|
)
|
|
|
|
|
|
def prepare_model_for_testing(model):
|
|
"""Prepare model for CPU testing by converting to float32 but keeping rotary embeddings in bfloat16."""
|
|
model = model.float()
|
|
model.cos = model.cos.bfloat16()
|
|
model.sin = model.sin.bfloat16()
|
|
return model
|
|
|
|
|
|
def test_gpt_config():
|
|
"""Test that GPTConfig initializes with correct defaults."""
|
|
config = GPTConfig()
|
|
assert config.sequence_len == 1024
|
|
assert config.vocab_size == 50304
|
|
assert config.n_layer == 12
|
|
assert config.n_head == 6
|
|
assert config.n_kv_head == 6
|
|
assert config.n_embd == 768
|
|
|
|
|
|
def test_norm_function():
|
|
"""Test the RMSNorm function."""
|
|
x = torch.randn(2, 4, 8)
|
|
y = norm(x)
|
|
# Check shape is preserved
|
|
assert y.shape == x.shape
|
|
# Check it's actually normalized (approximately)
|
|
# RMSNorm: y = x / rms(x) where rms = sqrt(mean(x^2))
|
|
expected_rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True))
|
|
expected_y = x / expected_rms
|
|
torch.testing.assert_close(y, expected_y, rtol=1e-4, atol=1e-4)
|
|
|
|
|
|
def test_apply_rotary_emb():
|
|
"""Test rotary embeddings application."""
|
|
batch_size, num_heads, seq_len, head_dim = 2, 4, 8, 16
|
|
x = torch.randn(batch_size, seq_len, num_heads, head_dim)
|
|
|
|
# Create simple cos/sin for testing
|
|
cos = torch.ones(1, seq_len, 1, head_dim // 2)
|
|
sin = torch.zeros(1, seq_len, 1, head_dim // 2)
|
|
|
|
y = apply_rotary_emb(x, cos, sin)
|
|
|
|
# Check shape is preserved
|
|
assert y.shape == x.shape
|
|
# Check dtype is preserved
|
|
assert y.dtype == x.dtype
|
|
|
|
|
|
def test_repeat_kv():
|
|
"""Test the repeat_kv function for MQA."""
|
|
bs, n_kv_heads, slen, head_dim = 2, 2, 8, 16
|
|
n_rep = 3 # Repeat each KV head 3 times
|
|
|
|
x = torch.randn(bs, n_kv_heads, slen, head_dim)
|
|
y = repeat_kv(x, n_rep)
|
|
|
|
# Check output shape
|
|
assert y.shape == (bs, n_kv_heads * n_rep, slen, head_dim)
|
|
|
|
# Check that heads are repeated correctly
|
|
for i in range(n_kv_heads):
|
|
for j in range(n_rep):
|
|
torch.testing.assert_close(y[:, i * n_rep + j], x[:, i])
|
|
|
|
# Test n_rep=1 (no-op case)
|
|
y_no_rep = repeat_kv(x, 1)
|
|
torch.testing.assert_close(y_no_rep, x)
|
|
|
|
|
|
def test_gpt_initialization(small_config):
|
|
"""Test that GPT model initializes correctly."""
|
|
model = GPT(small_config)
|
|
|
|
# Check model has the right components
|
|
assert hasattr(model, 'transformer')
|
|
assert hasattr(model, 'lm_head')
|
|
assert len(model.transformer.h) == small_config.n_layer
|
|
|
|
# Check parameter count is reasonable
|
|
num_params = sum(p.numel() for p in model.parameters())
|
|
assert num_params > 0
|
|
print(f"Small model has {num_params:,} parameters")
|
|
|
|
|
|
def test_gpt_forward_shape(tiny_config):
|
|
"""Test that forward pass produces correct output shapes."""
|
|
model = GPT(tiny_config)
|
|
model.init_weights()
|
|
model = prepare_model_for_testing(model)
|
|
model.eval()
|
|
|
|
batch_size = 2
|
|
seq_len = 16
|
|
|
|
# Create input tokens
|
|
idx = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
|
|
|
|
# Forward pass without targets (inference mode)
|
|
with torch.no_grad():
|
|
logits = model(idx)
|
|
|
|
# Check output shape
|
|
assert logits.shape == (batch_size, seq_len, tiny_config.vocab_size)
|
|
|
|
|
|
def test_gpt_forward_with_targets(tiny_config):
|
|
"""Test forward pass with targets (training mode)."""
|
|
model = GPT(tiny_config)
|
|
model.init_weights()
|
|
model = prepare_model_for_testing(model)
|
|
model.train()
|
|
|
|
batch_size = 2
|
|
seq_len = 16
|
|
|
|
# Create input tokens and targets
|
|
idx = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
|
|
targets = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
|
|
|
|
# Forward pass with targets
|
|
loss = model(idx, targets=targets)
|
|
|
|
# Check loss is a scalar
|
|
assert loss.shape == ()
|
|
assert loss.item() > 0 # Cross-entropy loss should be positive
|
|
assert not torch.isnan(loss)
|
|
assert not torch.isinf(loss)
|
|
|
|
|
|
def test_gpt_backward(tiny_config):
|
|
"""Test that backward pass works and gradients flow."""
|
|
model = GPT(tiny_config)
|
|
model.init_weights()
|
|
model = prepare_model_for_testing(model)
|
|
model.train()
|
|
|
|
batch_size = 2
|
|
seq_len = 8
|
|
|
|
idx = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
|
|
targets = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
|
|
|
|
# Forward and backward
|
|
loss = model(idx, targets=targets)
|
|
loss.backward()
|
|
|
|
# Check that gradients are computed for all parameters
|
|
for name, param in model.named_parameters():
|
|
if param.requires_grad:
|
|
assert param.grad is not None, f"No gradient for {name}"
|
|
assert not torch.isnan(param.grad).any(), f"NaN gradient in {name}"
|
|
assert not torch.isinf(param.grad).any(), f"Inf gradient in {name}"
|
|
|
|
|
|
def test_gpt_generate(tiny_config):
|
|
"""Test autoregressive generation."""
|
|
model = GPT(tiny_config)
|
|
model.init_weights()
|
|
model = prepare_model_for_testing(model)
|
|
model.eval()
|
|
|
|
# Start with a few tokens
|
|
initial_tokens = [1, 2, 3]
|
|
max_new_tokens = 10
|
|
|
|
# Generate tokens
|
|
generated = []
|
|
for token in model.generate(initial_tokens, max_tokens=max_new_tokens, temperature=1.0, seed=42):
|
|
generated.append(token)
|
|
|
|
# Check we generated the right number of tokens
|
|
assert len(generated) == max_new_tokens
|
|
|
|
# Check all tokens are valid
|
|
for token in generated:
|
|
assert 0 <= token < tiny_config.vocab_size
|
|
|
|
|
|
def test_gpt_generate_deterministic(tiny_config):
|
|
"""Test that generation is deterministic with same seed."""
|
|
model = GPT(tiny_config)
|
|
model.init_weights()
|
|
model = prepare_model_for_testing(model)
|
|
model.eval()
|
|
|
|
initial_tokens = [1, 2, 3]
|
|
max_new_tokens = 5
|
|
|
|
# Generate twice with same seed
|
|
gen1 = list(model.generate(initial_tokens, max_tokens=max_new_tokens, temperature=1.0, seed=42))
|
|
gen2 = list(model.generate(initial_tokens, max_tokens=max_new_tokens, temperature=1.0, seed=42))
|
|
|
|
assert gen1 == gen2, "Generation should be deterministic with same seed"
|
|
|
|
|
|
def test_gpt_generate_greedy(tiny_config):
|
|
"""Test greedy decoding (temperature=0)."""
|
|
model = GPT(tiny_config)
|
|
model.init_weights()
|
|
model = prepare_model_for_testing(model)
|
|
model.eval()
|
|
|
|
initial_tokens = [1, 2, 3]
|
|
max_new_tokens = 5
|
|
|
|
# Greedy decoding
|
|
generated = list(model.generate(initial_tokens, max_tokens=max_new_tokens, temperature=0.0))
|
|
|
|
assert len(generated) == max_new_tokens
|
|
for token in generated:
|
|
assert 0 <= token < tiny_config.vocab_size
|
|
|
|
|
|
def test_gpt_estimate_flops(small_config):
|
|
"""Test FLOP estimation."""
|
|
model = GPT(small_config)
|
|
flops = model.estimate_flops()
|
|
|
|
# FLOPs should be positive
|
|
assert flops > 0
|
|
print(f"Estimated FLOPs per token: {flops:,}")
|
|
|
|
|
|
def test_gpt_setup_optimizers(tiny_config):
|
|
"""Test optimizer setup."""
|
|
model = GPT(tiny_config)
|
|
model.init_weights()
|
|
|
|
# Setup optimizers
|
|
optimizers = model.setup_optimizers()
|
|
|
|
# Should return a list of 2 optimizers
|
|
assert len(optimizers) == 2
|
|
|
|
# Check they have parameter groups
|
|
for opt in optimizers:
|
|
assert len(opt.param_groups) > 0
|
|
|
|
|
|
def test_gpt_mqa_shapes(small_config):
|
|
"""Test that Multi-Query Attention produces correct shapes."""
|
|
# Modify config to have different n_head and n_kv_head
|
|
config = GPTConfig(
|
|
sequence_len=32,
|
|
vocab_size=128,
|
|
n_layer=1,
|
|
n_head=8,
|
|
n_kv_head=2, # 4:1 ratio
|
|
n_embd=64,
|
|
)
|
|
|
|
model = GPT(config)
|
|
model.init_weights()
|
|
model = prepare_model_for_testing(model)
|
|
model.eval()
|
|
|
|
batch_size = 2
|
|
seq_len = 16
|
|
idx = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
|
|
|
with torch.no_grad():
|
|
logits = model(idx)
|
|
|
|
assert logits.shape == (batch_size, seq_len, config.vocab_size)
|
|
|
|
|
|
def test_gpt_long_sequence(small_config):
|
|
"""Test with sequences up to max length."""
|
|
model = GPT(small_config)
|
|
model.init_weights()
|
|
model = prepare_model_for_testing(model)
|
|
model.eval()
|
|
|
|
batch_size = 1
|
|
seq_len = small_config.sequence_len # Full sequence length
|
|
|
|
idx = torch.randint(0, small_config.vocab_size, (batch_size, seq_len))
|
|
|
|
with torch.no_grad():
|
|
logits = model(idx)
|
|
|
|
assert logits.shape == (batch_size, seq_len, small_config.vocab_size)
|
|
|
|
|
|
def test_gpt_embedding_dtype():
|
|
"""Test that embeddings are cast to bfloat16."""
|
|
config = GPTConfig(
|
|
sequence_len=32,
|
|
vocab_size=128,
|
|
n_layer=1,
|
|
n_head=2,
|
|
n_kv_head=2,
|
|
n_embd=32,
|
|
)
|
|
|
|
model = GPT(config)
|
|
|
|
# Check that embeddings are in bfloat16
|
|
assert model.transformer.wte.weight.dtype == torch.bfloat16
|
|
|
|
|
|
def test_gpt_rotary_embeddings_dtype(tiny_config):
|
|
"""Test that rotary embeddings are in bfloat16."""
|
|
model = GPT(tiny_config)
|
|
model.init_weights()
|
|
|
|
assert model.cos.dtype == torch.bfloat16
|
|
assert model.sin.dtype == torch.bfloat16
|
|
|
|
|
|
def test_gpt_loss_reduction_modes(tiny_config):
|
|
"""Test different loss reduction modes."""
|
|
model = GPT(tiny_config)
|
|
model.init_weights()
|
|
model = prepare_model_for_testing(model)
|
|
model.train()
|
|
|
|
batch_size = 2
|
|
seq_len = 8
|
|
|
|
idx = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
|
|
targets = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
|
|
|
|
# Test 'mean' reduction (default)
|
|
loss_mean = model(idx, targets=targets, loss_reduction='mean')
|
|
assert loss_mean.shape == ()
|
|
|
|
# Test 'none' reduction
|
|
loss_none = model(idx, targets=targets, loss_reduction='none')
|
|
assert loss_none.shape == (batch_size * seq_len,)
|
|
|
|
# The mean of loss_none should be close to loss_mean
|
|
torch.testing.assert_close(loss_none.mean(), loss_mean, rtol=1e-4, atol=1e-4)
|
|
|
|
|
|
def test_gpt_with_ignore_index(tiny_config):
|
|
"""Test that -1 targets are ignored in loss."""
|
|
model = GPT(tiny_config)
|
|
model.init_weights()
|
|
model = prepare_model_for_testing(model)
|
|
model.train()
|
|
|
|
batch_size = 2
|
|
seq_len = 8
|
|
|
|
idx = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
|
|
targets = torch.randint(0, tiny_config.vocab_size, (batch_size, seq_len))
|
|
|
|
# Compute loss with all valid targets
|
|
loss_all = model(idx, targets=targets)
|
|
|
|
# Mask out ALL targets except first token
|
|
targets_masked = targets.clone()
|
|
targets_masked[:, 1:] = -1
|
|
loss_masked = model(idx, targets=targets_masked)
|
|
|
|
# Both should be valid losses
|
|
assert loss_all.item() > 0
|
|
assert loss_masked.item() > 0
|
|
# Loss should be finite
|
|
assert not torch.isnan(loss_masked)
|
|
assert not torch.isinf(loss_masked)
|
|
|
|
# Test that all -1 targets produces no loss computation error
|
|
targets_all_masked = torch.full_like(targets, -1)
|
|
# This should still work (loss over 0 tokens)
|
|
try:
|
|
loss_all_masked = model(idx, targets=targets_all_masked)
|
|
# If it works, loss should be finite or zero
|
|
assert not torch.isnan(loss_all_masked) or loss_all_masked.item() == 0
|
|
except:
|
|
# It's okay if this raises an error - that's one way to handle no valid targets
|
|
pass
|
|
|