mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 13:45:21 +00:00
Add test coverage for all major components: - GPT model: architecture, generation, MQA, rotary embeddings (19 tests) - Inference engine: KV cache, sampling, tool use (17 tests) - Optimizers: Muon and AdamW functionality (10 tests) - Checkpoint management: save/load, metadata (5 tests) - Data loading and utilities (13 tests) docs: update README with test documentation and learning guide - Add For Students section with structured learning path - Document architectural decisions and key concepts - Add test usage instructions
118 lines
3.1 KiB
Python
118 lines
3.1 KiB
Python
"""
|
|
Tests for data loading functionality.
|
|
|
|
Note: The actual dataloader requires CUDA and parquet files, so these tests
|
|
are simplified to test the core concepts.
|
|
|
|
Run with:
|
|
python -m pytest tests/test_dataloader.py -v -s
|
|
"""
|
|
|
|
import torch
|
|
import pytest
|
|
|
|
|
|
def test_batch_creation():
|
|
"""Test creating batches from token sequences."""
|
|
# Simulate what the dataloader does internally
|
|
tokens = list(range(100))
|
|
batch_size = 4
|
|
seq_len = 10
|
|
|
|
# Need batch_size * seq_len + 1 tokens for inputs and targets
|
|
needed = batch_size * seq_len + 1
|
|
assert len(tokens) >= needed
|
|
|
|
# Create inputs and targets
|
|
inputs = torch.tensor(tokens[:batch_size * seq_len]).view(batch_size, seq_len)
|
|
targets = torch.tensor(tokens[1:batch_size * seq_len + 1]).view(batch_size, seq_len)
|
|
|
|
# Check shapes
|
|
assert inputs.shape == (batch_size, seq_len)
|
|
assert targets.shape == (batch_size, seq_len)
|
|
|
|
# Check that targets are shifted by 1
|
|
assert targets[0, 0] == inputs[0, 1]
|
|
|
|
|
|
def test_token_buffer_simulation():
|
|
"""Test token buffering logic."""
|
|
from collections import deque
|
|
|
|
token_buffer = deque()
|
|
|
|
# Simulate adding tokens
|
|
for i in range(100):
|
|
token_buffer.append(i)
|
|
|
|
assert len(token_buffer) == 100
|
|
|
|
# Simulate consuming tokens
|
|
needed = 50
|
|
consumed = []
|
|
for _ in range(needed):
|
|
consumed.append(token_buffer.popleft())
|
|
|
|
assert len(consumed) == needed
|
|
assert len(token_buffer) == 50
|
|
assert consumed[0] == 0
|
|
assert consumed[-1] == 49
|
|
|
|
|
|
def test_distributed_rank_sharding():
|
|
"""Test how data is distributed across ranks."""
|
|
total_shards = 8
|
|
world_size = 4
|
|
|
|
# Each rank gets every world_size'th shard
|
|
for rank in range(world_size):
|
|
shards = list(range(rank, total_shards, world_size))
|
|
assert len(shards) == total_shards // world_size
|
|
|
|
|
|
def test_sequence_packing():
|
|
"""Test packing tokens into sequences."""
|
|
# Simulate the reshape operation in dataloader
|
|
batch_size = 2
|
|
seq_len = 4
|
|
|
|
# Flat token sequence
|
|
tokens = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
|
|
|
# Pack into batch
|
|
batch = tokens.view(batch_size, seq_len)
|
|
|
|
assert batch.shape == (batch_size, seq_len)
|
|
assert batch[0, 0] == 0
|
|
assert batch[0, -1] == 3
|
|
assert batch[1, 0] == 4
|
|
assert batch[1, -1] == 7
|
|
|
|
|
|
def test_input_target_alignment():
|
|
"""Test that inputs and targets are properly aligned."""
|
|
seq_len = 10
|
|
tokens = list(range(20))
|
|
|
|
# Inputs: tokens[:-1]
|
|
# Targets: tokens[1:]
|
|
inputs = tokens[:seq_len]
|
|
targets = tokens[1:seq_len + 1]
|
|
|
|
# Each target should be the next token after corresponding input
|
|
for i in range(seq_len):
|
|
assert targets[i] == inputs[i] + 1
|
|
|
|
|
|
def test_bos_token_prepending():
|
|
"""Test BOS token prepending logic."""
|
|
# Simulate what tokenizer does with prepend
|
|
bos_token = 255
|
|
text_tokens = [10, 20, 30, 40]
|
|
|
|
# With prepend
|
|
tokens_with_bos = [bos_token] + text_tokens
|
|
|
|
assert tokens_with_bos[0] == bos_token
|
|
assert len(tokens_with_bos) == len(text_tokens) + 1
|