diff --git a/dev/runcpu.sh b/dev/runcpu.sh index 469e51d..ffacefa 100755 --- a/dev/runcpu.sh +++ b/dev/runcpu.sh @@ -22,13 +22,6 @@ fi curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y source "$HOME/.cargo/env" uv run maturin develop --release --manifest-path rustbpe/Cargo.toml -EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip -if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then - curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL - unzip -q eval_bundle.zip - rm eval_bundle.zip - mv eval_bundle $NANOCHAT_BASE_DIR -fi # wipe the report python -m nanochat.report reset diff --git a/nanochat/common.py b/nanochat/common.py index a5a6d2e..8272378 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -58,7 +58,7 @@ def get_base_dir(): os.makedirs(nanochat_dir, exist_ok=True) return nanochat_dir -def download_file_with_lock(url, filename): +def download_file_with_lock(url, filename, postprocess_fn=None): """ Downloads a file from a URL to a local path in the base directory. Uses a lock file to prevent concurrent downloads among multiple ranks. @@ -76,18 +76,24 @@ def download_file_with_lock(url, filename): # All other ranks block until it is released fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + # Recheck after acquiring lock (another process may have downloaded it) if os.path.exists(file_path): return file_path + # Download the content as bytes print(f"Downloading {url}...") with urllib.request.urlopen(url) as response: - content = response.read().decode('utf-8') + content = response.read() # bytes - with open(file_path, 'w') as f: + # Write to local file + with open(file_path, 'wb') as f: f.write(content) - print(f"Downloaded to {file_path}") + # Run the postprocess function if provided + if postprocess_fn is not None: + postprocess_fn(file_path) + # Clean up the lock file after the lock is released try: os.remove(lock_path) diff --git a/run1000.sh b/run1000.sh index 6f454e0..e0bc4c4 100644 --- a/run1000.sh +++ b/run1000.sh @@ -19,13 +19,6 @@ python -m nanochat.report reset curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y source "$HOME/.cargo/env" uv run maturin develop --release --manifest-path rustbpe/Cargo.toml -EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip -if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then - curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL - unzip -q eval_bundle.zip - rm eval_bundle.zip - mv eval_bundle $NANOCHAT_BASE_DIR -fi curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl # train tokenizer on ~4B characters and kick off download of the rest for pretraining diff --git a/scripts/base_eval.py b/scripts/base_eval.py index c488c8a..21f7bac 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -2,10 +2,10 @@ Evaluate the CORE metric for a given model. Run on a single GPU: -python base_eval.py +python -m scripts.base_eval Run with torchrun on e.g. 8 GPUs: -torchrun --nproc_per_node=8 base_eval.py +torchrun --nproc_per_node=8 -m scripts.base_eval The script will print the CORE metric to the console. """ @@ -13,13 +13,16 @@ import os import csv import time import json -import random import yaml +import shutil +import random +import zipfile +import tempfile from contextlib import nullcontext import torch -from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock from nanochat.tokenizer import HuggingFaceTokenizer from nanochat.checkpoint_manager import load_model from nanochat.core_eval import evaluate_task @@ -27,6 +30,21 @@ from nanochat.core_eval import evaluate_task # ----------------------------------------------------------------------------- # nanochat specific function dealing with I/O etc. +# ~162MB of data needed to evaluate the CORE metric +EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip" + +def place_eval_bundle(file_path): + # here file_path is the path to the eval_bundle.zip file + # we need to unzip it and place it in the base directory + base_dir = get_base_dir() + eval_bundle_dir = os.path.join(base_dir, "eval_bundle") + with tempfile.TemporaryDirectory() as tmpdir: + with zipfile.ZipFile(file_path, 'r') as zip_ref: + zip_ref.extractall(tmpdir) + extracted_bundle_dir = os.path.join(tmpdir, "eval_bundle") + shutil.move(extracted_bundle_dir, eval_bundle_dir) + print0(f"Placed eval_bundle directory at {eval_bundle_dir}") + def evaluate_model(model, tokenizer, device, max_per_task=-1): """ Evaluate a base model on the CORE benchmark. @@ -35,6 +53,9 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): # Load config and task metadata base_dir = get_base_dir() eval_bundle_dir = os.path.join(base_dir, "eval_bundle") + # Download the eval bundle to disk (and unzip if needed) + if not os.path.exists(eval_bundle_dir): + download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle) config_path = os.path.join(eval_bundle_dir, "core.yaml") data_base_path = os.path.join(eval_bundle_dir, "eval_data") eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv") diff --git a/speedrun.sh b/speedrun.sh index 35dd39e..32c8870 100644 --- a/speedrun.sh +++ b/speedrun.sh @@ -73,15 +73,6 @@ python -m scripts.tok_eval # ----------------------------------------------------------------------------- # Base model (pretraining) -# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB) -EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip -if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then - curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL - unzip -q eval_bundle.zip - rm eval_bundle.zip - mv eval_bundle $NANOCHAT_BASE_DIR -fi - # The d20 model is 561M parameters. # Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens. # Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.