Merge branch 'master' into master

This commit is contained in:
Andrej 2025-11-04 16:35:02 -08:00 committed by GitHub
commit 3a2ae631c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 45 additions and 36 deletions

View File

@ -37,7 +37,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from nanochat.common import get_base_dir from nanochat.common import get_base_dir
api_key = open("openroutertoken.txt").read().strip() api_key = open("openroutertoken.txt", "r", encoding="utf-8").read().strip()
url = "https://openrouter.ai/api/v1/chat/completions" url = "https://openrouter.ai/api/v1/chat/completions"
headers = { headers = {
@ -45,7 +45,7 @@ headers = {
"Content-Type": "application/json" "Content-Type": "application/json"
} }
readme = open("README.md").read().strip() readme = open("README.md", "r", encoding="utf-8").read().strip()
prompt = r""" prompt = r"""
I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want: I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want:

View File

@ -34,7 +34,7 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data)
log0(f"Saved optimizer file to: {optimizer_path}") log0(f"Saved optimizer file to: {optimizer_path}")
# Save the metadata dict as json # Save the metadata dict as json
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
with open(meta_path, "w") as f: with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2) json.dump(meta_data, f, indent=2)
log0(f"Saved metadata file to: {meta_path}") log0(f"Saved metadata file to: {meta_path}")
@ -50,7 +50,7 @@ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
optimizer_data = torch.load(optimizer_path, map_location=device) optimizer_data = torch.load(optimizer_path, map_location=device)
# Load the metadata # Load the metadata
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
with open(meta_path, "r") as f: with open(meta_path, "r", encoding="utf-8") as f:
meta_data = json.load(f) meta_data = json.load(f)
return model_data, optimizer_data, meta_data return model_data, optimizer_data, meta_data
@ -65,7 +65,7 @@ def build_model(checkpoint_dir, step, device, phase):
""" """
assert phase in ["train", "eval"], f"Invalid phase: {phase}" assert phase in ["train", "eval"], f"Invalid phase: {phase}"
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
if device.type == "cpu": if device.type in {"cpu", "mps"}:
# Convert bfloat16 tensors to float for CPU inference # Convert bfloat16 tensors to float for CPU inference
model_data = { model_data = {
k: v.float() if v.dtype == torch.bfloat16 else v k: v.float() if v.dtype == torch.bfloat16 else v

View File

@ -71,8 +71,10 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
return file_path return file_path
with FileLock(lock_path): with FileLock(lock_path):
# Only a single rank can acquire this lock
# All other ranks block until it is released
# Recheck after acquiring lock (another process may have downloaded it) # Recheck after acquiring lock
if os.path.exists(file_path): if os.path.exists(file_path):
return file_path return file_path

View File

@ -170,7 +170,7 @@ Generated: {timestamp}
# count dependencies via uv.lock # count dependencies via uv.lock
uv_lock_lines = 0 uv_lock_lines = 0
if os.path.exists('uv.lock'): if os.path.exists('uv.lock'):
with open('uv.lock', 'r') as f: with open('uv.lock', 'r', encoding='utf-8') as f:
uv_lock_lines = len(f.readlines()) uv_lock_lines = len(f.readlines())
header += f""" header += f"""
@ -241,7 +241,7 @@ class Report:
slug = slugify(section) slug = slugify(section)
file_name = f"{slug}.md" file_name = f"{slug}.md"
file_path = os.path.join(self.report_dir, file_name) file_path = os.path.join(self.report_dir, file_name)
with open(file_path, "w") as f: with open(file_path, "w", encoding="utf-8") as f:
f.write(f"## {section}\n") f.write(f"## {section}\n")
f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
for item in data: for item in data:
@ -272,11 +272,11 @@ class Report:
final_metrics = {} # the most important final metrics we'll add as table at the end final_metrics = {} # the most important final metrics we'll add as table at the end
start_time = None start_time = None
end_time = None end_time = None
with open(report_file, "w") as out_file: with open(report_file, "w", encoding="utf-8") as out_file:
# write the header first # write the header first
header_file = os.path.join(report_dir, "header.md") header_file = os.path.join(report_dir, "header.md")
if os.path.exists(header_file): if os.path.exists(header_file):
with open(header_file, "r") as f: with open(header_file, "r", encoding="utf-8") as f:
header_content = f.read() header_content = f.read()
out_file.write(header_content) out_file.write(header_content)
start_time = extract_timestamp(header_content, "Run started:") start_time = extract_timestamp(header_content, "Run started:")
@ -293,7 +293,7 @@ class Report:
if not os.path.exists(section_file): if not os.path.exists(section_file):
print(f"Warning: {section_file} does not exist, skipping") print(f"Warning: {section_file} does not exist, skipping")
continue continue
with open(section_file, "r") as in_file: with open(section_file, "r", encoding="utf-8") as in_file:
section = in_file.read() section = in_file.read()
# Extract timestamp from this section (the last section's timestamp will "stick" as end_time) # Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
if "rl" not in file_name: if "rl" not in file_name:
@ -373,7 +373,7 @@ class Report:
header_file = os.path.join(self.report_dir, "header.md") header_file = os.path.join(self.report_dir, "header.md")
header = generate_header() header = generate_header()
start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(header_file, "w") as f: with open(header_file, "w", encoding="utf-8") as f:
f.write(header) f.write(header)
f.write(f"Run started: {start_time}\n\n---\n\n") f.write(f"Run started: {start_time}\n\n---\n\n")
print(f"Reset report and wrote header to {header_file}") print(f"Reset report and wrote header to {header_file}")

View File

@ -70,18 +70,22 @@ python -m scripts.tok_eval
# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd # which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
# start to overfit hard. # start to overfit hard.
# 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script. # 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script.
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=32 --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss # Number of processes/GPUs to use
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval NPROC_PER_NODE=8
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
# midtrain # midtrain
# NOTE: ensure that we use the same device_batch_size here as the base training script. # NOTE: ensure that we use the same device_batch_size here as the base training script.
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# sft # sft
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
# generate final report # generate final report
python -m nanochat.report generate python -m nanochat.report generate

View File

@ -59,7 +59,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
config_path = os.path.join(eval_bundle_dir, "core.yaml") config_path = os.path.join(eval_bundle_dir, "core.yaml")
data_base_path = os.path.join(eval_bundle_dir, "eval_data") data_base_path = os.path.join(eval_bundle_dir, "eval_data")
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv") eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
with open(config_path, 'r') as f: with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
tasks = config['icl_tasks'] tasks = config['icl_tasks']
@ -193,7 +193,7 @@ def main():
print0("="*80) print0("="*80)
print0(f"Model: {model_name}") print0(f"Model: {model_name}")
print0("="*80) print0("="*80)
with open(output_csv_path, 'r') as f: with open(output_csv_path, 'r', encoding='utf-8') as f:
print0(f.read()) print0(f.read())
# Log to report # Log to report

View File

@ -243,7 +243,7 @@ app.add_middleware(
async def root(): async def root():
"""Serve the chat UI.""" """Serve the chat UI."""
ui_html_path = os.path.join("nanochat", "ui.html") ui_html_path = os.path.join("nanochat", "ui.html")
with open(ui_html_path, "r") as f: with open(ui_html_path, "r", encoding="utf-8") as f:
html_content = f.read() html_content = f.read()
# Replace the API_URL to use the same origin # Replace the API_URL to use the same origin
html_content = html_content.replace( html_content = html_content.replace(

View File

@ -82,12 +82,15 @@ python -m scripts.tok_eval
echo "Waiting for dataset download to complete..." echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID wait $DATASET_DOWNLOAD_PID
# Number of processes/GPUs to use
NPROC_PER_NODE=8
# pretrain the d20 model # pretrain the d20 model
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples # evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
# evaluate the model on CORE tasks # evaluate the model on CORE tasks
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice) # Midtraining (teach the model conversation special tokens, tool use, multiple choice)
@ -97,15 +100,15 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
# run midtraining and eval the model # run midtraining and eval the model
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Supervised Finetuning (domain adaptation to each sequence all by itself per row) # Supervised Finetuning (domain adaptation to each sequence all by itself per row)
# train sft and re-eval right away (should see a small bump) # train sft and re-eval right away (should see a small bump)
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
# chat with the model over CLI! Leave out the -p to chat interactively # chat with the model over CLI! Leave out the -p to chat interactively
# python -m scripts.chat_cli -p "Why is the sky blue?" # python -m scripts.chat_cli -p "Why is the sky blue?"
@ -118,9 +121,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
# (optional) # (optional)
# run reinforcement learning # run reinforcement learning
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
# eval the RL model only on GSM8K # eval the RL model only on GSM8K
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Generate the full report by putting together all the sections # Generate the full report by putting together all the sections

View File

@ -32,7 +32,7 @@ class CustomJSON(Task):
print("-" * 80) print("-" * 80)
else: else:
with open(filepath, 'r') as f: with open(filepath, 'r', encoding='utf-8') as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if not line: # skip empty lines if not line: # skip empty lines

View File

@ -119,7 +119,7 @@ class SpellingBee(Task):
self.split = split self.split = split
filename = WORD_LIST_URL.split("/")[-1] filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename) word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path) as f: with open(word_list_path, 'r', encoding='utf-8') as f:
words = [line.strip() for line in f] words = [line.strip() for line in f]
self.words = words self.words = words
@ -238,7 +238,7 @@ class SimpleSpelling(Task):
self.split = split self.split = split
filename = WORD_LIST_URL.split("/")[-1] filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename) word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path) as f: with open(word_list_path, 'r', encoding='utf-8') as f:
words = [line.strip() for line in f] words = [line.strip() for line in f]
rng = random.Random(42) rng = random.Random(42)
rng.shuffle(words) # use a different word order than the SpellingBee task rng.shuffle(words) # use a different word order than the SpellingBee task

View File

@ -455,13 +455,13 @@ def enwik8_path():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def enwik8_small(enwik8_path): def enwik8_small(enwik8_path):
"""Fixture providing 100KB of enwik8 for quick tests.""" """Fixture providing 100KB of enwik8 for quick tests."""
with open(enwik8_path, "r") as f: with open(enwik8_path, "r", encoding="utf-8") as f:
return f.read(100_000) return f.read(100_000)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def enwik8_large(enwik8_path): def enwik8_large(enwik8_path):
"""Fixture providing 10MB of enwik8 for performance tests.""" """Fixture providing 10MB of enwik8 for performance tests."""
with open(enwik8_path, "r") as f: with open(enwik8_path, "r", encoding="utf-8") as f:
return f.read(10**7) return f.read(10**7)
def time_function(func, *args, **kwargs): def time_function(func, *args, **kwargs):