Compare commits

...

7 Commits

Author SHA1 Message Date
h3nock
ca43cd0e05
Merge eecfdbf9f9 into 9a71d13688 2025-11-13 22:59:38 -05:00
henok3878
eecfdbf9f9 feat(dataset): make work share factor configurable with -f flag 2025-11-13 22:59:14 -05:00
henok3878
ed07192724 feat(dataset): make work share factor configurable with -f flag 2025-11-13 22:48:35 -05:00
henok3878
c2740d3a82 improve dataset downloader logging and add progress bar
store debug level worker output in the logs/dataset_download.log while keeping key messages on stdout
2025-11-07 16:44:26 -05:00
henok3878
2801dc341b add resusable file logger helper method 2025-11-07 16:44:26 -05:00
henok3878
d4cc96d749 add get_logs_dir() to resolve log output path 2025-11-07 16:44:26 -05:00
henok3878
8788ffb3db add helper to locate project root dynamically 2025-11-07 16:44:26 -05:00
4 changed files with 105 additions and 14 deletions

3
.gitignore vendored
View File

@ -4,4 +4,5 @@ __pycache__/
rustbpe/target/ rustbpe/target/
dev-ignore/ dev-ignore/
report.md report.md
eval_bundle/ eval_bundle/
logs/

View File

