nanochat/tests/test_dataset.py
2026-05-09 12:00:37 +08:00

27 lines
914 B
Python

import pytest
from nanochat.dataset import MAX_SHARD, get_shard_indices_to_download, validate_num_workers
def test_get_shard_indices_to_download_includes_validation_shard():
assert get_shard_indices_to_download(0) == [MAX_SHARD]
assert get_shard_indices_to_download(2) == [0, 1, MAX_SHARD]
def test_get_shard_indices_to_download_caps_train_shards():
indices = get_shard_indices_to_download(MAX_SHARD + 100)
assert indices[0] == 0
assert indices[-2:] == [MAX_SHARD - 1, MAX_SHARD]
assert len(indices) == MAX_SHARD + 1
def test_get_shard_indices_to_download_rejects_negative_counts_except_all():
with pytest.raises(ValueError, match="--num-files"):
get_shard_indices_to_download(-2)
def test_validate_num_workers_requires_positive_count():
assert validate_num_workers(1) == 1
with pytest.raises(ValueError, match="--num-workers"):
validate_num_workers(0)