Merge branch 'master' into refactor-vertex-ai-pipelines

This commit is contained in:
javasoup 2025-12-01 20:07:43 -05:00 committed by GitHub
commit 2adcc95c4e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 204 additions and 101 deletions

View File

@ -184,6 +184,7 @@ python -m pytest tests/test_rustbpe.py -v -s
│ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF │ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF
│ └── spellingbee.py # Task teaching model to spell/count letters │ └── spellingbee.py # Task teaching model to spell/count letters
├── tests ├── tests
│ └── test_engine.py
│ └── test_rustbpe.py │ └── test_rustbpe.py
└── uv.lock └── uv.lock
``` ```
@ -201,6 +202,7 @@ Current LLM policy: disclosure. When submitting a PR, please declare any parts t
- Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk. - Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk.
- Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project. - Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project.
- Thank you to chief LLM whisperer 🧙‍♂️ Alec Radford for advice/guidance. - Thank you to chief LLM whisperer 🧙‍♂️ Alec Radford for advice/guidance.
- Thank you to the repo czar Sofie [@svlandeg](https://github.com/svlandeg) for help with managing issues, pull requests and discussions of nanochat.
## Cite ## Cite

View File

@ -37,7 +37,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from nanochat.common import get_base_dir from nanochat.common import get_base_dir
api_key = open("openroutertoken.txt").read().strip() api_key = open("openroutertoken.txt", "r", encoding="utf-8").read().strip()
url = "https://openrouter.ai/api/v1/chat/completions" url = "https://openrouter.ai/api/v1/chat/completions"
headers = { headers = {
@ -45,7 +45,7 @@ headers = {
"Content-Type": "application/json" "Content-Type": "application/json"
} }
readme = open("README.md").read().strip() readme = open("README.md", "r", encoding="utf-8").read().strip()
prompt = r""" prompt = r"""
I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want: I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want:

View File

@ -119,7 +119,7 @@ def build_model(checkpoint_dir, step, device, phase):
""" """
assert phase in ["train", "eval"], f"Invalid phase: {phase}" assert phase in ["train", "eval"], f"Invalid phase: {phase}"
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
if device.type == "cpu": if device.type in {"cpu", "mps"}:
# Convert bfloat16 tensors to float for CPU inference # Convert bfloat16 tensors to float for CPU inference
model_data = { model_data = {
k: v.float() if v.dtype == torch.bfloat16 else v k: v.float() if v.dtype == torch.bfloat16 else v

View File

@ -5,10 +5,10 @@ Common utilities for nanochat.
import os import os
import re import re
import logging import logging
import fcntl
import urllib.request import urllib.request
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from filelock import FileLock
class ColoredFormatter(logging.Formatter): class ColoredFormatter(logging.Formatter):
"""Custom formatter that adds colors to log messages.""" """Custom formatter that adds colors to log messages."""
@ -71,13 +71,11 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
if os.path.exists(file_path): if os.path.exists(file_path):
return file_path return file_path
with open(lock_path, 'w') as lock_file: with FileLock(lock_path):
# Only a single rank can acquire this lock # Only a single rank can acquire this lock
# All other ranks block until it is released # All other ranks block until it is released
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
# Recheck after acquiring lock (another process may have downloaded it) # Recheck after acquiring lock
if os.path.exists(file_path): if os.path.exists(file_path):
return file_path return file_path
@ -95,12 +93,6 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
if postprocess_fn is not None: if postprocess_fn is not None:
postprocess_fn(file_path) postprocess_fn(file_path)
# Clean up the lock file after the lock is released
try:
os.remove(lock_path)
except OSError:
pass # Ignore if already removed by another process
return file_path return file_path
def print0(s="",**kwargs): def print0(s="",**kwargs):
@ -169,6 +161,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
# Reproducibility # Reproducibility
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
# The only place where global rng might be used is nn.Module initialization of the model weights.
torch.manual_seed(42) torch.manual_seed(42)
if device_type == "cuda": if device_type == "cuda":
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)

View File

@ -1,49 +1,87 @@
from collections import deque from collections import deque
import torch import torch
import pyarrow.parquet as pq
from nanochat.common import get_dist_info from nanochat.common import get_dist_info
from nanochat.dataset import parquets_iter_batched from nanochat.dataset import list_parquet_files
from nanochat.tokenizer import get_tokenizer from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"): def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
"""Stream pretraining text from parquet files, tokenize, yield training batches.""" """
Stream pretraining text from parquet files, tokenize, yield training batches.
This implementation became a bit more complex because we wish to support approximate resume training.
Instead of turning this into a Class, we opt to return the state_dict with every batch,
and then the caller can pass in a state_dict to resume training from a desired point.
Note that this resumption is atm only *approximate* for simplicity.
We won't repeat the same documents but we might skip a few.
The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
"""
assert split in ["train", "val"], "split must be 'train' or 'val'" assert split in ["train", "val"], "split must be 'train' or 'val'"
# infinite iterator over document batches (list of text strings)
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
def document_batches():
parquet_paths = list_parquet_files()
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
while True: # iterate infinitely (multi-epoch)
while pq_idx < len(parquet_paths): # iterate over all parquet files
filepath = parquet_paths[pq_idx]
pf = pq.ParquetFile(filepath)
# Start from resume point if resuming on same file, otherwise from DDP rank
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
if resume_rg_idx is not None:
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank
resume_rg_idx = None # set to None as we only want to do this a single time
else:
rg_idx = ddp_rank
while rg_idx < pf.num_row_groups:
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
rg_idx += ddp_world_size # advance to the next row group (in DDP)
pq_idx += 1 # advance to the next parquet file
batches = document_batches()
# Now emit batches of tokens.
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
# get the tokenizer and the bos token # get the tokenizer and the bos token
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
bos_token = tokenizer.get_bos_token_id() bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration # scratch buffer holds the tokens for one iteration
token_buffer = deque() # we stream tokens on the right and pop from the left token_buffer = deque() # we stream tokens on the right and pop from the left
# infinite iterator over document batches
def document_batches():
while True:
# batch will iterate in group size of the parquet files, usually e.g. 1024 rows
for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
# for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size]
batches = document_batches()
batch_index = 0
while True: while True:
# Accumulate enough tokens for one iteration before yielding. # Accumulate enough tokens for one iteration before yielding.
while len(token_buffer) < needed_tokens: while len(token_buffer) < needed_tokens:
doc_batch = next(batches) doc_batch, (pq_idx, rg_idx) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists: for tokens in token_lists:
token_buffer.extend(tokens) token_buffer.extend(tokens)
batch_index += 1
# Move tokens from the deque into the scratch buffer # Move tokens from the deque into the scratch buffer
tokens = [token_buffer.popleft() for _ in range(needed_tokens)] tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
# CUDA supports memory pinning for faster transfers between CPU and GPU: # CUDA supports memory pinning for asynchronous transfers between CPU and GPU
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda")) use_cuda_optimizations = device == "cuda"
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
# Create the inputs/targets as 1D tensors # Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1].to(dtype=torch.int32) inputs_cpu = scratch[:-1]
targets_cpu = scratch[1:] targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async # Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True) inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True) targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
yield inputs, targets, state_dict
def tokenizing_distributed_data_loader(*args, **kwargs):
# helper function that only emits the inputs/targets and not the state_dict
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
yield inputs, targets yield inputs, targets

View File

@ -17,8 +17,9 @@ import signal
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from collections import deque from collections import deque
from nanochat.common import compute_init from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from contextlib import nullcontext
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Calculator tool helpers # Calculator tool helpers
@ -107,8 +108,9 @@ class KVCache:
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
assert other.kv_cache is not None, "Cannot prefill with a None KV cache" assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)): for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
# ix 0: num_layers, 1: k/v, 2: batch_size, 3: num_heads, 4: seq_len, 5: head_dim
if ix in [0, 1, 3, 5]: if ix in [0, 1, 3, 5]:
# num_layers, batch_size, num_heads, head_dim must match # num_layers, k/v, num_heads, head_dim must match
assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}" assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}"
elif ix == 2: elif ix == 2:
# batch_size can be expanded # batch_size can be expanded
@ -327,6 +329,9 @@ if __name__ == "__main__":
import time import time
# init compute # init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# load the model and tokenizer # load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="eval") model, tokenizer, meta = load_model("base", device, phase="eval")
bos_token_id = tokenizer.get_bos_token_id() bos_token_id = tokenizer.get_bos_token_id()
@ -339,10 +344,11 @@ if __name__ == "__main__":
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
stream = model.generate(prompt_tokens, **kwargs) stream = model.generate(prompt_tokens, **kwargs)
for token in stream: with autocast_ctx:
generated_tokens.append(token) for token in stream:
chunk = tokenizer.decode([token]) generated_tokens.append(token)
print(chunk, end="", flush=True) chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print() print()
torch.cuda.synchronize() torch.cuda.synchronize()
t1 = time.time() t1 = time.time()
@ -354,11 +360,12 @@ if __name__ == "__main__":
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for token_column, token_masks in stream: with autocast_ctx:
token = token_column[0] # only print out the first row for token_column, token_masks in stream:
generated_tokens.append(token) token = token_column[0] # only print out the first row
chunk = tokenizer.decode([token]) generated_tokens.append(token)
print(chunk, end="", flush=True) chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print() print()
torch.cuda.synchronize() torch.cuda.synchronize()
t1 = time.time() t1 = time.time()

