diff --git a/.gitignore b/.gitignore index 4a87b23..e45c5e7 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ __pycache__/ rustbpe/target/ dev-ignore/ report.md -eval_bundle/ \ No newline at end of file +eval_bundle/ +logs/ diff --git a/nanochat/dataset.py b/nanochat/dataset.py index 602daed..2a6faf6 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -10,11 +10,14 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`. import os import argparse import time +import logging import requests +from tqdm import tqdm import pyarrow.parquet as pq +from functools import partial from multiprocessing import Pool -from nanochat.common import get_base_dir +from nanochat.common import get_base_dir, setup_file_logger # ----------------------------------------------------------------------------- # The specifics of the current pretraining dataset @@ -27,6 +30,19 @@ base_dir = get_base_dir() DATA_DIR = os.path.join(base_dir, "base_data") os.makedirs(DATA_DIR, exist_ok=True) +# ----------------------------------------------------------------------------- +# Minimal logger setup for DEBUG level +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +log_path = setup_file_logger( + logger_name=__name__, + filename="dataset_download.log", + level=logging.DEBUG, + formatter=logging.Formatter( + "%(asctime)s - %(processName)s - %(levelname)s - %(message)s" + ), +) + # ----------------------------------------------------------------------------- # These functions are useful utilities to other modules, can/should be imported @@ -64,12 +80,12 @@ def download_single_file(index): filename = index_to_filename(index) filepath = os.path.join(DATA_DIR, filename) if os.path.exists(filepath): - print(f"Skipping {filepath} (already exists)") + logger.debug(f"Skipping {filepath} (already exists)") return True # Construct the remote URL for this file url = f"{BASE_URL}/{filename}" - print(f"Downloading {filename}...") + logger.debug(f"Downloading {filename}...") # Download with retries max_attempts = 5 @@ -85,11 +101,11 @@ def download_single_file(index): f.write(chunk) # Move temp file to final location os.rename(temp_path, filepath) - print(f"Successfully downloaded {filename}") + logger.debug(f"Successfully downloaded {filename}") return True except (requests.RequestException, IOError) as e: - print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") + logger.warning(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") # Clean up any partial files for path in [filepath + f".tmp", filepath]: if os.path.exists(path): @@ -100,10 +116,10 @@ def download_single_file(index): # Try a few times with exponential backoff: 2^attempt seconds if attempt < max_attempts: wait_time = 2 ** attempt - print(f"Waiting {wait_time} seconds before retry...") + logger.debug(f"Waiting {wait_time} seconds before retry...") time.sleep(wait_time) else: - print(f"Failed to download {filename} after {max_attempts} attempts") + logger.debug(f"Failed to download {filename} after {max_attempts} attempts") return False return False @@ -117,12 +133,22 @@ if __name__ == "__main__": num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) ids_to_download = list(range(num)) - print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") - print(f"Target directory: {DATA_DIR}") - print() + logger.info(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") + logger.info(f"Dataset target directory: {DATA_DIR}") + logger.info(f"Dataset downloader debug logs will be written to: {log_path}") + + CHUNK_SIZE = max(1, len(ids_to_download) // (args.num_workers * 8)) + ok_count = 0 with Pool(processes=args.num_workers) as pool: - results = pool.map(download_single_file, ids_to_download) + for ok in tqdm( + pool.imap_unordered( + partial(download_single_file), ids_to_download, chunksize=CHUNK_SIZE + ), + total=len(ids_to_download), + desc="all shards", + smoothing=0.1, + ): + ok_count += int(ok) # Report results - successful = sum(1 for success in results if success) - print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}") + logger.info(f"Done! Downloaded: {ok_count}/{len(ids_to_download)} shards to {DATA_DIR}") diff --git a/pyproject.toml b/pyproject.toml index 3d03c4b..bc6f4a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "torch>=2.8.0", "uvicorn>=0.36.0", "wandb>=0.21.3", + "tqdm>=4.66.0" ] [build-system]