mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
improve dataset downloader logging and add progress bar
store debug level worker output in the logs/dataset_download.log while keeping key messages on stdout
This commit is contained in:
parent
2801dc341b
commit
c2740d3a82
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -4,4 +4,5 @@ __pycache__/
|
|||
rustbpe/target/
|
||||
dev-ignore/
|
||||
report.md
|
||||
eval_bundle/
|
||||
eval_bundle/
|
||||
logs/
|
||||
|
|
|
|||
|
|
@ -10,11 +10,14 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`.
|
|||
import os
|
||||
import argparse
|
||||
import time
|
||||
import logging
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
import pyarrow.parquet as pq
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
from nanochat.common import get_base_dir, setup_file_logger
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# The specifics of the current pretraining dataset
|
||||
|
|
@ -27,6 +30,19 @@ base_dir = get_base_dir()
|
|||
DATA_DIR = os.path.join(base_dir, "base_data")
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Minimal logger setup for DEBUG level
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
log_path = setup_file_logger(
|
||||
logger_name=__name__,
|
||||
filename="dataset_download.log",
|
||||
level=logging.DEBUG,
|
||||
formatter=logging.Formatter(
|
||||
"%(asctime)s - %(processName)s - %(levelname)s - %(message)s"
|
||||
),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# These functions are useful utilities to other modules, can/should be imported
|
||||
|
||||
|
|
@ -64,12 +80,12 @@ def download_single_file(index):
|
|||
filename = index_to_filename(index)
|
||||
filepath = os.path.join(DATA_DIR, filename)
|
||||
if os.path.exists(filepath):
|
||||
print(f"Skipping {filepath} (already exists)")
|
||||
logger.debug(f"Skipping {filepath} (already exists)")
|
||||
return True
|
||||
|
||||
# Construct the remote URL for this file
|
||||
url = f"{BASE_URL}/{filename}"
|
||||
print(f"Downloading {filename}...")
|
||||
logger.debug(f"Downloading {filename}...")
|
||||
|
||||
# Download with retries
|
||||
max_attempts = 5
|
||||
|
|
@ -85,11 +101,11 @@ def download_single_file(index):
|
|||
f.write(chunk)
|
||||
# Move temp file to final location
|
||||
os.rename(temp_path, filepath)
|
||||
print(f"Successfully downloaded {filename}")
|
||||
logger.debug(f"Successfully downloaded {filename}")
|
||||
return True
|
||||
|
||||
except (requests.RequestException, IOError) as e:
|
||||
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
||||
logger.warning(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
||||
# Clean up any partial files
|
||||
for path in [filepath + f".tmp", filepath]:
|
||||
if os.path.exists(path):
|
||||
|
|
@ -100,10 +116,10 @@ def download_single_file(index):
|
|||
# Try a few times with exponential backoff: 2^attempt seconds
|
||||
if attempt < max_attempts:
|
||||
wait_time = 2 ** attempt
|
||||
print(f"Waiting {wait_time} seconds before retry...")
|
||||
logger.debug(f"Waiting {wait_time} seconds before retry...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
print(f"Failed to download {filename} after {max_attempts} attempts")
|
||||
logger.debug(f"Failed to download {filename} after {max_attempts} attempts")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
|
@ -117,12 +133,22 @@ if __name__ == "__main__":
|
|||
|
||||
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
|
||||
ids_to_download = list(range(num))
|
||||
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
||||
print(f"Target directory: {DATA_DIR}")
|
||||
print()
|
||||
logger.info(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
||||
logger.info(f"Dataset target directory: {DATA_DIR}")
|
||||
logger.info(f"Dataset downloader debug logs will be written to: {log_path}")
|
||||
|
||||
CHUNK_SIZE = max(1, len(ids_to_download) // (args.num_workers * 8))
|
||||
ok_count = 0
|
||||
with Pool(processes=args.num_workers) as pool:
|
||||
results = pool.map(download_single_file, ids_to_download)
|
||||
for ok in tqdm(
|
||||
pool.imap_unordered(
|
||||
partial(download_single_file), ids_to_download, chunksize=CHUNK_SIZE
|
||||
),
|
||||
total=len(ids_to_download),
|
||||
desc="all shards",
|
||||
smoothing=0.1,
|
||||
):
|
||||
ok_count += int(ok)
|
||||
|
||||
# Report results
|
||||
successful = sum(1 for success in results if success)
|
||||
print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
|
||||
logger.info(f"Done! Downloaded: {ok_count}/{len(ids_to_download)} shards to {DATA_DIR}")
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ dependencies = [
|
|||
"torch>=2.8.0",
|
||||
"uvicorn>=0.36.0",
|
||||
"wandb>=0.21.3",
|
||||
"tqdm>=4.66.0"
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user