adjust logic for downloading validation shard based on number of files

This commit is contained in:
Kartik Vashishta 2026-02-26 04:46:42 +11:00
parent 3a0550ab48
commit 2277da9ff4

View File

@ -143,9 +143,8 @@ if __name__ == "__main__":
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
ids_to_download = list(range(num))
if VAL_SHARD_INDEX not in ids_to_download:
if num <= VAL_SHARD_INDEX:
ids_to_download.append(VAL_SHARD_INDEX)
ids_to_download = sorted(set(ids_to_download))
if args.num_files != -1 and args.num_files <= MAX_SHARD:
print(f"Downloading {len(ids_to_download)} shards ({num} train + 1 val) using {args.num_workers} workers...")