View File

@ -8,7 +8,7 @@ Notable features:
- norm after token embedding - norm after token embedding
- no learnable params in rmsnorm - no learnable params in rmsnorm
- no bias in linear layers - no bias in linear layers
- Multi-Query Attention (MQA) support for more efficient inference - Group-Query Attention (GQA) support for more efficient inference
""" """
import math import math
@ -29,7 +29,7 @@ class GPTConfig:
vocab_size: int = 50304 vocab_size: int = 50304
n_layer: int = 12 n_layer: int = 12
n_head: int = 6 # number of query heads n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (MQA) n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768 n_embd: int = 768
@ -244,7 +244,7 @@ class GPT(nn.Module):
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size() B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"

View File

@ -9,9 +9,9 @@ import torch.distributed as dist
def evaluate_bpb(model, batches, steps, token_bytes): def evaluate_bpb(model, batches, steps, token_bytes):
""" """
Instead of the naive 'mean loss', this function returns the bits per byte (bpb), Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
which is a tokenization vocab size-indepedent metric, meaning you are still comparing which is a tokenization vocab size-independent metric, meaning you are still comparing
apples:apples if you change the vocab size. The way this works is that instead of just apples:apples if you change the vocab size. The way this works is that instead of just
calculating the average loss as usual, you calculate the sum loss, and indepependently calculating the average loss as usual, you calculate the sum loss, and independently
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
the number of bytes that the target tokens represent. the number of bytes that the target tokens represent.

