This commit is contained in:
h3nock 2025-11-14 12:55:24 -05:00 committed by GitHub
commit d8b86be6de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 105 additions and 14 deletions

3
.gitignore vendored
View File

@ -4,4 +4,5 @@ __pycache__/
rustbpe/target/ rustbpe/target/
dev-ignore/ dev-ignore/
report.md report.md
eval_bundle/ eval_bundle/
logs/

View File

@ -38,12 +38,32 @@ class ColoredFormatter(logging.Formatter):
def setup_default_logging(): def setup_default_logging():
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
handlers=[handler] handlers=[handler]
) )
def setup_file_logger(logger_name, filename, level=logging.DEBUG, formatter=None):
clean_name = os.path.basename(filename)
if clean_name != filename or not clean_name:
raise ValueError(f"Invalid log filename provided: {filename}")
if not clean_name.endswith(".log"):
clean_name += ".log"
logs_dir = get_logs_dir()
path = os.path.join(logs_dir, clean_name)
handler = logging.FileHandler(path, mode="w")
handler.setLevel(level)
handler.setFormatter(
formatter
or ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger = logging.getLogger(logger_name)
logger.addHandler(handler)
return path
setup_default_logging() setup_default_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,6 +78,38 @@ def get_base_dir():
os.makedirs(nanochat_dir, exist_ok=True) os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir return nanochat_dir
def get_project_root():
# locates the project root by walking upward from this file
_PROJECT_MARKERS = ('.git', 'uv.lock')
curr = os.path.dirname(os.path.abspath(__file__))
while True:
if any(os.path.exists(os.path.join(curr, m)) for m in _PROJECT_MARKERS):
return curr
parent = os.path.dirname(curr)
if parent == curr: # reached filesystem root
return None
curr = parent
def get_logs_dir():
"""
Resolves the directory where log files should be written.
- if $LOG_DIR is set, use that.
- else if, project root is detected, use <project_root>/logs.
- else, fall back to <get_base_dir()>/logs
"""
env = os.environ.get("LOG_DIR")
if env:
path = os.path.abspath(env)
os.makedirs(path, exist_ok=True)
return path
root = get_project_root()
if not root:
root = get_base_dir()
logs = os.path.join(root, 'logs')
os.makedirs(logs, exist_ok=True)
return logs
def download_file_with_lock(url, filename, postprocess_fn=None): def download_file_with_lock(url, filename, postprocess_fn=None):
""" """
Downloads a file from a URL to a local path in the base directory. Downloads a file from a URL to a local path in the base directory.

View File

@ -10,11 +10,14 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`.
import os import os
import argparse import argparse
import time import time
import logging
import requests import requests
from tqdm import tqdm
import pyarrow.parquet as pq import pyarrow.parquet as pq
from functools import partial
from multiprocessing import Pool from multiprocessing import Pool
from nanochat.common import get_base_dir from nanochat.common import get_base_dir, setup_file_logger
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# The specifics of the current pretraining dataset # The specifics of the current pretraining dataset
@ -27,6 +30,19 @@ base_dir = get_base_dir()
DATA_DIR = os.path.join(base_dir, "base_data") DATA_DIR = os.path.join(base_dir, "base_data")
os.makedirs(DATA_DIR, exist_ok=True) os.makedirs(DATA_DIR, exist_ok=True)
# -----------------------------------------------------------------------------
# Minimal logger setup for DEBUG level
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
log_path = setup_file_logger(
logger_name=__name__,
filename="dataset_download.log",
level=logging.DEBUG,
formatter=logging.Formatter(
"%(asctime)s - %(processName)s - %(levelname)s - %(message)s"
),
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# These functions are useful utilities to other modules, can/should be imported # These functions are useful utilities to other modules, can/should be imported
@ -64,12 +80,12 @@ def download_single_file(index):
filename = index_to_filename(index) filename = index_to_filename(index)
filepath = os.path.join(DATA_DIR, filename) filepath = os.path.join(DATA_DIR, filename)
if os.path.exists(filepath): if os.path.exists(filepath):
print(f"Skipping {filepath} (already exists)") logger.debug(f"Skipping {filepath} (already exists)")
return True return True
# Construct the remote URL for this file # Construct the remote URL for this file
url = f"{BASE_URL}/{filename}" url = f"{BASE_URL}/{filename}"
print(f"Downloading {filename}...") logger.debug(f"Downloading {filename}...")
# Download with retries # Download with retries
max_attempts = 5 max_attempts = 5
@ -85,11 +101,11 @@ def download_single_file(index):
f.write(chunk) f.write(chunk)
# Move temp file to final location # Move temp file to final location
os.rename(temp_path, filepath) os.rename(temp_path, filepath)
print(f"Successfully downloaded {filename}") logger.debug(f"Successfully downloaded {filename}")
return True return True
except (requests.RequestException, IOError) as e: except (requests.RequestException, IOError) as e:
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") logger.warning(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
# Clean up any partial files # Clean up any partial files
for path in [filepath + f".tmp", filepath]: for path in [filepath + f".tmp", filepath]:
if os.path.exists(path): if os.path.exists(path):
@ -100,10 +116,10 @@ def download_single_file(index):
# Try a few times with exponential backoff: 2^attempt seconds # Try a few times with exponential backoff: 2^attempt seconds
if attempt < max_attempts: if attempt < max_attempts:
wait_time = 2 ** attempt wait_time = 2 ** attempt
print(f"Waiting {wait_time} seconds before retry...") logger.debug(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time) time.sleep(wait_time)
else: else:
print(f"Failed to download {filename} after {max_attempts} attempts") logger.debug(f"Failed to download {filename} after {max_attempts} attempts")
return False return False
return False return False
@ -113,16 +129,37 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards") parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable") parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)") parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
parser.add_argument(
"-f",
"--work-share-factor",
type=int,
default=8,
help=(
"""Controls how each worker's share of shards is subdivided. CHUNK_SIZE is computed as len(ids_to_download) // (num_workers * work_share_factor), so CHUNK_SIZE is the number of tasks a worker pulls per request from the main process. for example, for 240 shards and 4 workers the default value (8) produces 7 shards per request. setting it 1 gives a worker its entire share (~60 shards) in one go with minimal coordination but slow progress updates. larger work-share-factor values make the main process hand out smaller batches more often for faster feedback at a small scheduling cost."""
),
)
args = parser.parse_args() args = parser.parse_args()
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
ids_to_download = list(range(num)) ids_to_download = list(range(num))
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") logger.info(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
print(f"Target directory: {DATA_DIR}") logger.info(f"Dataset target directory: {DATA_DIR}")
print() logger.info(f"Dataset downloader debug logs will be written to: {log_path}")
# pool.imap_unordered pulls `chunksize` tasks from the main process before asking for more
work_share_factor = max(1, args.work_share_factor)
CHUNK_SIZE = max(1, len(ids_to_download) // (args.num_workers * work_share_factor))
ok_count = 0
with Pool(processes=args.num_workers) as pool: with Pool(processes=args.num_workers) as pool:
results = pool.map(download_single_file, ids_to_download) for ok in tqdm(
pool.imap_unordered(
partial(download_single_file), ids_to_download, chunksize=CHUNK_SIZE
),
total=len(ids_to_download),
desc="all shards",
smoothing=0.1,
):
ok_count += int(ok)
# Report results # Report results
successful = sum(1 for success in results if success) logger.info(f"Done! Downloaded: {ok_count}/{len(ids_to_download)} shards to {DATA_DIR}")
print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")

View File

@ -16,6 +16,7 @@ dependencies = [
"torch>=2.8.0", "torch>=2.8.0",
"uvicorn>=0.36.0", "uvicorn>=0.36.0",
"wandb>=0.21.3", "wandb>=0.21.3",
"tqdm>=4.66.0"
] ]
[build-system] [build-system]