mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Merge branch 'master' into fix/typo
This commit is contained in:
commit
52e85aaf80
|
|
@ -113,7 +113,7 @@ files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --
|
||||||
|
|
||||||
This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` folder, and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
|
This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` folder, and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
|
||||||
|
|
||||||
Alternatively, I recommend using [DeepWiki](https://deepwiki.com/) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
|
Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
|
||||||
|
|
||||||
## Tests
|
## Tests
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,13 +22,6 @@ fi
|
||||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||||
source "$HOME/.cargo/env"
|
source "$HOME/.cargo/env"
|
||||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
|
||||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
|
||||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
|
||||||
unzip -q eval_bundle.zip
|
|
||||||
rm eval_bundle.zip
|
|
||||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
|
||||||
fi
|
|
||||||
|
|
||||||
# wipe the report
|
# wipe the report
|
||||||
python -m nanochat.report reset
|
python -m nanochat.report reset
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ def get_base_dir():
|
||||||
os.makedirs(nanochat_dir, exist_ok=True)
|
os.makedirs(nanochat_dir, exist_ok=True)
|
||||||
return nanochat_dir
|
return nanochat_dir
|
||||||
|
|
||||||
def download_file_with_lock(url, filename):
|
def download_file_with_lock(url, filename, postprocess_fn=None):
|
||||||
"""
|
"""
|
||||||
Downloads a file from a URL to a local path in the base directory.
|
Downloads a file from a URL to a local path in the base directory.
|
||||||
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
||||||
|
|
@ -76,18 +76,24 @@ def download_file_with_lock(url, filename):
|
||||||
# All other ranks block until it is released
|
# All other ranks block until it is released
|
||||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||||
|
|
||||||
|
# Recheck after acquiring lock (another process may have downloaded it)
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
# Download the content as bytes
|
||||||
print(f"Downloading {url}...")
|
print(f"Downloading {url}...")
|
||||||
with urllib.request.urlopen(url) as response:
|
with urllib.request.urlopen(url) as response:
|
||||||
content = response.read().decode('utf-8')
|
content = response.read() # bytes
|
||||||
|
|
||||||
with open(file_path, 'w') as f:
|
# Write to local file
|
||||||
|
with open(file_path, 'wb') as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
print(f"Downloaded to {file_path}")
|
print(f"Downloaded to {file_path}")
|
||||||
|
|
||||||
|
# Run the postprocess function if provided
|
||||||
|
if postprocess_fn is not None:
|
||||||
|
postprocess_fn(file_path)
|
||||||
|
|
||||||
# Clean up the lock file after the lock is released
|
# Clean up the lock file after the lock is released
|
||||||
try:
|
try:
|
||||||
os.remove(lock_path)
|
os.remove(lock_path)
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ def eval_with_timeout(formula, max_time=3):
|
||||||
with timeout(max_time, formula):
|
with timeout(max_time, formula):
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore", SyntaxWarning)
|
warnings.simplefilter("ignore", SyntaxWarning)
|
||||||
return eval(formula)
|
return eval(formula, {"__builtins__": {}}, {})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
signal.alarm(0)
|
signal.alarm(0)
|
||||||
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
|
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
|
||||||
|
|
@ -109,7 +109,7 @@ class KVCache:
|
||||||
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
|
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
|
||||||
if ix in [0, 1, 3, 5]:
|
if ix in [0, 1, 3, 5]:
|
||||||
# num_layers, batch_size, num_heads, head_dim must match
|
# num_layers, batch_size, num_heads, head_dim must match
|
||||||
assert dim1 == dim2, f"Batch dim mismatch: {dim1} != {dim2}"
|
assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}"
|
||||||
elif ix == 2:
|
elif ix == 2:
|
||||||
# batch_size can be expanded
|
# batch_size can be expanded
|
||||||
assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
|
assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ dependencies = [
|
||||||
"datasets>=4.0.0",
|
"datasets>=4.0.0",
|
||||||
"fastapi>=0.117.1",
|
"fastapi>=0.117.1",
|
||||||
"files-to-prompt>=0.6",
|
"files-to-prompt>=0.6",
|
||||||
"numpy==1.26.4",
|
|
||||||
"psutil>=7.1.0",
|
"psutil>=7.1.0",
|
||||||
"regex>=2025.9.1",
|
"regex>=2025.9.1",
|
||||||
"setuptools>=80.9.0",
|
"setuptools>=80.9.0",
|
||||||
|
|
|
||||||
|
|
@ -19,13 +19,6 @@ python -m nanochat.report reset
|
||||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||||
source "$HOME/.cargo/env"
|
source "$HOME/.cargo/env"
|
||||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
|
||||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
|
||||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
|
||||||
unzip -q eval_bundle.zip
|
|
||||||
rm eval_bundle.zip
|
|
||||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
|
||||||
fi
|
|
||||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||||
|
|
||||||
# train tokenizer on ~4B characters and kick off download of the rest for pretraining
|
# train tokenizer on ~4B characters and kick off download of the rest for pretraining
|
||||||
|
|
|
||||||
|
|
@ -2,48 +2,75 @@
|
||||||
Evaluate the CORE metric for a given model.
|
Evaluate the CORE metric for a given model.
|
||||||
|
|
||||||
Run on a single GPU:
|
Run on a single GPU:
|
||||||
python base_eval.py
|
python -m scripts.base_eval
|
||||||
|
|
||||||
Run with torchrun on e.g. 8 GPUs:
|
Run with torchrun on e.g. 8 GPUs:
|
||||||
torchrun --nproc_per_node=8 base_eval.py
|
torchrun --nproc_per_node=8 -m scripts.base_eval
|
||||||
|
|
||||||
The script will print the CORE metric to the console.
|
The script will print the CORE metric to the console.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import sys
|
import csv
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import random
|
|
||||||
import yaml
|
import yaml
|
||||||
|
import shutil
|
||||||
|
import random
|
||||||
|
import zipfile
|
||||||
|
import tempfile
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
|
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
|
||||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.core_eval import evaluate_task
|
from nanochat.core_eval import evaluate_task
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# nanoChat specific function dealing with I/O etc.
|
# nanochat specific function dealing with I/O etc.
|
||||||
|
|
||||||
|
# ~162MB of data needed to evaluate the CORE metric
|
||||||
|
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
|
||||||
|
|
||||||
|
def place_eval_bundle(file_path):
|
||||||
|
# here file_path is the path to the eval_bundle.zip file
|
||||||
|
# we need to unzip it and place it in the base directory
|
||||||
|
base_dir = get_base_dir()
|
||||||
|
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
||||||
|
zip_ref.extractall(tmpdir)
|
||||||
|
extracted_bundle_dir = os.path.join(tmpdir, "eval_bundle")
|
||||||
|
shutil.move(extracted_bundle_dir, eval_bundle_dir)
|
||||||
|
print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
|
||||||
|
|
||||||
def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||||
"""
|
"""
|
||||||
Evaluate a base model on the CORE benchmark.
|
Evaluate a base model on the CORE benchmark.
|
||||||
- max_per_task: crop the data to this many examples per task for testing (-1 = disable)
|
- max_per_task: crop the data to this many examples per task for testing (-1 = disable)
|
||||||
TODO: clean up this function, delete the need for all the files, for pandas dependency, etc.
|
|
||||||
"""
|
"""
|
||||||
# Load config and task metadata
|
# Load config and task metadata
|
||||||
base_dir = get_base_dir()
|
base_dir = get_base_dir()
|
||||||
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
||||||
|
# Download the eval bundle to disk (and unzip if needed)
|
||||||
|
if not os.path.exists(eval_bundle_dir):
|
||||||
|
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
|
||||||
config_path = os.path.join(eval_bundle_dir, "core.yaml")
|
config_path = os.path.join(eval_bundle_dir, "core.yaml")
|
||||||
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
|
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
|
||||||
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
|
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
config = yaml.safe_load(f)
|
config = yaml.safe_load(f)
|
||||||
tasks = config['icl_tasks']
|
tasks = config['icl_tasks']
|
||||||
eval_metadata = pd.read_csv(eval_meta_data)
|
|
||||||
|
# Load random baseline values from eval metadata
|
||||||
|
random_baselines = {}
|
||||||
|
with open(eval_meta_data, 'r', encoding='utf-8') as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
task_name = row['Eval Task']
|
||||||
|
random_baseline = row['Random baseline']
|
||||||
|
random_baselines[task_name] = float(random_baseline)
|
||||||
|
|
||||||
# Evaluate each task
|
# Evaluate each task
|
||||||
results = {}
|
results = {}
|
||||||
|
|
@ -75,8 +102,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||||
accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
|
accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
|
||||||
|
|
||||||
results[label] = accuracy
|
results[label] = accuracy
|
||||||
row = eval_metadata[eval_metadata["Eval Task"] == label]
|
random_baseline = random_baselines[label]
|
||||||
random_baseline = row["Random baseline"].values[0]
|
|
||||||
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
||||||
centered_results[label] = centered_result
|
centered_results[label] = centered_result
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
|
||||||
|
|
@ -294,7 +294,7 @@ for step in range(num_iterations + 1):
|
||||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||||
pct_done = 100 * step / num_iterations
|
pct_done = 100 * step / num_iterations
|
||||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
tok_per_sec = int(total_batch_size / dt)
|
||||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||||
|
|
|
||||||
|
|
@ -268,7 +268,7 @@ while True:
|
||||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||||
pct_done = 100 * progress
|
pct_done = 100 * progress
|
||||||
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
|
tok_per_sec = int(total_batch_size / dt)
|
||||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||||
|
|
|
||||||
|
|
@ -73,15 +73,6 @@ python -m scripts.tok_eval
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Base model (pretraining)
|
# Base model (pretraining)
|
||||||
|
|
||||||
# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB)
|
|
||||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
|
||||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
|
||||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
|
||||||
unzip -q eval_bundle.zip
|
|
||||||
rm eval_bundle.zip
|
|
||||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
|
||||||
fi
|
|
||||||
|
|
||||||
# The d20 model is 561M parameters.
|
# The d20 model is 561M parameters.
|
||||||
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
||||||
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user