View File

@ -170,7 +170,7 @@ Generated: {timestamp}
# count dependencies via uv.lock # count dependencies via uv.lock
uv_lock_lines = 0 uv_lock_lines = 0
if os.path.exists('uv.lock'): if os.path.exists('uv.lock'):
with open('uv.lock', 'r') as f: with open('uv.lock', 'r', encoding='utf-8') as f:
uv_lock_lines = len(f.readlines()) uv_lock_lines = len(f.readlines())
header += f""" header += f"""
@ -241,7 +241,7 @@ class Report:
slug = slugify(section) slug = slugify(section)
file_name = f"{slug}.md" file_name = f"{slug}.md"
file_path = os.path.join(self.report_dir, file_name) file_path = os.path.join(self.report_dir, file_name)
with open(file_path, "w") as f: with open(file_path, "w", encoding="utf-8") as f:
f.write(f"## {section}\n") f.write(f"## {section}\n")
f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
for item in data: for item in data:
@ -272,11 +272,11 @@ class Report:
final_metrics = {} # the most important final metrics we'll add as table at the end final_metrics = {} # the most important final metrics we'll add as table at the end
start_time = None start_time = None
end_time = None end_time = None
with open(report_file, "w") as out_file: with open(report_file, "w", encoding="utf-8") as out_file:
# write the header first # write the header first
header_file = os.path.join(report_dir, "header.md") header_file = os.path.join(report_dir, "header.md")
if os.path.exists(header_file): if os.path.exists(header_file):
with open(header_file, "r") as f: with open(header_file, "r", encoding="utf-8") as f:
header_content = f.read() header_content = f.read()
out_file.write(header_content) out_file.write(header_content)
start_time = extract_timestamp(header_content, "Run started:") start_time = extract_timestamp(header_content, "Run started:")
@ -293,7 +293,7 @@ class Report:
if not os.path.exists(section_file): if not os.path.exists(section_file):
print(f"Warning: {section_file} does not exist, skipping") print(f"Warning: {section_file} does not exist, skipping")
continue continue
with open(section_file, "r") as in_file: with open(section_file, "r", encoding="utf-8") as in_file:
section = in_file.read() section = in_file.read()
# Extract timestamp from this section (the last section's timestamp will "stick" as end_time) # Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
if "rl" not in file_name: if "rl" not in file_name:
@ -373,7 +373,7 @@ class Report:
header_file = os.path.join(self.report_dir, "header.md") header_file = os.path.join(self.report_dir, "header.md")
header = generate_header() header = generate_header()
start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(header_file, "w") as f: with open(header_file, "w", encoding="utf-8") as f:
f.write(header) f.write(header)
f.write(f"Run started: {start_time}\n\n---\n\n") f.write(f"Run started: {start_time}\n\n---\n\n")
print(f"Reset report and wrote header to {header_file}") print(f"Reset report and wrote header to {header_file}")

