mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-14 11:47:34 +00:00
27 lines
914 B
Python
27 lines
914 B
Python
import pytest
|
|
|
|
from nanochat.dataset import MAX_SHARD, get_shard_indices_to_download, validate_num_workers
|
|
|
|
|
|
def test_get_shard_indices_to_download_includes_validation_shard():
|
|
assert get_shard_indices_to_download(0) == [MAX_SHARD]
|
|
assert get_shard_indices_to_download(2) == [0, 1, MAX_SHARD]
|
|
|
|
|
|
def test_get_shard_indices_to_download_caps_train_shards():
|
|
indices = get_shard_indices_to_download(MAX_SHARD + 100)
|
|
assert indices[0] == 0
|
|
assert indices[-2:] == [MAX_SHARD - 1, MAX_SHARD]
|
|
assert len(indices) == MAX_SHARD + 1
|
|
|
|
|
|
def test_get_shard_indices_to_download_rejects_negative_counts_except_all():
|
|
with pytest.raises(ValueError, match="--num-files"):
|
|
get_shard_indices_to_download(-2)
|
|
|
|
|
|
def test_validate_num_workers_requires_positive_count():
|
|
assert validate_num_workers(1) == 1
|
|
with pytest.raises(ValueError, match="--num-workers"):
|
|
validate_num_workers(0)
|