diff --git a/nanochat/dataset.py b/nanochat/dataset.py index fffe722e..4c94e1f4 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -132,6 +132,21 @@ def download_single_file(index): return False +def get_shard_indices_to_download(num_files): + """Return train shard indices plus the always-required validation shard.""" + if num_files < -1: + raise ValueError("--num-files must be -1 or a non-negative integer") + num_train_shards = MAX_SHARD if num_files == -1 else min(num_files, MAX_SHARD) + ids_to_download = list(range(num_train_shards)) + ids_to_download.append(MAX_SHARD) # always download the validation shard + return ids_to_download + +def validate_num_workers(num_workers): + """Validate the multiprocessing worker count before creating a Pool.""" + if num_workers < 1: + raise ValueError("--num-workers must be at least 1") + return num_workers + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Download pretraining dataset shards") @@ -144,15 +159,17 @@ if __name__ == "__main__": # The way this works is that the user specifies the number of train shards to download via the -n flag. # In addition to that, the validation shard is *always* downloaded and is pinned to be the last shard. - num_train_shards = MAX_SHARD if args.num_files == -1 else min(args.num_files, MAX_SHARD) - ids_to_download = list(range(num_train_shards)) - ids_to_download.append(MAX_SHARD) # always download the validation shard + try: + ids_to_download = get_shard_indices_to_download(args.num_files) + num_workers = validate_num_workers(args.num_workers) + except ValueError as e: + parser.error(str(e)) # Download the shards - print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") + print(f"Downloading {len(ids_to_download)} shards using {num_workers} workers...") print(f"Target directory: {DATA_DIR}") print() - with Pool(processes=args.num_workers) as pool: + with Pool(processes=num_workers) as pool: results = pool.map(download_single_file, ids_to_download) # Report results diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 00000000..07d5526a --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,26 @@ +import pytest + +from nanochat.dataset import MAX_SHARD, get_shard_indices_to_download, validate_num_workers + + +def test_get_shard_indices_to_download_includes_validation_shard(): + assert get_shard_indices_to_download(0) == [MAX_SHARD] + assert get_shard_indices_to_download(2) == [0, 1, MAX_SHARD] + + +def test_get_shard_indices_to_download_caps_train_shards(): + indices = get_shard_indices_to_download(MAX_SHARD + 100) + assert indices[0] == 0 + assert indices[-2:] == [MAX_SHARD - 1, MAX_SHARD] + assert len(indices) == MAX_SHARD + 1 + + +def test_get_shard_indices_to_download_rejects_negative_counts_except_all(): + with pytest.raises(ValueError, match="--num-files"): + get_shard_indices_to_download(-2) + + +def test_validate_num_workers_requires_positive_count(): + assert validate_num_workers(1) == 1 + with pytest.raises(ValueError, match="--num-workers"): + validate_num_workers(0)