View File

@ -70,18 +70,22 @@ python -m scripts.tok_eval
# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd # which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
# start to overfit hard. # start to overfit hard.
# 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script. # 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script.
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=32 --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss # Number of processes/GPUs to use
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval NPROC_PER_NODE=8
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
# midtrain # midtrain
# NOTE: ensure that we use the same device_batch_size here as the base training script. # NOTE: ensure that we use the same device_batch_size here as the base training script.
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# sft # sft
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
# generate final report # generate final report
python -m nanochat.report generate python -m nanochat.report generate

View File

@ -75,7 +75,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
config_path = os.path.join(eval_bundle_dir, "core.yaml") config_path = os.path.join(eval_bundle_dir, "core.yaml")
data_base_path = os.path.join(eval_bundle_dir, "eval_data") data_base_path = os.path.join(eval_bundle_dir, "eval_data")
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv") eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
with open(config_path, 'r') as f: with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
tasks = config['icl_tasks'] tasks = config['icl_tasks']
@ -209,7 +209,7 @@ def main():
print0("="*80) print0("="*80)
print0(f"Model: {model_name}") print0(f"Model: {model_name}")
print0("="*80) print0("="*80)
with open(output_csv_path, 'r') as f: with open(output_csv_path, 'r', encoding='utf-8') as f:
print0(f.read()) print0(f.read())
# Log to report # Log to report

View File

