improve dataset downloader logging and add progress bar

store debug level worker output in the logs/dataset_download.log while keeping key messages on stdout
This commit is contained in:
henok3878 2025-11-07 16:06:01 -05:00
parent 2801dc341b
commit c2740d3a82
3 changed files with 42 additions and 14 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ rustbpe/target/
dev-ignore/ dev-ignore/
report.md report.md
eval_bundle/ eval_bundle/
logs/

View File

@ -10,11 +10,14 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`.
import os import os
import argparse import argparse
import time import time
import logging
import requests import requests
from tqdm import tqdm
import pyarrow.parquet as pq import pyarrow.parquet as pq
from functools import partial
from multiprocessing import Pool from multiprocessing import Pool
from nanochat.common import get_base_dir from nanochat.common import get_base_dir, setup_file_logger
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# The specifics of the current pretraining dataset # The specifics of the current pretraining dataset
@ -27,6 +30,19 @@ base_dir = get_base_dir()
DATA_DIR = os.path.join(base_dir, "base_data") DATA_DIR = os.path.join(base_dir, "base_data")
os.makedirs(DATA_DIR, exist_ok=True) os.makedirs(DATA_DIR, exist_ok=True)
# -----------------------------------------------------------------------------
# Minimal logger setup for DEBUG level
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
log_path = setup_file_logger(
logger_name=__name__,
filename="dataset_download.log",
level=logging.DEBUG,
formatter=logging.Formatter(
"%(asctime)s - %(processName)s - %(levelname)s - %(message)s"
),
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# These functions are useful utilities to other modules, can/should be imported # These functions are useful utilities to other modules, can/should be imported
@ -64,12 +80,12 @@ def download_single_file(index):
filename = index_to_filename(index) filename = index_to_filename(index)
filepath = os.path.join(DATA_DIR, filename) filepath = os.path.join(DATA_DIR, filename)
if os.path.exists(filepath): if os.path.exists(filepath):
print(f"Skipping {filepath} (already exists)") logger.debug(f"Skipping {filepath} (already exists)")
return True return True
# Construct the remote URL for this file # Construct the remote URL for this file
url = f"{BASE_URL}/{filename}" url = f"{BASE_URL}/{filename}"
print(f"Downloading {filename}...") logger.debug(f"Downloading {filename}...")
# Download with retries # Download with retries
max_attempts = 5 max_attempts = 5
@ -85,11 +101,11 @@ def download_single_file(index):
f.write(chunk) f.write(chunk)
# Move temp file to final location # Move temp file to final location
os.rename(temp_path, filepath) os.rename(temp_path, filepath)
print(f"Successfully downloaded {filename}") logger.debug(f"Successfully downloaded {filename}")
return True return True
except (requests.RequestException, IOError) as e: except (requests.RequestException, IOError) as e:
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") logger.warning(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
# Clean up any partial files # Clean up any partial files
for path in [filepath + f".tmp", filepath]: for path in [filepath + f".tmp", filepath]:
if os.path.exists(path): if os.path.exists(path):
@ -100,10 +116,10 @@ def download_single_file(index):
# Try a few times with exponential backoff: 2^attempt seconds # Try a few times with exponential backoff: 2^attempt seconds
if attempt < max_attempts: if attempt < max_attempts:
wait_time = 2 ** attempt wait_time = 2 ** attempt
print(f"Waiting {wait_time} seconds before retry...") logger.debug(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time) time.sleep(wait_time)
else: else:
print(f"Failed to download {filename} after {max_attempts} attempts") logger.debug(f"Failed to download {filename} after {max_attempts} attempts")
return False return False
return False return False
@ -117,12 +133,22 @@ if __name__ == "__main__":
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
ids_to_download = list(range(num)) ids_to_download = list(range(num))
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") logger.info(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
print(f"Target directory: {DATA_DIR}") logger.info(f"Dataset target directory: {DATA_DIR}")
print() logger.info(f"Dataset downloader debug logs will be written to: {log_path}")
CHUNK_SIZE = max(1, len(ids_to_download) // (args.num_workers * 8))
ok_count = 0
with Pool(processes=args.num_workers) as pool: with Pool(processes=args.num_workers) as pool:
results = pool.map(download_single_file, ids_to_download) for ok in tqdm(
pool.imap_unordered(
partial(download_single_file), ids_to_download, chunksize=CHUNK_SIZE
),
total=len(ids_to_download),
desc="all shards",
smoothing=0.1,
):
ok_count += int(ok)
# Report results # Report results
successful = sum(1 for success in results if success) logger.info(f"Done! Downloaded: {ok_count}/{len(ids_to_download)} shards to {DATA_DIR}")
print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")

View File

@ -16,6 +16,7 @@ dependencies = [
"torch>=2.8.0", "torch>=2.8.0",
"uvicorn>=0.36.0", "uvicorn>=0.36.0",
"wandb>=0.21.3", "wandb>=0.21.3",
"tqdm>=4.66.0"
] ]
[build-system] [build-system]