diff --git a/nanochat/dataset.py b/nanochat/dataset.py index b695ebb..2d60f56 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -143,9 +143,8 @@ if __name__ == "__main__": num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) ids_to_download = list(range(num)) - if VAL_SHARD_INDEX not in ids_to_download: + if num <= VAL_SHARD_INDEX: ids_to_download.append(VAL_SHARD_INDEX) - ids_to_download = sorted(set(ids_to_download)) if args.num_files != -1 and args.num_files <= MAX_SHARD: print(f"Downloading {len(ids_to_download)} shards ({num} train + 1 val) using {args.num_workers} workers...")