@ -23,7 +23,7 @@ from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_experiment_logger from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_experiment_logger
from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine from nanochat.engine import Engine
from scripts.base_eval import evaluate_model from scripts.base_eval import evaluate_model
@ -54,12 +54,14 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
warmup_ratio = 0.0 # ratio of iterations for LR warmup warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR final_lr_frac = 0.0 # final LR is this fraction of the initial LR
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
# Evaluation # Evaluation
eval_every = 250 # every how many steps to evaluate the model for val bpb eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_tokens = 20*524288 # number of tokens to evaluate val loss on eval_tokens = 20*524288 # number of tokens to evaluate val loss on
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable) core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = 500 # examples per task in estimating the core metric core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model sample_every = 2000 # every how many steps to sample from the model
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
# Output # Output
model_tag = "" # optionally override the model tag for the output checkpoint directory name model_tag = "" # optionally override the model tag for the output checkpoint directory name
# now allow CLI to override the settings via the configurator lol # now allow CLI to override the settings via the configurator lol
@ -114,16 +116,31 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Initialize the Model # Initialize the Model
# Create a new model with random weights
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
with torch.device("meta"): with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs) model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config) model = GPT(model_config)
model.to_empty(device=device) model.to_empty(device=device)
model.init_weights() model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through # If we are resuming, overwrite the model parameters with those of the checkpoint
base_dir = get_base_dir()
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
resuming = resume_from_step != -1
if resuming:
print0(f"Resuming optimization from step {resume_from_step}")
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank)
model.load_state_dict(model_data, strict=True, assign=True)
del model_data # free up this memory after the copy
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
num_params = sum(p.numel() for p in model.parameters()) num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}") print0(f"Number of parameters: {num_params:,}")
num_flops_per_token = model.estimate_flops() num_flops_per_token = model.estimate_flops()
@ -173,12 +190,18 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adamw_optimizer, muon_optimizer = optimizers adamw_optimizer, muon_optimizer = optimizers
if resuming:
for opt, dat in zip(optimizers, optimizer_data):
opt.load_state_dict(dat)
del optimizer_data # free up the memory
# -----------------------------------------------------------------------------
# Initialize the DataLoaders for train/val # Initialize the DataLoaders for train/val
base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data") tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device) dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device) build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Set up hyperparameter schedulers # Set up hyperparameter schedulers
@ -201,6 +224,21 @@ def get_muon_momentum(it):
momentum = (1 - frac) * 0.85 + frac * 0.95 momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum return momentum
# -----------------------------------------------------------------------------
# Loop state (variables updated by the training loop)
if not resuming:
step = 0
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
total_training_time = 0 # total wall-clock time of training
else:
step = meta_data["step"]
loop_state = meta_data["loop_state"]
min_val_bpb = loop_state["min_val_bpb"]
smooth_train_loss = loop_state["smooth_train_loss"]
total_training_time = loop_state["total_training_time"]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Training loop # Training loop
min_val_bpb = float("inf") min_val_bpb = float("inf")
@ -284,16 +322,23 @@ for step in range(start_step, num_iterations + 1):
save_checkpoint( save_checkpoint(
checkpoint_dir, checkpoint_dir,
step, step,
orig_model.state_dict(), orig_model.state_dict(), # model parameters
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly [opt.state_dict() for opt in optimizers], # optimizer states
{ { # metadata saved as json
"step": step, "step": step,
"val_bpb": val_bpb, # loss at last step "val_bpb": val_bpb, # loss at last step
"model_config": model_config_kwargs, "model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script "user_config": user_config, # inputs to the training script
"device_batch_size": device_batch_size, "device_batch_size": device_batch_size,
"max_seq_len": max_seq_len, "max_seq_len": max_seq_len,
} "dataloader_state_dict": dataloader_state_dict,
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,
"smooth_train_loss": smooth_train_loss,
"total_training_time": total_training_time,
},
},
rank=ddp_rank,
) )
# Periodic checkpointing (every 1000 steps) # Periodic checkpointing (every 1000 steps)
@ -335,10 +380,12 @@ for step in range(start_step, num_iterations + 1):
train_loss = loss.detach() # for logging train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# gradient clipping (TODO possibly experiment with) # gradient clipping
if grad_clip > 0.0: grad_clip_enabled = grad_clip > 0.0
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip) if grad_clip_enabled:
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
# step the optimizers # step the optimizers
lrm = get_lr_multiplier(step) lrm = get_lr_multiplier(step)
for opt in optimizers: for opt in optimizers:
@ -356,6 +403,7 @@ for step in range(start_step, num_iterations + 1):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# logging # logging
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations pct_done = 100 * step / num_iterations
@ -365,9 +413,10 @@ for step in range(start_step, num_iterations + 1):
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10: if step > 10:
total_training_time += dt # only count the time after the first 10 steps total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else ""
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 100 == 0: if step % 100 == 0:
wandb_run.log({ log_data = {
"step": step, "step": step,
"total_training_flops": flops_so_far, "total_training_flops": flops_so_far,
"total_training_time": total_training_time, "total_training_time": total_training_time,
@ -376,7 +425,13 @@ for step in range(start_step, num_iterations + 1):
"train/dt": dt, "train/dt": dt,
"train/tok_per_sec": tok_per_sec, "train/tok_per_sec": tok_per_sec,
"train/mfu": mfu, "train/mfu": mfu,
}) }
if grad_clip_enabled:
log_data["train/grad_norm"] = grad_norm
wandb_run.log(log_data)
# state update
step += 1
# print a few more stats # print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")

View File

@ -1,6 +1,6 @@
""" """
Evaluate the Chat model. Evaluate the Chat model.
All the generic code lives here, and all the evlauation-specific All the generic code lives here, and all the evaluation-specific
code lives in nanochat directory and is imported from here. code lives in nanochat directory and is imported from here.
Example runs: Example runs:

View File

