diff --git a/.gitignore b/.gitignore index 4a87b23..e45c5e7 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ __pycache__/ rustbpe/target/ dev-ignore/ report.md -eval_bundle/ \ No newline at end of file +eval_bundle/ +logs/ diff --git a/nanochat/common.py b/nanochat/common.py index d4a9828..bd0c75c 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -38,12 +38,32 @@ class ColoredFormatter(logging.Formatter): def setup_default_logging(): handler = logging.StreamHandler() + handler.setLevel(logging.INFO) handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) logging.basicConfig( level=logging.INFO, handlers=[handler] ) +def setup_file_logger(logger_name, filename, level=logging.DEBUG, formatter=None): + clean_name = os.path.basename(filename) + if clean_name != filename or not clean_name: + raise ValueError(f"Invalid log filename provided: {filename}") + if not clean_name.endswith(".log"): + clean_name += ".log" + logs_dir = get_logs_dir() + path = os.path.join(logs_dir, clean_name) + + handler = logging.FileHandler(path, mode="w") + handler.setLevel(level) + handler.setFormatter( + formatter + or ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ) + logger = logging.getLogger(logger_name) + logger.addHandler(handler) + return path + setup_default_logging() logger = logging.getLogger(__name__) @@ -58,6 +78,38 @@ def get_base_dir(): os.makedirs(nanochat_dir, exist_ok=True) return nanochat_dir +def get_project_root(): + # locates the project root by walking upward from this file + _PROJECT_MARKERS = ('.git', 'uv.lock') + curr = os.path.dirname(os.path.abspath(__file__)) + while True: + if any(os.path.exists(os.path.join(curr, m)) for m in _PROJECT_MARKERS): + return curr + parent = os.path.dirname(curr) + if parent == curr: # reached filesystem root + return None + curr = parent + +def get_logs_dir(): + """ + Resolves the directory where log files should be written. + - if $LOG_DIR is set, use that. + - else if, project root is detected, use /logs. + - else, fall back to /logs + """ + env = os.environ.get("LOG_DIR") + if env: + path = os.path.abspath(env) + os.makedirs(path, exist_ok=True) + return path + + root = get_project_root() + if not root: + root = get_base_dir() + logs = os.path.join(root, 'logs') + os.makedirs(logs, exist_ok=True) + return logs + def download_file_with_lock(url, filename, postprocess_fn=None): """ Downloads a file from a URL to a local path in the base directory. diff --git a/nanochat/dataset.py b/nanochat/dataset.py index 602daed..2a6faf6 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -10,11 +10,14 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`. import os import argparse import time +import logging import requests +from tqdm import tqdm import pyarrow.parquet as pq +from functools import partial from multiprocessing import Pool -from nanochat.common import get_base_dir +from nanochat.common import get_base_dir, setup_file_logger # ----------------------------------------------------------------------------- # The specifics of the current pretraining dataset @@ -27,6 +30,19 @@ base_dir = get_base_dir() DATA_DIR = os.path.join(base_dir, "base_data") os.makedirs(DATA_DIR, exist_ok=True) +# ----------------------------------------------------------------------------- +# Minimal logger setup for DEBUG level +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +log_path = setup_file_logger( + logger_name=__name__, + filename="dataset_download.log", + level=logging.DEBUG, + formatter=logging.Formatter( + "%(asctime)s - %(processName)s - %(levelname)s - %(message)s" + ), +) + # ----------------------------------------------------------------------------- # These functions are useful utilities to other modules, can/should be imported @@ -64,12 +80,12 @@ def download_single_file(index): filename = index_to_filename(index) filepath = os.path.join(DATA_DIR, filename) if os.path.exists(filepath): - print(f"Skipping {filepath} (already exists)") + logger.debug(f"Skipping {filepath} (already exists)") return True # Construct the remote URL for this file url = f"{BASE_URL}/{filename}" - print(f"Downloading {filename}...") + logger.debug(f"Downloading {filename}...") # Download with retries max_attempts = 5 @@ -85,11 +101,11 @@ def download_single_file(index): f.write(chunk) # Move temp file to final location os.rename(temp_path, filepath) - print(f"Successfully downloaded {filename}") + logger.debug(f"Successfully downloaded {filename}") return True except (requests.RequestException, IOError) as e: - print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") + logger.warning(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") # Clean up any partial files for path in [filepath + f".tmp", filepath]: if os.path.exists(path): @@ -100,10 +116,10 @@ def download_single_file(index): # Try a few times with exponential backoff: 2^attempt seconds if attempt < max_attempts: wait_time = 2 ** attempt - print(f"Waiting {wait_time} seconds before retry...") + logger.debug(f"Waiting {wait_time} seconds before retry...") time.sleep(wait_time) else: - print(f"Failed to download {filename} after {max_attempts} attempts") + logger.debug(f"Failed to download {filename} after {max_attempts} attempts") return False return False @@ -117,12 +133,22 @@ if __name__ == "__main__": num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) ids_to_download = list(range(num)) - print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") - print(f"Target directory: {DATA_DIR}") - print() + logger.info(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") + logger.info(f"Dataset target directory: {DATA_DIR}") + logger.info(f"Dataset downloader debug logs will be written to: {log_path}") + + CHUNK_SIZE = max(1, len(ids_to_download) // (args.num_workers * 8)) + ok_count = 0 with Pool(processes=args.num_workers) as pool: - results = pool.map(download_single_file, ids_to_download) + for ok in tqdm( + pool.imap_unordered( + partial(download_single_file), ids_to_download, chunksize=CHUNK_SIZE + ), + total=len(ids_to_download), + desc="all shards", + smoothing=0.1, + ): + ok_count += int(ok) # Report results - successful = sum(1 for success in results if success) - print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}") + logger.info(f"Done! Downloaded: {ok_count}/{len(ids_to_download)} shards to {DATA_DIR}") diff --git a/pyproject.toml b/pyproject.toml index 3d03c4b..bc6f4a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "torch>=2.8.0", "uvicorn>=0.36.0", "wandb>=0.21.3", + "tqdm>=4.66.0" ] [build-system]