mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-11 02:10:13 +00:00
fix: validate dataset download counts
This commit is contained in:
parent
dc54a1a307
commit
a7eb6155ab
|
|
@ -132,6 +132,21 @@ def download_single_file(index):
|
|||
|
||||
return False
|
||||
|
||||
def get_shard_indices_to_download(num_files):
|
||||
"""Return train shard indices plus the always-required validation shard."""
|
||||
if num_files < -1:
|
||||
raise ValueError("--num-files must be -1 or a non-negative integer")
|
||||
num_train_shards = MAX_SHARD if num_files == -1 else min(num_files, MAX_SHARD)
|
||||
ids_to_download = list(range(num_train_shards))
|
||||
ids_to_download.append(MAX_SHARD) # always download the validation shard
|
||||
return ids_to_download
|
||||
|
||||
def validate_num_workers(num_workers):
|
||||
"""Validate the multiprocessing worker count before creating a Pool."""
|
||||
if num_workers < 1:
|
||||
raise ValueError("--num-workers must be at least 1")
|
||||
return num_workers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Download pretraining dataset shards")
|
||||
|
|
@ -144,15 +159,17 @@ if __name__ == "__main__":
|
|||
|
||||
# The way this works is that the user specifies the number of train shards to download via the -n flag.
|
||||
# In addition to that, the validation shard is *always* downloaded and is pinned to be the last shard.
|
||||
num_train_shards = MAX_SHARD if args.num_files == -1 else min(args.num_files, MAX_SHARD)
|
||||
ids_to_download = list(range(num_train_shards))
|
||||
ids_to_download.append(MAX_SHARD) # always download the validation shard
|
||||
try:
|
||||
ids_to_download = get_shard_indices_to_download(args.num_files)
|
||||
num_workers = validate_num_workers(args.num_workers)
|
||||
except ValueError as e:
|
||||
parser.error(str(e))
|
||||
|
||||
# Download the shards
|
||||
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
||||
print(f"Downloading {len(ids_to_download)} shards using {num_workers} workers...")
|
||||
print(f"Target directory: {DATA_DIR}")
|
||||
print()
|
||||
with Pool(processes=args.num_workers) as pool:
|
||||
with Pool(processes=num_workers) as pool:
|
||||
results = pool.map(download_single_file, ids_to_download)
|
||||
|
||||
# Report results
|
||||
|
|
|
|||
26
tests/test_dataset.py
Normal file
26
tests/test_dataset.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import pytest
|
||||
|
||||
from nanochat.dataset import MAX_SHARD, get_shard_indices_to_download, validate_num_workers
|
||||
|
||||
|
||||
def test_get_shard_indices_to_download_includes_validation_shard():
|
||||
assert get_shard_indices_to_download(0) == [MAX_SHARD]
|
||||
assert get_shard_indices_to_download(2) == [0, 1, MAX_SHARD]
|
||||
|
||||
|
||||
def test_get_shard_indices_to_download_caps_train_shards():
|
||||
indices = get_shard_indices_to_download(MAX_SHARD + 100)
|
||||
assert indices[0] == 0
|
||||
assert indices[-2:] == [MAX_SHARD - 1, MAX_SHARD]
|
||||
assert len(indices) == MAX_SHARD + 1
|
||||
|
||||
|
||||
def test_get_shard_indices_to_download_rejects_negative_counts_except_all():
|
||||
with pytest.raises(ValueError, match="--num-files"):
|
||||
get_shard_indices_to_download(-2)
|
||||
|
||||
|
||||
def test_validate_num_workers_requires_positive_count():
|
||||
assert validate_num_workers(1) == 1
|
||||
with pytest.raises(ValueError, match="--num-workers"):
|
||||
validate_num_workers(0)
|
||||
Loading…
Reference in New Issue
Block a user