@ -203,7 +203,7 @@ for step in range(num_iterations):
}) })
model.train() model.train()
# evlauate accuracy of the multiple choice tasks (which are quick to run) # evaluate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % eval_metrics_every == 0): if last_step or (step > 0 and step % eval_metrics_every == 0):
model.eval() model.eval()
metrics = {} metrics = {}

View File

@ -243,7 +243,7 @@ app.add_middleware(
async def root(): async def root():
"""Serve the chat UI.""" """Serve the chat UI."""
ui_html_path = os.path.join("nanochat", "ui.html") ui_html_path = os.path.join("nanochat", "ui.html")
with open(ui_html_path, "r") as f: with open(ui_html_path, "r", encoding="utf-8") as f:
html_content = f.read() html_content = f.read()
# Replace the API_URL to use the same origin # Replace the API_URL to use the same origin
html_content = html_content.replace( html_content = html_content.replace(

View File

@ -82,12 +82,15 @@ python -m scripts.tok_eval
echo "Waiting for dataset download to complete..." echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID wait $DATASET_DOWNLOAD_PID
# Number of processes/GPUs to use
NPROC_PER_NODE=8
# pretrain the d20 model # pretrain the d20 model
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples # evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
# evaluate the model on CORE tasks # evaluate the model on CORE tasks
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice) # Midtraining (teach the model conversation special tokens, tool use, multiple choice)
@ -97,15 +100,15 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
# run midtraining and eval the model # run midtraining and eval the model
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Supervised Finetuning (domain adaptation to each sequence all by itself per row) # Supervised Finetuning (domain adaptation to each sequence all by itself per row)
# train sft and re-eval right away (should see a small bump) # train sft and re-eval right away (should see a small bump)
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
# chat with the model over CLI! Leave out the -p to chat interactively # chat with the model over CLI! Leave out the -p to chat interactively
# python -m scripts.chat_cli -p "Why is the sky blue?" # python -m scripts.chat_cli -p "Why is the sky blue?"
@ -118,9 +121,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
# (optional) # (optional)
# run reinforcement learning # run reinforcement learning
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
# eval the RL model only on GSM8K # eval the RL model only on GSM8K
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Generate the full report by putting together all the sections # Generate the full report by putting together all the sections

View File

@ -32,7 +32,7 @@ class CustomJSON(Task):
print("-" * 80) print("-" * 80)
else: else:
with open(filepath, 'r') as f: with open(filepath, 'r', encoding='utf-8') as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if not line: # skip empty lines if not line: # skip empty lines

View File

@ -119,7 +119,7 @@ class SpellingBee(Task):
self.split = split self.split = split
filename = WORD_LIST_URL.split("/")[-1] filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename) word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path) as f: with open(word_list_path, 'r', encoding='utf-8') as f:
words = [line.strip() for line in f] words = [line.strip() for line in f]
self.words = words self.words = words
@ -238,7 +238,7 @@ class SimpleSpelling(Task):
self.split = split self.split = split
filename = WORD_LIST_URL.split("/")[-1] filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename) word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path) as f: with open(word_list_path, 'r', encoding='utf-8') as f:
words = [line.strip() for line in f] words = [line.strip() for line in f]
rng = random.Random(42) rng = random.Random(42)
rng.shuffle(words) # use a different word order than the SpellingBee task rng.shuffle(words) # use a different word order than the SpellingBee task

View File

@ -455,13 +455,13 @@ def enwik8_path():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def enwik8_small(enwik8_path): def enwik8_small(enwik8_path):
"""Fixture providing 100KB of enwik8 for quick tests.""" """Fixture providing 100KB of enwik8 for quick tests."""
with open(enwik8_path, "r") as f: with open(enwik8_path, "r", encoding="utf-8") as f:
return f.read(100_000) return f.read(100_000)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def enwik8_large(enwik8_path): def enwik8_large(enwik8_path):
"""Fixture providing 10MB of enwik8 for performance tests.""" """Fixture providing 10MB of enwik8 for performance tests."""
with open(enwik8_path, "r") as f: with open(enwik8_path, "r", encoding="utf-8") as f:
return f.read(10**7) return f.read(10**7)
def time_function(func, *args, **kwargs): def time_function(func, *args, **kwargs):