@ -38,12 +38,32 @@ class ColoredFormatter(logging.Formatter):
def setup_default_logging(): def setup_default_logging():
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
handlers=[handler] handlers=[handler]
) )
def setup_file_logger(logger_name, filename, level=logging.DEBUG, formatter=None):
clean_name = os.path.basename(filename)
if clean_name != filename or not clean_name:
raise ValueError(f"Invalid log filename provided: {filename}")
if not clean_name.endswith(".log"):
clean_name += ".log"
logs_dir = get_logs_dir()
path = os.path.join(logs_dir, clean_name)
handler = logging.FileHandler(path, mode="w")
handler.setLevel(level)
handler.setFormatter(
formatter
or ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger = logging.getLogger(logger_name)
logger.addHandler(handler)
return path
setup_default_logging() setup_default_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,6 +78,38 @@ def get_base_dir():
os.makedirs(nanochat_dir, exist_ok=True) os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir return nanochat_dir
def get_project_root():
# locates the project root by walking upward from this file
_PROJECT_MARKERS = ('.git', 'uv.lock')
curr = os.path.dirname(os.path.abspath(__file__))
while True:
if any(os.path.exists(os.path.join(curr, m)) for m in _PROJECT_MARKERS):
return curr
parent = os.path.dirname(curr)
if parent == curr: # reached filesystem root
return None
curr = parent
def get_logs_dir():
"""
Resolves the directory where log files should be written.
- if $LOG_DIR is set, use that.
- else if, project root is detected, use <project_root>/logs.
- else, fall back to <get_base_dir()>/logs
"""
env = os.environ.get("LOG_DIR")
if env:
path = os.path.abspath(env)
os.makedirs(path, exist_ok=True)
return path
root = get_project_root()
if not root:
root = get_base_dir()
logs = os.path.join(root, 'logs')
os.makedirs(logs, exist_ok=True)
return logs
def download_file_with_lock(url, filename, postprocess_fn=None): def download_file_with_lock(url, filename, postprocess_fn=None):
""" """
Downloads a file from a URL to a local path in the base directory. Downloads a file from a URL to a local path in the base directory.

View File

@ -10,11 +10,14 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`.
import os import os
import argparse import argparse
import time import time
import logging
import requests import requests
from tqdm import tqdm
import pyarrow.parquet as pq import pyarrow.parquet as pq
from functools import partial
from multiprocessing import Pool from multiprocessing import Pool
from nanochat.common import get_base_dir from nanochat.common import get_base_dir, setup_file_logger
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# The specifics of the current pretraining dataset # The specifics of the current pretraining dataset
@ -27,6 +30,19 @@ base_dir = get_base_dir()
DATA_DIR = os.path.join(base_dir, "base_data") DATA_DIR = os.path.join(base_dir, "base_data")
os.makedirs(DATA_DIR, exist_ok=True) os.makedirs(DATA_DIR, exist_ok=True)
# -----------------------------------------------------------------------------
# Minimal logger setup for DEBUG level
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
log_path = setup_file_logger(
logger_name=__name__,
filename="dataset_download.log",
level=logging.DEBUG,
formatter=logging.Formatter(
"%(asctime)s - %(processName)s - %(levelname)s - %(message)s"
),
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# These functions are useful utilities to other modules, can/should be imported # These functions are useful utilities to other modules, can/should be imported
@ -64,12 +80,12 @@ def download_single_file(index):
filename = index_to_filename(index) filename = index_to_filename(index)
filepath = os.path.join(DATA_DIR, filename) filepath = os.path.join(DATA_DIR, filename)
if os.path.exists(filepath): if os.path.exists(filepath):
print(f"Skipping {filepath} (already exists)") logger.debug(f"Skipping {filepath} (already exists)")
return True return True
# Construct the remote URL for this file # Construct the remote URL for this file
url = f"{BASE_URL}/{filename}" url = f"{BASE_URL}/{filename}"
print(f"Downloading {filename}...") logger.debug(f"Downloading {filename}...")
# Download with retries # Download with retries
max_attempts = 5 max_attempts = 5
@ -85,11 +101,11 @@ def download_single_file(index):
f.write(chunk) f.write(chunk)
# Move temp file to final location # Move temp file to final location
os.rename(temp_path, filepath) os.rename(temp_path, filepath)
print(f"Successfully downloaded {filename}") logger.debug(f"Successfully downloaded {filename}")
return True return True
except (requests.RequestException, IOError) as e: except (requests.RequestException, IOError) as e:
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") logger.warning(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
# Clean up any partial files # Clean up any partial files
for path in [filepath + f".tmp", filepath]: for path in [filepath + f".tmp", filepath]:
if os.path.exists(path): if os.path.exists(path):
@ -100,10 +116,10 @@ def download_single_file(index):
# Try a few times with exponential backoff: 2^attempt seconds # Try a few times with exponential backoff: 2^attempt seconds
if attempt < max_attempts: if attempt < max_attempts:
wait_time = 2 ** attempt wait_time = 2 ** attempt
print(f"Waiting {wait_time} seconds before retry...") logger.debug(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time) time.sleep(wait_time)
else: else:
print(f"Failed to download {filename} after {max_attempts} attempts") logger.debug(f"Failed to download {filename} after {max_attempts} attempts")
return False return False
return False return False
@ -113,16 +129,37 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards") parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable") parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)") parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
parser.add_argument(
"-f",
"--work-share-factor",
type=int,
default=8,
help=(
"""Controls how each worker's share of shards is subdivided. CHUNK_SIZE is computed as len(ids_to_download) // (num_workers * work_share_factor), so CHUNK_SIZE is the number of tasks a worker pulls per request from the main process. for example, for 240 shards and 4 workers the default value (8) produces 7 shards per request. setting it 1 gives a worker its entire share (~60 shards) in one go with minimal coordination but slow progress updates. larger work-share-factor values make the main process hand out smaller batches more often for faster feedback at a small scheduling cost."""
),
)
args = parser.parse_args() args = parser.parse_args()
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
ids_to_download = list(range(num)) ids_to_download = list(range(num))
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") logger.info(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
print(f"Target directory: {DATA_DIR}") logger.info(f"Dataset target directory: {DATA_DIR}")
print() logger.info(f"Dataset downloader debug logs will be written to: {log_path}")
# pool.imap_unordered pulls `chunksize` tasks from the main process before asking for more
work_share_factor = max(1, args.work_share_factor)
CHUNK_SIZE = max(1, len(ids_to_download) // (args.num_workers * work_share_factor))
ok_count = 0
with Pool(processes=args.num_workers) as pool: with Pool(processes=args.num_workers) as pool:
results = pool.map(download_single_file, ids_to_download) for ok in tqdm(
pool.imap_unordered(
partial(download_single_file), ids_to_download, chunksize=CHUNK_SIZE
),
total=len(ids_to_download),
desc="all shards",
smoothing=0.1,
):
ok_count += int(ok)
# Report results # Report results
successful = sum(1 for success in results if success) logger.info(f"Done! Downloaded: {ok_count}/{len(ids_to_download)} shards to {DATA_DIR}")
print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")

View File

@ -16,6 +16,7 @@ dependencies = [
"torch>=2.8.0", "torch>=2.8.0",
"uvicorn>=0.36.0", "uvicorn>=0.36.0",
"wandb>=0.21.3", "wandb>=0.21.3",
"tqdm>=4.66.0"
] ]
[build-system] [build-system]