mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
Merge eecfdbf9f9 into f66a780f68
This commit is contained in:
commit
d8b86be6de
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -5,3 +5,4 @@ rustbpe/target/
|
||||||
dev-ignore/
|
dev-ignore/
|
||||||
report.md
|
report.md
|
||||||
eval_bundle/
|
eval_bundle/
|
||||||
|
logs/
|
||||||
|
|
|
||||||
|
|
@ -38,12 +38,32 @@ class ColoredFormatter(logging.Formatter):
|
||||||
|
|
||||||
def setup_default_logging():
|
def setup_default_logging():
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
|
handler.setLevel(logging.INFO)
|
||||||
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
handlers=[handler]
|
handlers=[handler]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def setup_file_logger(logger_name, filename, level=logging.DEBUG, formatter=None):
|
||||||
|
clean_name = os.path.basename(filename)
|
||||||
|
if clean_name != filename or not clean_name:
|
||||||
|
raise ValueError(f"Invalid log filename provided: {filename}")
|
||||||
|
if not clean_name.endswith(".log"):
|
||||||
|
clean_name += ".log"
|
||||||
|
logs_dir = get_logs_dir()
|
||||||
|
path = os.path.join(logs_dir, clean_name)
|
||||||
|
|
||||||
|
handler = logging.FileHandler(path, mode="w")
|
||||||
|
handler.setLevel(level)
|
||||||
|
handler.setFormatter(
|
||||||
|
formatter
|
||||||
|
or ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
return path
|
||||||
|
|
||||||
setup_default_logging()
|
setup_default_logging()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -58,6 +78,38 @@ def get_base_dir():
|
||||||
os.makedirs(nanochat_dir, exist_ok=True)
|
os.makedirs(nanochat_dir, exist_ok=True)
|
||||||
return nanochat_dir
|
return nanochat_dir
|
||||||
|
|
||||||
|
def get_project_root():
|
||||||
|
# locates the project root by walking upward from this file
|
||||||
|
_PROJECT_MARKERS = ('.git', 'uv.lock')
|
||||||
|
curr = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
while True:
|
||||||
|
if any(os.path.exists(os.path.join(curr, m)) for m in _PROJECT_MARKERS):
|
||||||
|
return curr
|
||||||
|
parent = os.path.dirname(curr)
|
||||||
|
if parent == curr: # reached filesystem root
|
||||||
|
return None
|
||||||
|
curr = parent
|
||||||
|
|
||||||
|
def get_logs_dir():
|
||||||
|
"""
|
||||||
|
Resolves the directory where log files should be written.
|
||||||
|
- if $LOG_DIR is set, use that.
|
||||||
|
- else if, project root is detected, use <project_root>/logs.
|
||||||
|
- else, fall back to <get_base_dir()>/logs
|
||||||
|
"""
|
||||||
|
env = os.environ.get("LOG_DIR")
|
||||||
|
if env:
|
||||||
|
path = os.path.abspath(env)
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
root = get_project_root()
|
||||||
|
if not root:
|
||||||
|
root = get_base_dir()
|
||||||
|
logs = os.path.join(root, 'logs')
|
||||||
|
os.makedirs(logs, exist_ok=True)
|
||||||
|
return logs
|
||||||
|
|
||||||
def download_file_with_lock(url, filename, postprocess_fn=None):
|
def download_file_with_lock(url, filename, postprocess_fn=None):
|
||||||
"""
|
"""
|
||||||
Downloads a file from a URL to a local path in the base directory.
|
Downloads a file from a URL to a local path in the base directory.
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,14 @@ For details of how the dataset was prepared, see `repackage_data_reference.py`.
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
import logging
|
||||||
import requests
|
import requests
|
||||||
|
from tqdm import tqdm
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
|
from functools import partial
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
|
|
||||||
from nanochat.common import get_base_dir
|
from nanochat.common import get_base_dir, setup_file_logger
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# The specifics of the current pretraining dataset
|
# The specifics of the current pretraining dataset
|
||||||
|
|
@ -27,6 +30,19 @@ base_dir = get_base_dir()
|
||||||
DATA_DIR = os.path.join(base_dir, "base_data")
|
DATA_DIR = os.path.join(base_dir, "base_data")
|
||||||
os.makedirs(DATA_DIR, exist_ok=True)
|
os.makedirs(DATA_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Minimal logger setup for DEBUG level
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
log_path = setup_file_logger(
|
||||||
|
logger_name=__name__,
|
||||||
|
filename="dataset_download.log",
|
||||||
|
level=logging.DEBUG,
|
||||||
|
formatter=logging.Formatter(
|
||||||
|
"%(asctime)s - %(processName)s - %(levelname)s - %(message)s"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# These functions are useful utilities to other modules, can/should be imported
|
# These functions are useful utilities to other modules, can/should be imported
|
||||||
|
|
||||||
|
|
@ -64,12 +80,12 @@ def download_single_file(index):
|
||||||
filename = index_to_filename(index)
|
filename = index_to_filename(index)
|
||||||
filepath = os.path.join(DATA_DIR, filename)
|
filepath = os.path.join(DATA_DIR, filename)
|
||||||
if os.path.exists(filepath):
|
if os.path.exists(filepath):
|
||||||
print(f"Skipping {filepath} (already exists)")
|
logger.debug(f"Skipping {filepath} (already exists)")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Construct the remote URL for this file
|
# Construct the remote URL for this file
|
||||||
url = f"{BASE_URL}/{filename}"
|
url = f"{BASE_URL}/{filename}"
|
||||||
print(f"Downloading {filename}...")
|
logger.debug(f"Downloading {filename}...")
|
||||||
|
|
||||||
# Download with retries
|
# Download with retries
|
||||||
max_attempts = 5
|
max_attempts = 5
|
||||||
|
|
@ -85,11 +101,11 @@ def download_single_file(index):
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
# Move temp file to final location
|
# Move temp file to final location
|
||||||
os.rename(temp_path, filepath)
|
os.rename(temp_path, filepath)
|
||||||
print(f"Successfully downloaded {filename}")
|
logger.debug(f"Successfully downloaded {filename}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except (requests.RequestException, IOError) as e:
|
except (requests.RequestException, IOError) as e:
|
||||||
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
logger.warning(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
||||||
# Clean up any partial files
|
# Clean up any partial files
|
||||||
for path in [filepath + f".tmp", filepath]:
|
for path in [filepath + f".tmp", filepath]:
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
|
|
@ -100,10 +116,10 @@ def download_single_file(index):
|
||||||
# Try a few times with exponential backoff: 2^attempt seconds
|
# Try a few times with exponential backoff: 2^attempt seconds
|
||||||
if attempt < max_attempts:
|
if attempt < max_attempts:
|
||||||
wait_time = 2 ** attempt
|
wait_time = 2 ** attempt
|
||||||
print(f"Waiting {wait_time} seconds before retry...")
|
logger.debug(f"Waiting {wait_time} seconds before retry...")
|
||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
print(f"Failed to download {filename} after {max_attempts} attempts")
|
logger.debug(f"Failed to download {filename} after {max_attempts} attempts")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
@ -113,16 +129,37 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
|
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
|
||||||
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
|
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
|
||||||
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
|
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
|
||||||
|
parser.add_argument(
|
||||||
|
"-f",
|
||||||
|
"--work-share-factor",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help=(
|
||||||
|
"""Controls how each worker's share of shards is subdivided. CHUNK_SIZE is computed as len(ids_to_download) // (num_workers * work_share_factor), so CHUNK_SIZE is the number of tasks a worker pulls per request from the main process. for example, for 240 shards and 4 workers the default value (8) produces 7 shards per request. setting it 1 gives a worker its entire share (~60 shards) in one go with minimal coordination but slow progress updates. larger work-share-factor values make the main process hand out smaller batches more often for faster feedback at a small scheduling cost."""
|
||||||
|
),
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
|
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
|
||||||
ids_to_download = list(range(num))
|
ids_to_download = list(range(num))
|
||||||
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
logger.info(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
||||||
print(f"Target directory: {DATA_DIR}")
|
logger.info(f"Dataset target directory: {DATA_DIR}")
|
||||||
print()
|
logger.info(f"Dataset downloader debug logs will be written to: {log_path}")
|
||||||
|
|
||||||
|
# pool.imap_unordered pulls `chunksize` tasks from the main process before asking for more
|
||||||
|
work_share_factor = max(1, args.work_share_factor)
|
||||||
|
CHUNK_SIZE = max(1, len(ids_to_download) // (args.num_workers * work_share_factor))
|
||||||
|
ok_count = 0
|
||||||
with Pool(processes=args.num_workers) as pool:
|
with Pool(processes=args.num_workers) as pool:
|
||||||
results = pool.map(download_single_file, ids_to_download)
|
for ok in tqdm(
|
||||||
|
pool.imap_unordered(
|
||||||
|
partial(download_single_file), ids_to_download, chunksize=CHUNK_SIZE
|
||||||
|
),
|
||||||
|
total=len(ids_to_download),
|
||||||
|
desc="all shards",
|
||||||
|
smoothing=0.1,
|
||||||
|
):
|
||||||
|
ok_count += int(ok)
|
||||||
|
|
||||||
# Report results
|
# Report results
|
||||||
successful = sum(1 for success in results if success)
|
logger.info(f"Done! Downloaded: {ok_count}/{len(ids_to_download)} shards to {DATA_DIR}")
|
||||||
print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ dependencies = [
|
||||||
"torch>=2.8.0",
|
"torch>=2.8.0",
|
||||||
"uvicorn>=0.36.0",
|
"uvicorn>=0.36.0",
|
||||||
"wandb>=0.21.3",
|
"wandb>=0.21.3",
|
||||||
|
"tqdm>=4.66.0"
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user