diff --git a/nanochat/dataset.py b/nanochat/dataset.py index 2a6faf6..17ffa9c 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -129,6 +129,15 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards") parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable") parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)") + parser.add_argument( + "-f", + "--work-share-factor", + type=int, + default=8, + help=( + """Controls how each worker's share of shards is subdivided. CHUNK_SIZE is computed as len(ids_to_download) // (num_workers * work_share_factor), so it is the number of tasks a worker pulls per request from the main process. for example, for 240 shards and 4 workers the default value (8) produces 7 shards per request. setting it 1 gives a worker its entire share (~60 shards) in one go with minimal coordination but slow progress updates. larger work-share-factor values make the main process hand out smaller batches more often for faster feedback at a small scheduling cost.""" + ), + ) args = parser.parse_args() num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) @@ -137,7 +146,9 @@ if __name__ == "__main__": logger.info(f"Dataset target directory: {DATA_DIR}") logger.info(f"Dataset downloader debug logs will be written to: {log_path}") - CHUNK_SIZE = max(1, len(ids_to_download) // (args.num_workers * 8)) + # pool.imap_unordered pulls `chunksize` tasks from the main process before asking for more + work_share_factor = max(1, args.work_share_factor) + CHUNK_SIZE = max(1, len(ids_to_download) // (args.num_workers * work_share_factor)) ok_count = 0 with Pool(processes=args.num_workers) as pool: for ok in tqdm(