fix: validate dataset download counts

This commit is contained in:
陈家名 2026-05-09 12:00:37 +08:00
parent dc54a1a307
commit a7eb6155ab
2 changed files with 48 additions and 5 deletions

View File

@ -132,6 +132,21 @@ def download_single_file(index):
return False
def get_shard_indices_to_download(num_files):
"""Return train shard indices plus the always-required validation shard."""
if num_files < -1:
raise ValueError("--num-files must be -1 or a non-negative integer")
num_train_shards = MAX_SHARD if num_files == -1 else min(num_files, MAX_SHARD)
ids_to_download = list(range(num_train_shards))
ids_to_download.append(MAX_SHARD) # always download the validation shard
return ids_to_download
def validate_num_workers(num_workers):
"""Validate the multiprocessing worker count before creating a Pool."""
if num_workers < 1:
raise ValueError("--num-workers must be at least 1")
return num_workers
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download pretraining dataset shards")
@ -144,15 +159,17 @@ if __name__ == "__main__":
# The way this works is that the user specifies the number of train shards to download via the -n flag.
# In addition to that, the validation shard is *always* downloaded and is pinned to be the last shard.
num_train_shards = MAX_SHARD if args.num_files == -1 else min(args.num_files, MAX_SHARD)
ids_to_download = list(range(num_train_shards))
ids_to_download.append(MAX_SHARD) # always download the validation shard
try:
ids_to_download = get_shard_indices_to_download(args.num_files)
num_workers = validate_num_workers(args.num_workers)
except ValueError as e:
parser.error(str(e))
# Download the shards
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
print(f"Downloading {len(ids_to_download)} shards using {num_workers} workers...")
print(f"Target directory: {DATA_DIR}")
print()
with Pool(processes=args.num_workers) as pool:
with Pool(processes=num_workers) as pool:
results = pool.map(download_single_file, ids_to_download)
# Report results

26
tests/test_dataset.py Normal file
View File

@ -0,0 +1,26 @@
import pytest
from nanochat.dataset import MAX_SHARD, get_shard_indices_to_download, validate_num_workers
def test_get_shard_indices_to_download_includes_validation_shard():
assert get_shard_indices_to_download(0) == [MAX_SHARD]
assert get_shard_indices_to_download(2) == [0, 1, MAX_SHARD]
def test_get_shard_indices_to_download_caps_train_shards():
indices = get_shard_indices_to_download(MAX_SHARD + 100)
assert indices[0] == 0
assert indices[-2:] == [MAX_SHARD - 1, MAX_SHARD]
assert len(indices) == MAX_SHARD + 1
def test_get_shard_indices_to_download_rejects_negative_counts_except_all():
with pytest.raises(ValueError, match="--num-files"):
get_shard_indices_to_download(-2)
def test_validate_num_workers_requires_positive_count():
assert validate_num_workers(1) == 1
with pytest.raises(ValueError, match="--num-workers"):
validate_num_workers(0)