Fix test_dataloader.py to test actual dataloader implementation

This commit is contained in:
Rimom Costa 2025-10-13 19:39:10 +01:00
parent 44764ffff0
commit 992e73b055

View File

@ -1,117 +1,243 @@
"""
Tests for data loading functionality.
Note: The actual dataloader requires CUDA and parquet files, so these tests
are simplified to test the core concepts.
Run with:
python -m pytest tests/test_dataloader.py -v -s
"""
import torch
import pytest
from collections import deque
from unittest.mock import Mock, patch, MagicMock
from nanochat.dataloader import tokenizing_distributed_data_loader
def test_batch_creation():
"""Test creating batches from token sequences."""
# Simulate what the dataloader does internally
tokens = list(range(100))
batch_size = 4
seq_len = 10
@pytest.fixture
def mock_tokenizer():
"""Create a mock tokenizer."""
tokenizer = Mock()
tokenizer.get_bos_token_id.return_value = 255
tokenizer.encode.return_value = [
[255, 1, 2, 3, 4],
[255, 5, 6, 7, 8],
[255, 9, 10, 11, 12],
]
return tokenizer
@pytest.fixture
def mock_parquet_data():
"""Mock parquet data iterator."""
def mock_iter(split, start, step):
# Yield a few batches of mock documents
for i in range(3):
yield [f"Document {i*3}", f"Document {i*3+1}", f"Document {i*3+2}"]
return mock_iter
def test_dataloader_initialization():
"""Test that dataloader can be initialized with proper mocks."""
with patch('nanochat.dataloader.get_dist_info') as mock_dist:
with patch('nanochat.dataloader.parquets_iter_batched'):
with patch('nanochat.dataloader.get_tokenizer') as mock_tok:
with patch('torch.Tensor.to', return_value=torch.tensor([[1, 2], [3, 4]])):
mock_dist.return_value = (False, 0, 0, 1)
mock_tokenizer = Mock()
mock_tokenizer.get_bos_token_id.return_value = 255
mock_tokenizer.encode.return_value = [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
mock_tok.return_value = mock_tokenizer
loader = tokenizing_distributed_data_loader(B=2, T=4, split="train")
assert loader is not None
def test_dataloader_batch_shapes():
"""Test that dataloader produces correct batch shapes."""
with patch('nanochat.dataloader.get_dist_info') as mock_dist:
with patch('nanochat.dataloader.parquets_iter_batched') as mock_parquet:
with patch('nanochat.dataloader.get_tokenizer') as mock_tok:
with patch('torch.empty') as mock_empty:
mock_dist.return_value = (False, 0, 0, 1)
# Mock parquet to return documents
mock_parquet.return_value = iter([
["doc1", "doc2", "doc3"],
["doc4", "doc5", "doc6"],
])
# Mock tokenizer to return tokens
mock_tokenizer = Mock()
mock_tokenizer.get_bos_token_id.return_value = 255
# Return enough tokens for at least one batch (B=2, T=3 needs 7 tokens)
mock_tokenizer.encode.return_value = [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
[21, 22, 23, 24, 25, 26, 27, 28, 29, 30],
]
mock_tok.return_value = mock_tokenizer
B, T = 2, 3
# Mock the scratch buffer
mock_empty.return_value = torch.zeros(B * T + 1, dtype=torch.int64)
# Mock cuda tensors
with patch('torch.Tensor.to') as mock_to:
def to_side_effect(*args, **kwargs):
# Return a properly shaped tensor
if kwargs.get('device') == 'cuda':
shape = (B, T)
dtype = kwargs.get('dtype', torch.int64)
return torch.zeros(shape, dtype=dtype)
return torch.zeros((B, T))
mock_to.side_effect = to_side_effect
loader = tokenizing_distributed_data_loader(B=B, T=T, split="train")
inputs, targets = next(loader)
assert inputs.shape == (B, T)
assert targets.shape == (B, T)
def test_dataloader_token_shifting():
"""Test that targets are shifted by 1 position from inputs."""
B, T = 2, 4
needed_tokens = B * T + 1 # 9 tokens
# Need batch_size * seq_len + 1 tokens for inputs and targets
needed = batch_size * seq_len + 1
assert len(tokens) >= needed
# Create a sequence where we can verify the shift
tokens = list(range(100, 100 + needed_tokens)) # [100, 101, 102, ..., 108]
# Create inputs and targets
inputs = torch.tensor(tokens[:batch_size * seq_len]).view(batch_size, seq_len)
targets = torch.tensor(tokens[1:batch_size * seq_len + 1]).view(batch_size, seq_len)
# Simulate what the dataloader does
inputs = torch.tensor(tokens[:-1]).view(B, T)
targets = torch.tensor(tokens[1:]).view(B, T)
# Check shapes
assert inputs.shape == (batch_size, seq_len)
assert targets.shape == (batch_size, seq_len)
assert inputs.shape == (B, T)
assert targets.shape == (B, T)
# Check that targets are shifted by 1
assert targets[0, 0] == inputs[0, 1]
# Check shifting: targets[i] should equal inputs[i+1] (within the flat view)
inputs_flat = inputs.reshape(-1)
targets_flat = targets.reshape(-1)
# First element of targets should be second element from original sequence
assert targets_flat[0].item() == 101
assert inputs_flat[0].item() == 100
def test_token_buffer_simulation():
"""Test token buffering logic."""
from collections import deque
def test_dataloader_distributed_sharding():
"""Test that different ranks get different shards."""
with patch('nanochat.dataloader.get_dist_info') as mock_dist:
with patch('nanochat.dataloader.parquets_iter_batched') as mock_parquet:
with patch('nanochat.dataloader.get_tokenizer') as mock_tok:
with patch('torch.empty') as mock_empty:
# Simulate rank 1 out of 4
mock_dist.return_value = (True, 1, 1, 4)
# Track what start/step values parquets_iter_batched is called with
calls = []
def track_parquet_call(split, start, step):
calls.append((split, start, step))
return iter([["doc1", "doc2"]])
mock_parquet.side_effect = track_parquet_call
mock_tokenizer = Mock()
mock_tokenizer.get_bos_token_id.return_value = 255
mock_tokenizer.encode.return_value = [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
mock_tok.return_value = mock_tokenizer
# Mock the scratch buffer
mock_empty.return_value = torch.zeros(7, dtype=torch.int64)
with patch('torch.Tensor.to', return_value=torch.zeros((2, 3))):
loader = tokenizing_distributed_data_loader(B=2, T=3, split="train")
next(loader)
# Verify parquets_iter_batched was called with correct rank/world_size
assert len(calls) > 0
split, start, step = calls[0]
assert start == 1 # rank 1
assert step == 4 # world_size 4
def test_dataloader_token_buffer_accumulation():
"""Test token buffer accumulation logic."""
token_buffer = deque()
B, T = 2, 3
needed_tokens = B * T + 1 # 7 tokens
# Simulate adding tokens
for i in range(100):
token_buffer.append(i)
# Simulate adding tokens from documents
doc1_tokens = [1, 2, 3]
doc2_tokens = [4, 5, 6, 7, 8]
assert len(token_buffer) == 100
token_buffer.extend(doc1_tokens)
assert len(token_buffer) < needed_tokens
# Simulate consuming tokens
needed = 50
consumed = []
for _ in range(needed):
consumed.append(token_buffer.popleft())
token_buffer.extend(doc2_tokens)
assert len(token_buffer) >= needed_tokens
assert len(consumed) == needed
assert len(token_buffer) == 50
assert consumed[0] == 0
assert consumed[-1] == 49
# Extract tokens for one batch
batch_tokens = []
for _ in range(needed_tokens):
batch_tokens.append(token_buffer.popleft())
assert len(batch_tokens) == needed_tokens
assert batch_tokens == [1, 2, 3, 4, 5, 6, 7]
assert len(token_buffer) == 1 # One token remaining
def test_distributed_rank_sharding():
"""Test how data is distributed across ranks."""
total_shards = 8
world_size = 4
# Each rank gets every world_size'th shard
for rank in range(world_size):
shards = list(range(rank, total_shards, world_size))
assert len(shards) == total_shards // world_size
def test_dataloader_split_validation():
"""Test that invalid split values raise an error."""
# The function is a generator, so the assertion only runs when next() is called
with pytest.raises(AssertionError, match="split must be"):
loader = tokenizing_distributed_data_loader(B=2, T=4, split="invalid")
next(loader) # This triggers the function body to execute
def test_sequence_packing():
"""Test packing tokens into sequences."""
# Simulate the reshape operation in dataloader
batch_size = 2
seq_len = 4
# Flat token sequence
tokens = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
# Pack into batch
batch = tokens.view(batch_size, seq_len)
assert batch.shape == (batch_size, seq_len)
assert batch[0, 0] == 0
assert batch[0, -1] == 3
assert batch[1, 0] == 4
assert batch[1, -1] == 7
def test_dataloader_bos_token_prepending():
"""Test that BOS tokens are properly prepended."""
with patch('nanochat.dataloader.get_dist_info') as mock_dist:
with patch('nanochat.dataloader.parquets_iter_batched') as mock_parquet:
with patch('nanochat.dataloader.get_tokenizer') as mock_tok:
with patch('torch.empty') as mock_empty:
mock_dist.return_value = (False, 0, 0, 1)
mock_parquet.return_value = iter([["doc1"]])
mock_tokenizer = Mock()
bos_token = 255
mock_tokenizer.get_bos_token_id.return_value = bos_token
mock_tokenizer.encode.return_value = [[bos_token, 1, 2, 3, 4, 5, 6, 7, 8, 9]]
mock_tok.return_value = mock_tokenizer
# Mock the scratch buffer
mock_empty.return_value = torch.zeros(7, dtype=torch.int64)
with patch('torch.Tensor.to', return_value=torch.zeros((2, 3))):
loader = tokenizing_distributed_data_loader(B=2, T=3, split="train")
next(loader)
# Verify tokenizer.encode was called with prepend=bos_token
mock_tokenizer.encode.assert_called()
call_kwargs = mock_tokenizer.encode.call_args[1]
assert 'prepend' in call_kwargs
assert call_kwargs['prepend'] == bos_token
def test_input_target_alignment():
"""Test that inputs and targets are properly aligned."""
seq_len = 10
tokens = list(range(20))
def test_dataloader_needed_tokens_calculation():
"""Test that the dataloader calculates needed tokens correctly."""
B, T = 4, 16
needed_tokens = B * T + 1
# Inputs: tokens[:-1]
# Targets: tokens[1:]
inputs = tokens[:seq_len]
targets = tokens[1:seq_len + 1]
# We need B*T tokens for inputs, plus 1 for the last target
assert needed_tokens == 65
# Each target should be the next token after corresponding input
for i in range(seq_len):
assert targets[i] == inputs[i] + 1
def test_bos_token_prepending():
"""Test BOS token prepending logic."""
# Simulate what tokenizer does with prepend
bos_token = 255
text_tokens = [10, 20, 30, 40]
# The scratch buffer should be exactly this size
scratch = torch.empty(needed_tokens, dtype=torch.int64)
assert scratch.shape == (65,)
# With prepend
tokens_with_bos = [bos_token] + text_tokens
assert tokens_with_bos[0] == bos_token
assert len(tokens_with_bos) == len(text_tokens) + 1
# After slicing, inputs should be B*T and targets should be B*T
inputs = scratch[:-1]
targets = scratch[1:]
assert len(inputs) == B * T
assert len(targets) == B * T