This commit is contained in:
h3nock 2025-11-08 03:14:13 -05:00 committed by GitHub
commit bda3b1a986
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 94 additions and 14 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ rustbpe/target/
dev-ignore/ dev-ignore/
report.md report.md
eval_bundle/ eval_bundle/
logs/

View File

@ -38,12 +38,32 @@ class ColoredFormatter(logging.Formatter):
def setup_default_logging(): def setup_default_logging():
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
handlers=[handler] handlers=[handler]
) )
def setup_file_logger(logger_name, filename, level=logging.DEBUG, formatter=None):
clean_name = os.path.basename(filename)
if clean_name != filename or not clean_name:
raise ValueError(f"Invalid log filename provided: {filename}")
if not clean_name.endswith(".log"):
clean_name += ".log"
logs_dir = get_logs_dir()
path = os.path.join(logs_dir, clean_name)
handler = logging.FileHandler(path, mode="w")
handler.setLevel(level)
handler.setFormatter(
formatter
or ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger = logging.getLogger(logger_name)
logger.addHandler(handler)
return path
setup_default_logging() setup_default_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,6 +78,38 @@ def get_base_dir():
os.makedirs(nanochat_dir, exist_ok=True) os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir return nanochat_dir
def get_project_root():
# locates the project root by walking upward from this file
_PROJECT_MARKERS = ('.git', 'uv.lock')
curr = os.path.dirname(os.path.abspath(__file__))
while True:
if any(os.path.exists(os.path.join(curr, m)) for m in _PROJECT_MARKERS):
return curr
parent = os.path.dirname(curr)
if parent == curr: # reached filesystem root
return None
curr = parent
def get_logs_dir():
"""
Resolves the directory where log files should be written.
- if $LOG_DIR is set, use that.
- else if, project root is detected, use <project_root>/logs.
- else, fall back to <get_base_dir()>/logs
"""
env = os.environ.get("LOG_DIR")
if env:
path = os.path.abspath(env)
os.makedirs(path, exist_ok=True)
return path
root = get_project_root()
if not root:
root = get_base_dir()
logs = os.path.join(root, 'logs')
os.makedirs(logs, exist_ok=True)
return logs
def download_file_with_lock(url, filename, postprocess_fn=None): def download_file_with_lock(url, filename, postprocess_fn=None):
""" """
Downloads a file from a URL to a local path in the base directory. Downloads a file from a URL to a local path in the base directory.

View File

@ -10,11 +10,14 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`.
import os import os
import argparse import argparse
import time import time
import logging
import requests import requests
from tqdm import tqdm
import pyarrow.parquet as pq import pyarrow.parquet as pq
from functools import partial
from multiprocessing import Pool from multiprocessing import Pool
from nanochat.common import get_base_dir from nanochat.common import get_base_dir, setup_file_logger
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# The specifics of the current pretraining dataset # The specifics of the current pretraining dataset
@ -27,6 +30,19 @@ base_dir = get_base_dir()
DATA_DIR = os.path.join(base_dir, "base_data") DATA_DIR = os.path.join(base_dir, "base_data")
os.makedirs(DATA_DIR, exist_ok=True) os.makedirs(DATA_DIR, exist_ok=True)
# -----------------------------------------------------------------------------
# Minimal logger setup for DEBUG level
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
log_path = setup_file_logger(
logger_name=__name__,
filename="dataset_download.log",
level=logging.DEBUG,
formatter=logging.Formatter(
"%(asctime)s - %(processName)s - %(levelname)s - %(message)s"
),
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# These functions are useful utilities to other modules, can/should be imported # These functions are useful utilities to other modules, can/should be imported
@ -64,12 +80,12 @@ def download_single_file(index):
filename = index_to_filename(index) filename = index_to_filename(index)
filepath = os.path.join(DATA_DIR, filename) filepath = os.path.join(DATA_DIR, filename)
if os.path.exists(filepath): if os.path.exists(filepath):
print(f"Skipping {filepath} (already exists)") logger.debug(f"Skipping {filepath} (already exists)")
return True return True
# Construct the remote URL for this file # Construct the remote URL for this file
url = f"{BASE_URL}/{filename}" url = f"{BASE_URL}/{filename}"
print(f"Downloading {filename}...") logger.debug(f"Downloading {filename}...")
# Download with retries # Download with retries
max_attempts = 5 max_attempts = 5
@ -85,11 +101,11 @@ def download_single_file(index):
f.write(chunk) f.write(chunk)
# Move temp file to final location # Move temp file to final location
os.rename(temp_path, filepath) os.rename(temp_path, filepath)
print(f"Successfully downloaded {filename}") logger.debug(f"Successfully downloaded {filename}")
return True return True
except (requests.RequestException, IOError) as e: except (requests.RequestException, IOError) as e:
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") logger.warning(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
# Clean up any partial files # Clean up any partial files
for path in [filepath + f".tmp", filepath]: for path in [filepath + f".tmp", filepath]:
if os.path.exists(path): if os.path.exists(path):
@ -100,10 +116,10 @@ def download_single_file(index):
# Try a few times with exponential backoff: 2^attempt seconds # Try a few times with exponential backoff: 2^attempt seconds
if attempt < max_attempts: if attempt < max_attempts:
wait_time = 2 ** attempt wait_time = 2 ** attempt
print(f"Waiting {wait_time} seconds before retry...") logger.debug(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time) time.sleep(wait_time)
else: else:
print(f"Failed to download {filename} after {max_attempts} attempts") logger.debug(f"Failed to download {filename} after {max_attempts} attempts")
return False return False
return False return False
@ -117,12 +133,22 @@ if __name__ == "__main__":
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
ids_to_download = list(range(num)) ids_to_download = list(range(num))
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...") logger.info(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
print(f"Target directory: {DATA_DIR}") logger.info(f"Dataset target directory: {DATA_DIR}")
print() logger.info(f"Dataset downloader debug logs will be written to: {log_path}")
CHUNK_SIZE = max(1, len(ids_to_download) // (args.num_workers * 8))
ok_count = 0
with Pool(processes=args.num_workers) as pool: with Pool(processes=args.num_workers) as pool:
results = pool.map(download_single_file, ids_to_download) for ok in tqdm(
pool.imap_unordered(
partial(download_single_file), ids_to_download, chunksize=CHUNK_SIZE
),
total=len(ids_to_download),
desc="all shards",
smoothing=0.1,
):
ok_count += int(ok)
# Report results # Report results
successful = sum(1 for success in results if success) logger.info(f"Done! Downloaded: {ok_count}/{len(ids_to_download)} shards to {DATA_DIR}")
print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")

View File

@ -16,6 +16,7 @@ dependencies = [
"torch>=2.8.0", "torch>=2.8.0",
"uvicorn>=0.36.0", "uvicorn>=0.36.0",
"wandb>=0.21.3", "wandb>=0.21.3",
"tqdm>=4.66.0"
] ]
[build-system] [build-system]