This commit is contained in:
h3nock 2025-11-08 03:14:13 -05:00 committed by GitHub
commit bda3b1a986
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 94 additions and 14 deletions

3
.gitignore vendored
View File

@ -4,4 +4,5 @@ __pycache__/
rustbpe/target/
dev-ignore/
report.md
eval_bundle/
eval_bundle/
logs/

View File

@ -38,12 +38,32 @@ class ColoredFormatter(logging.Formatter):
def setup_default_logging():
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logging.basicConfig(
level=logging.INFO,
handlers=[handler]
)
def setup_file_logger(logger_name, filename, level=logging.DEBUG, formatter=None):
clean_name = os.path.basename(filename)
if clean_name != filename or not clean_name:
raise ValueError(f"Invalid log filename provided: {filename}")
if not clean_name.endswith(".log"):
clean_name += ".log"
logs_dir = get_logs_dir()
path = os.path.join(logs_dir, clean_name)
handler = logging.FileHandler(path, mode="w")
handler.setLevel(level)
handler.setFormatter(
formatter
or ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger = logging.getLogger(logger_name)
logger.addHandler(handler)
return path
setup_default_logging()
logger = logging.getLogger(__name__)
@ -58,6 +78,38 @@ def get_base_dir():
os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir
def get_project_root():
# locates the project root by walking upward from this file
_PROJECT_MARKERS = ('.git', 'uv.lock')
curr = os.path.dirname(os.path.abspath(__file__))
while True:
if any(os.path.exists(os.path.join(curr, m)) for m in _PROJECT_MARKERS):
return curr
parent = os.path.dirname(curr)
if parent == curr: # reached filesystem root
return None
curr = parent
def get_logs_dir():
"""
Resolves the directory where log files should be written.
- if $LOG_DIR is set, use that.
- else if, project root is detected, use <project_root>/logs.
- else, fall back to <get_base_dir()>/logs
"""
env = os.environ.get("LOG_DIR")
if env:
path = os.path.abspath(env)
os.makedirs(path, exist_ok=True)
return path
root = get_project_root()
if not root:
root = get_base_dir()
logs = os.path.join(root, 'logs')
os.makedirs(logs, exist_ok=True)
return logs
def download_file_with_lock(url, filename, postprocess_fn=None):
"""
Downloads a file from a URL to a local path in the base directory.

View File

@ -10,11 +10,14 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`.
import os
import argparse
import time
import logging
import requests
from tqdm import tqdm
import pyarrow.parquet as pq
from functools import partial
from multiprocessing import Pool
from nanochat.common import get_base_dir
from nanochat.common import get_base_dir, setup_file_logger
# -----------------------------------------------------------------------------
# The specifics of the current pretraining dataset
@ -27,6 +30,19 @@ base_dir = get_base_dir()
DATA_DIR = os.path.join(base_dir, "base_data")
os.makedirs(DATA_DIR, exist_ok=True)
# -----------------------------------------------------------------------------
# Minimal logger setup for DEBUG level
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
log_path = setup_file_logger(
logger_name=__name__,
filename="dataset_download.log",
level=logging.DEBUG,
formatter=logging.Formatter(
"%(asctime)s - %(processName)s - %(levelname)s - %(message)s"
),
)
# -----------------------------------------------------------------------------
# These functions are useful utilities to other modules, can/should be imported
@ -64,12 +80,12 @@ def download_single_file(index):
filename = index_to_filename(index)
filepath = os.path.join(DATA_DIR, filename)
if os.path.exists(filepath):
print(f"Skipping {filepath} (already exists)")
logger.debug(f"Skipping {filepath} (already exists)")
return True
# Construct the remote URL for this file
url = f"{BASE_URL}/{filename}"
print(f"Downloading {filename}...")
logger.debug(f"Downloading {filename}...")
# Download with retries
max_attempts = 5
@ -85,11 +101,11 @@ def download_single_file(index):
f.write(chunk)
# Move temp file to final location
os.rename(temp_path, filepath)
print(f"Successfully downloaded {filename}")
logger.debug(f"Successfully downloaded {filename}")
return True
except (requests.RequestException, IOError) as e:
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
logger.warning(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
# Clean up any partial files
for path in [filepath + f".tmp", filepath]:
if os.path.exists(path):
@ -100,10 +116,10 @@ def download_single_file(index):
# Try a few times with exponential backoff: 2^attempt seconds
if attempt < max_attempts:
wait_time = 2 ** attempt
print(f"Waiting {wait_time} seconds before retry...")
logger.debug(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time)
else:
print(f"Failed to download {filename} after {max_attempts} attempts")
logger.debug(f"Failed to download {filename} after {max_attempts} attempts")
return False
return False
@ -117,12 +133,22 @@ if __name__ == "__main__":
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
ids_to_download = list(range(num))
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
print(f"Target directory: {DATA_DIR}")
print()
logger.info(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
logger.info(f"Dataset target directory: {DATA_DIR}")
logger.info(f"Dataset downloader debug logs will be written to: {log_path}")
CHUNK_SIZE = max(1, len(ids_to_download) // (args.num_workers * 8))
ok_count = 0
with Pool(processes=args.num_workers) as pool:
results = pool.map(download_single_file, ids_to_download)
for ok in tqdm(
pool.imap_unordered(
partial(download_single_file), ids_to_download, chunksize=CHUNK_SIZE
),
total=len(ids_to_download),
desc="all shards",
smoothing=0.1,
):
ok_count += int(ok)
# Report results
successful = sum(1 for success in results if success)
print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
logger.info(f"Done! Downloaded: {ok_count}/{len(ids_to_download)} shards to {DATA_DIR}")

View File

@ -16,6 +16,7 @@ dependencies = [
"torch>=2.8.0",
"uvicorn>=0.36.0",
"wandb>=0.21.3",
"tqdm>=4.66.0"
]
[build-system]