diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 15c32f7..3032462 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -1,117 +1,243 @@ """ Tests for data loading functionality. -Note: The actual dataloader requires CUDA and parquet files, so these tests -are simplified to test the core concepts. - Run with: python -m pytest tests/test_dataloader.py -v -s """ import torch import pytest +from collections import deque +from unittest.mock import Mock, patch, MagicMock +from nanochat.dataloader import tokenizing_distributed_data_loader -def test_batch_creation(): - """Test creating batches from token sequences.""" - # Simulate what the dataloader does internally - tokens = list(range(100)) - batch_size = 4 - seq_len = 10 +@pytest.fixture +def mock_tokenizer(): + """Create a mock tokenizer.""" + tokenizer = Mock() + tokenizer.get_bos_token_id.return_value = 255 + tokenizer.encode.return_value = [ + [255, 1, 2, 3, 4], + [255, 5, 6, 7, 8], + [255, 9, 10, 11, 12], + ] + return tokenizer + + +@pytest.fixture +def mock_parquet_data(): + """Mock parquet data iterator.""" + def mock_iter(split, start, step): + # Yield a few batches of mock documents + for i in range(3): + yield [f"Document {i*3}", f"Document {i*3+1}", f"Document {i*3+2}"] + return mock_iter + + +def test_dataloader_initialization(): + """Test that dataloader can be initialized with proper mocks.""" + with patch('nanochat.dataloader.get_dist_info') as mock_dist: + with patch('nanochat.dataloader.parquets_iter_batched'): + with patch('nanochat.dataloader.get_tokenizer') as mock_tok: + with patch('torch.Tensor.to', return_value=torch.tensor([[1, 2], [3, 4]])): + mock_dist.return_value = (False, 0, 0, 1) + mock_tokenizer = Mock() + mock_tokenizer.get_bos_token_id.return_value = 255 + mock_tokenizer.encode.return_value = [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] + mock_tok.return_value = mock_tokenizer + + loader = tokenizing_distributed_data_loader(B=2, T=4, split="train") + assert loader is not None + + +def test_dataloader_batch_shapes(): + """Test that dataloader produces correct batch shapes.""" + with patch('nanochat.dataloader.get_dist_info') as mock_dist: + with patch('nanochat.dataloader.parquets_iter_batched') as mock_parquet: + with patch('nanochat.dataloader.get_tokenizer') as mock_tok: + with patch('torch.empty') as mock_empty: + mock_dist.return_value = (False, 0, 0, 1) + + # Mock parquet to return documents + mock_parquet.return_value = iter([ + ["doc1", "doc2", "doc3"], + ["doc4", "doc5", "doc6"], + ]) + + # Mock tokenizer to return tokens + mock_tokenizer = Mock() + mock_tokenizer.get_bos_token_id.return_value = 255 + # Return enough tokens for at least one batch (B=2, T=3 needs 7 tokens) + mock_tokenizer.encode.return_value = [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [11, 12, 13, 14, 15, 16, 17, 18, 19, 20], + [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], + ] + mock_tok.return_value = mock_tokenizer + + B, T = 2, 3 + + # Mock the scratch buffer + mock_empty.return_value = torch.zeros(B * T + 1, dtype=torch.int64) + + # Mock cuda tensors + with patch('torch.Tensor.to') as mock_to: + def to_side_effect(*args, **kwargs): + # Return a properly shaped tensor + if kwargs.get('device') == 'cuda': + shape = (B, T) + dtype = kwargs.get('dtype', torch.int64) + return torch.zeros(shape, dtype=dtype) + return torch.zeros((B, T)) + + mock_to.side_effect = to_side_effect + + loader = tokenizing_distributed_data_loader(B=B, T=T, split="train") + inputs, targets = next(loader) + + assert inputs.shape == (B, T) + assert targets.shape == (B, T) + + +def test_dataloader_token_shifting(): + """Test that targets are shifted by 1 position from inputs.""" + B, T = 2, 4 + needed_tokens = B * T + 1 # 9 tokens - # Need batch_size * seq_len + 1 tokens for inputs and targets - needed = batch_size * seq_len + 1 - assert len(tokens) >= needed + # Create a sequence where we can verify the shift + tokens = list(range(100, 100 + needed_tokens)) # [100, 101, 102, ..., 108] - # Create inputs and targets - inputs = torch.tensor(tokens[:batch_size * seq_len]).view(batch_size, seq_len) - targets = torch.tensor(tokens[1:batch_size * seq_len + 1]).view(batch_size, seq_len) + # Simulate what the dataloader does + inputs = torch.tensor(tokens[:-1]).view(B, T) + targets = torch.tensor(tokens[1:]).view(B, T) # Check shapes - assert inputs.shape == (batch_size, seq_len) - assert targets.shape == (batch_size, seq_len) + assert inputs.shape == (B, T) + assert targets.shape == (B, T) - # Check that targets are shifted by 1 - assert targets[0, 0] == inputs[0, 1] + # Check shifting: targets[i] should equal inputs[i+1] (within the flat view) + inputs_flat = inputs.reshape(-1) + targets_flat = targets.reshape(-1) + + # First element of targets should be second element from original sequence + assert targets_flat[0].item() == 101 + assert inputs_flat[0].item() == 100 -def test_token_buffer_simulation(): - """Test token buffering logic.""" - from collections import deque - +def test_dataloader_distributed_sharding(): + """Test that different ranks get different shards.""" + with patch('nanochat.dataloader.get_dist_info') as mock_dist: + with patch('nanochat.dataloader.parquets_iter_batched') as mock_parquet: + with patch('nanochat.dataloader.get_tokenizer') as mock_tok: + with patch('torch.empty') as mock_empty: + # Simulate rank 1 out of 4 + mock_dist.return_value = (True, 1, 1, 4) + + # Track what start/step values parquets_iter_batched is called with + calls = [] + def track_parquet_call(split, start, step): + calls.append((split, start, step)) + return iter([["doc1", "doc2"]]) + + mock_parquet.side_effect = track_parquet_call + + mock_tokenizer = Mock() + mock_tokenizer.get_bos_token_id.return_value = 255 + mock_tokenizer.encode.return_value = [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] + mock_tok.return_value = mock_tokenizer + + # Mock the scratch buffer + mock_empty.return_value = torch.zeros(7, dtype=torch.int64) + + with patch('torch.Tensor.to', return_value=torch.zeros((2, 3))): + loader = tokenizing_distributed_data_loader(B=2, T=3, split="train") + next(loader) + + # Verify parquets_iter_batched was called with correct rank/world_size + assert len(calls) > 0 + split, start, step = calls[0] + assert start == 1 # rank 1 + assert step == 4 # world_size 4 + + +def test_dataloader_token_buffer_accumulation(): + """Test token buffer accumulation logic.""" token_buffer = deque() + B, T = 2, 3 + needed_tokens = B * T + 1 # 7 tokens - # Simulate adding tokens - for i in range(100): - token_buffer.append(i) + # Simulate adding tokens from documents + doc1_tokens = [1, 2, 3] + doc2_tokens = [4, 5, 6, 7, 8] - assert len(token_buffer) == 100 + token_buffer.extend(doc1_tokens) + assert len(token_buffer) < needed_tokens - # Simulate consuming tokens - needed = 50 - consumed = [] - for _ in range(needed): - consumed.append(token_buffer.popleft()) + token_buffer.extend(doc2_tokens) + assert len(token_buffer) >= needed_tokens - assert len(consumed) == needed - assert len(token_buffer) == 50 - assert consumed[0] == 0 - assert consumed[-1] == 49 + # Extract tokens for one batch + batch_tokens = [] + for _ in range(needed_tokens): + batch_tokens.append(token_buffer.popleft()) + + assert len(batch_tokens) == needed_tokens + assert batch_tokens == [1, 2, 3, 4, 5, 6, 7] + assert len(token_buffer) == 1 # One token remaining -def test_distributed_rank_sharding(): - """Test how data is distributed across ranks.""" - total_shards = 8 - world_size = 4 - - # Each rank gets every world_size'th shard - for rank in range(world_size): - shards = list(range(rank, total_shards, world_size)) - assert len(shards) == total_shards // world_size +def test_dataloader_split_validation(): + """Test that invalid split values raise an error.""" + # The function is a generator, so the assertion only runs when next() is called + with pytest.raises(AssertionError, match="split must be"): + loader = tokenizing_distributed_data_loader(B=2, T=4, split="invalid") + next(loader) # This triggers the function body to execute -def test_sequence_packing(): - """Test packing tokens into sequences.""" - # Simulate the reshape operation in dataloader - batch_size = 2 - seq_len = 4 - - # Flat token sequence - tokens = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - - # Pack into batch - batch = tokens.view(batch_size, seq_len) - - assert batch.shape == (batch_size, seq_len) - assert batch[0, 0] == 0 - assert batch[0, -1] == 3 - assert batch[1, 0] == 4 - assert batch[1, -1] == 7 +def test_dataloader_bos_token_prepending(): + """Test that BOS tokens are properly prepended.""" + with patch('nanochat.dataloader.get_dist_info') as mock_dist: + with patch('nanochat.dataloader.parquets_iter_batched') as mock_parquet: + with patch('nanochat.dataloader.get_tokenizer') as mock_tok: + with patch('torch.empty') as mock_empty: + mock_dist.return_value = (False, 0, 0, 1) + mock_parquet.return_value = iter([["doc1"]]) + + mock_tokenizer = Mock() + bos_token = 255 + mock_tokenizer.get_bos_token_id.return_value = bos_token + mock_tokenizer.encode.return_value = [[bos_token, 1, 2, 3, 4, 5, 6, 7, 8, 9]] + mock_tok.return_value = mock_tokenizer + + # Mock the scratch buffer + mock_empty.return_value = torch.zeros(7, dtype=torch.int64) + + with patch('torch.Tensor.to', return_value=torch.zeros((2, 3))): + loader = tokenizing_distributed_data_loader(B=2, T=3, split="train") + next(loader) + + # Verify tokenizer.encode was called with prepend=bos_token + mock_tokenizer.encode.assert_called() + call_kwargs = mock_tokenizer.encode.call_args[1] + assert 'prepend' in call_kwargs + assert call_kwargs['prepend'] == bos_token -def test_input_target_alignment(): - """Test that inputs and targets are properly aligned.""" - seq_len = 10 - tokens = list(range(20)) +def test_dataloader_needed_tokens_calculation(): + """Test that the dataloader calculates needed tokens correctly.""" + B, T = 4, 16 + needed_tokens = B * T + 1 - # Inputs: tokens[:-1] - # Targets: tokens[1:] - inputs = tokens[:seq_len] - targets = tokens[1:seq_len + 1] + # We need B*T tokens for inputs, plus 1 for the last target + assert needed_tokens == 65 - # Each target should be the next token after corresponding input - for i in range(seq_len): - assert targets[i] == inputs[i] + 1 - - -def test_bos_token_prepending(): - """Test BOS token prepending logic.""" - # Simulate what tokenizer does with prepend - bos_token = 255 - text_tokens = [10, 20, 30, 40] + # The scratch buffer should be exactly this size + scratch = torch.empty(needed_tokens, dtype=torch.int64) + assert scratch.shape == (65,) - # With prepend - tokens_with_bos = [bos_token] + text_tokens - - assert tokens_with_bos[0] == bos_token - assert len(tokens_with_bos) == len(text_tokens) + 1 + # After slicing, inputs should be B*T and targets should be B*T + inputs = scratch[:-1] + targets = scratch[1:] + assert len(inputs) == B * T + assert len(targets) == B * T