diff --git a/README.md b/README.md index 0a46b99..c599249 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" -- This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` folder, and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files. -Alternatively, I recommend using [DeepWiki](https://deepwiki.com/) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off. +Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off. ## Codebase Overview and Data Flow @@ -214,6 +214,7 @@ python -m pytest tests/test_rustbpe.py -v -s │ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF │ └── spellingbee.py # Task teaching model to spell/count letters ├── tests +│ └── test_engine.py │ └── test_rustbpe.py └── uv.lock ``` @@ -231,6 +232,7 @@ Current LLM policy: disclosure. When submitting a PR, please declare any parts t - Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk. - Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project. - Thank you to chief LLM whisperer 🧙‍♂️ Alec Radford for advice/guidance. +- Thank you to the repo czar Sofie [@svlandeg](https://github.com/svlandeg) for help with managing issues, pull requests and discussions of nanochat. ## Cite diff --git a/dev/gen_synthetic_data.py b/dev/gen_synthetic_data.py index a67c7a5..e1a772a 100644 --- a/dev/gen_synthetic_data.py +++ b/dev/gen_synthetic_data.py @@ -42,7 +42,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from nanochat.common import get_base_dir -api_key = open("openroutertoken.txt").read().strip() +api_key = open("openroutertoken.txt", "r", encoding="utf-8").read().strip() url = "https://openrouter.ai/api/v1/chat/completions" headers = { @@ -50,7 +50,7 @@ headers = { "Content-Type": "application/json" } -readme = open("README.md").read().strip() +readme = open("README.md", "r", encoding="utf-8").read().strip() prompt = r""" I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want: diff --git a/dev/nanochat.png b/dev/nanochat.png index 84e1b5f..2313d27 100644 Binary files a/dev/nanochat.png and b/dev/nanochat.png differ diff --git a/dev/runcpu.sh b/dev/runcpu.sh index 20d253d..59979d4 100755 --- a/dev/runcpu.sh +++ b/dev/runcpu.sh @@ -47,15 +47,6 @@ source "$HOME/.cargo/env" # Build the Rust-based BPE tokenizer. uv run maturin develop --release --manifest-path rustbpe/Cargo.toml -# Download and set up the evaluation bundle if it's not already cached. -EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip -if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then - curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL - unzip -q eval_bundle.zip - rm eval_bundle.zip - mv eval_bundle $NANOCHAT_BASE_DIR -fi - # --- Training and Evaluation Pipeline --- # Reset any previous reports to start fresh. python -m nanochat.report reset diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f9302f1..381afd9 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -39,59 +39,36 @@ def log0(message): if int(os.environ.get('RANK', 0)) == 0: logger.info(message) -def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): - """ - Saves a checkpoint to the specified directory. - - Args: - checkpoint_dir (str): The directory to save the checkpoint to. - step (int): The current training step. - model_data (dict): The model's state_dict. - optimizer_data (dict): The optimizer's state_dict. - meta_data (dict): A dictionary of metadata to save. - """ - assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now - os.makedirs(checkpoint_dir, exist_ok=True) - # Save the model state (parameters) - model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") - torch.save(model_data, model_path) - log0(f"Saved model file to: {model_path}") - # Save the optimizer state (useful for SFT or any other fine-tuning) +def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): + if rank == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + # Save the model state parameters + model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") + torch.save(model_data, model_path) + logger.info(f"Saved model parameters to: {model_path}") + # Save the metadata dict as json + meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") + with open(meta_path, "w", encoding="utf-8") as f: + json.dump(meta_data, f, indent=2) + logger.info(f"Saved metadata to: {meta_path}") + # Note that optimizer state is sharded across ranks, so each rank must save its own. if optimizer_data is not None: - optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") torch.save(optimizer_data, optimizer_path) - log0(f"Saved optimizer file to: {optimizer_path}") - # Save the metadata dict as json - meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") - with open(meta_path, "w") as f: - json.dump(meta_data, f, indent=2) - log0(f"Saved metadata file to: {meta_path}") + logger.info(f"Saved optimizer state to: {optimizer_path}") - -def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False): - """ - Loads a checkpoint from the specified directory. - - Args: - checkpoint_dir (str): The directory to load the checkpoint from. - step (int): The training step of the checkpoint to load. - device (str): The device to load the tensors onto. - load_optimizer (bool, optional): Whether to load the optimizer state. Defaults to False. - - Returns: - tuple: A tuple containing the model data, optimizer data, and metadata. - """ +def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): # Load the model state model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") model_data = torch.load(model_path, map_location=device) # Load the optimizer state if requested optimizer_data = None if load_optimizer: - optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") + optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") optimizer_data = torch.load(optimizer_path, map_location=device) # Load the metadata meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") - with open(meta_path, "r") as f: + with open(meta_path, "r", encoding="utf-8") as f: meta_data = json.load(f) return model_data, optimizer_data, meta_data @@ -111,8 +88,14 @@ def build_model(checkpoint_dir, step, device, phase): """ assert phase in ["train", "eval"], f"Invalid phase: {phase}" model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) + if device.type in {"cpu", "mps"}: + # Convert bfloat16 tensors to float for CPU inference + model_data = { + k: v.float() if v.dtype == torch.bfloat16 else v + for k, v in model_data.items() + } # Hack: fix torch compile issue, which prepends all keys with _orig_mod. - model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()} + model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} model_config_kwargs = meta_data["model_config"] log0(f"Building model with config: {model_config_kwargs}") model_config = GPTConfig(**model_config_kwargs) diff --git a/nanochat/common.py b/nanochat/common.py index 4291287..5d75608 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -7,10 +7,10 @@ used across various scripts to ensure consistency and reduce code duplication. import os import re import logging -import fcntl import urllib.request import torch import torch.distributed as dist +from filelock import FileLock class ColoredFormatter(logging.Formatter): """ @@ -69,7 +69,7 @@ def get_base_dir(): os.makedirs(nanochat_dir, exist_ok=True) return nanochat_dir -def download_file_with_lock(url, filename): +def download_file_with_lock(url, filename, postprocess_fn=None): """ Downloads a file from a URL to a local path, using a lock file to prevent concurrent downloads in a distributed setting. @@ -88,29 +88,27 @@ def download_file_with_lock(url, filename): if os.path.exists(file_path): return file_path - with open(lock_path, 'w') as lock_file: - + with FileLock(lock_path): # Only a single rank can acquire this lock # All other ranks block until it is released - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + # Recheck after acquiring lock if os.path.exists(file_path): return file_path + # Download the content as bytes print(f"Downloading {url}...") with urllib.request.urlopen(url) as response: - content = response.read().decode('utf-8') + content = response.read() # bytes - with open(file_path, 'w') as f: + # Write to local file + with open(file_path, 'wb') as f: f.write(content) - print(f"Downloaded to {file_path}") - # Clean up the lock file after the lock is released - try: - os.remove(lock_path) - except OSError: - pass # Ignore if already removed by another process + # Run the postprocess function if provided + if postprocess_fn is not None: + postprocess_fn(file_path) return file_path @@ -124,15 +122,15 @@ def print_banner(): """Prints the nanochat ASCII art banner.""" # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ banner = """ - █████ █████ - ░░███ ░░███ - ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████ -░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███ ░░░███░ - ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███ - ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███ - ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░████████ ░░█████ -░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░ -""" + █████ █████ + ░░███ ░░███ + ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████ + ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░ + ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███ + ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███ + ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████ + ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░ + """ print0(banner) def is_ddp(): @@ -184,6 +182,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" # Reproducibility + # Note that we set the global seeds here, but most of the code uses explicit rng objects. + # The only place where global rng might be used is nn.Module initialization of the model weights. torch.manual_seed(42) if device_type == "cuda": torch.cuda.manual_seed(42) @@ -198,7 +198,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() if ddp and device_type == "cuda": device = torch.device("cuda", ddp_local_rank) - torch.cuda.set_device(device) # make "cuda" default to this device + torch.cuda.set_device(device) # make "cuda" default to this device dist.init_process_group(backend="nccl", device_id=device) dist.barrier() else: diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index e87bd4d..3271298 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -1,67 +1,87 @@ from collections import deque import torch +import pyarrow.parquet as pq from nanochat.common import get_dist_info -from nanochat.dataset import parquets_iter_batched +from nanochat.dataset import list_parquet_files from nanochat.tokenizer import get_tokenizer -def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"): +def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None): """ - Streams text from Parquet files, tokenizes it, and yields training batches. + Stream pretraining text from parquet files, tokenize, yield training batches. - This data loader is designed for large-scale pretraining, where the entire dataset - cannot fit into memory. It streams data from disk, tokenizes it on the fly, and - yields batches of data indefinitely. It also supports distributed training by - sharding the data across multiple devices. + This implementation became a bit more complex because we wish to support approximate resume training. + Instead of turning this into a Class, we opt to return the state_dict with every batch, + and then the caller can pass in a state_dict to resume training from a desired point. + Note that this resumption is atm only *approximate* for simplicity. + We won't repeat the same documents but we might skip a few. + The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume. - Args: - B (int): The batch size. - T (int): The sequence length. - split (str): The data split to use, either "train" or "val". - tokenizer_threads (int, optional): The number of threads for tokenization. - tokenizer_batch_size (int, optional): The number of documents to tokenize at once. - device (str, optional): The device to move the batches to. - - Yields: - tuple: A tuple containing the input and target tensors. + Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm. """ assert split in ["train", "val"], "split must be 'train' or 'val'" + + # infinite iterator over document batches (list of text strings) ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + def document_batches(): + parquet_paths = list_parquet_files() + parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] + resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 + resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None + pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0) + while True: # iterate infinitely (multi-epoch) + while pq_idx < len(parquet_paths): # iterate over all parquet files + filepath = parquet_paths[pq_idx] + pf = pq.ParquetFile(filepath) + # Start from resume point if resuming on same file, otherwise from DDP rank + # I know this state resumption is a little bit tricky and a little bit hacky... sigh. + if resume_rg_idx is not None: + base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size + base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming + rg_idx = base_idx * ddp_world_size + ddp_rank + resume_rg_idx = None # set to None as we only want to do this a single time + else: + rg_idx = ddp_rank + while rg_idx < pf.num_row_groups: + rg = pf.read_row_group(rg_idx) + batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows + # the tokenizer encode might want to go in even smaller batches, e.g. 128 rows + for i in range(0, len(batch), tokenizer_batch_size): + yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx) + rg_idx += ddp_world_size # advance to the next row group (in DDP) + pq_idx += 1 # advance to the next parquet file + batches = document_batches() + + # Now emit batches of tokens. needed_tokens = B * T + 1 # +1 is because we also need the target at the last token # get the tokenizer and the bos token tokenizer = get_tokenizer() bos_token = tokenizer.get_bos_token_id() # scratch buffer holds the tokens for one iteration token_buffer = deque() # we stream tokens on the right and pop from the left - - # infinite iterator over document batches - def document_batches(): - while True: - # batch will iterate in group size of the parquet files, usually e.g. 1024 rows - for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size): - # for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows - for i in range(0, len(batch), tokenizer_batch_size): - yield batch[i:i+tokenizer_batch_size] - batches = document_batches() - - batch_index = 0 while True: # Accumulate enough tokens for one iteration before yielding. while len(token_buffer) < needed_tokens: - doc_batch = next(batches) + doc_batch, (pq_idx, rg_idx) = next(batches) token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) for tokens in token_lists: token_buffer.extend(tokens) - batch_index += 1 # Move tokens from the deque into the scratch buffer tokens = [token_buffer.popleft() for _ in range(needed_tokens)] - # CUDA supports memory pinning for faster transfers between CPU and GPU: - scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda")) + # CUDA supports memory pinning for asynchronous transfers between CPU and GPU + use_cuda_optimizations = device == "cuda" + scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64 # Create the inputs/targets as 1D tensors - inputs_cpu = scratch[:-1].to(dtype=torch.int32) + inputs_cpu = scratch[:-1] targets_cpu = scratch[1:] # Reshape to 2D and move to GPU async - inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True) - targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True) + inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) + targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) + state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training + yield inputs, targets, state_dict + +def tokenizing_distributed_data_loader(*args, **kwargs): + # helper function that only emits the inputs/targets and not the state_dict + for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs): yield inputs, targets diff --git a/nanochat/engine.py b/nanochat/engine.py index de7306f..35d3ca9 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -19,8 +19,9 @@ import signal import warnings from contextlib import contextmanager from collections import deque -from nanochat.common import compute_init +from nanochat.common import compute_init, autodetect_device_type from nanochat.checkpoint_manager import load_model +from contextlib import nullcontext # ----------------------------------------------------------------------------- # Calculator tool helpers @@ -41,7 +42,7 @@ def eval_with_timeout(formula, max_time=3): with timeout(max_time, formula): with warnings.catch_warnings(): warnings.simplefilter("ignore", SyntaxWarning) - return eval(formula) + return eval(formula, {"__builtins__": {}}, {}) except Exception as e: signal.alarm(0) # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage @@ -124,9 +125,10 @@ class KVCache: assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" assert other.kv_cache is not None, "Cannot prefill with a None KV cache" for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)): + # ix 0: num_layers, 1: k/v, 2: batch_size, 3: num_heads, 4: seq_len, 5: head_dim if ix in [0, 1, 3, 5]: - # num_layers, batch_size, num_heads, head_dim must match - assert dim1 == dim2, f"Batch dim mismatch: {dim1} != {dim2}" + # num_layers, k/v, num_heads, head_dim must match + assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}" elif ix == 2: # batch_size can be expanded assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}" @@ -368,6 +370,9 @@ if __name__ == "__main__": import time # init compute ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() + device_type = autodetect_device_type() + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() + # load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="eval") bos_token_id = tokenizer.get_bos_token_id() @@ -380,10 +385,11 @@ if __name__ == "__main__": torch.cuda.synchronize() t0 = time.time() stream = model.generate(prompt_tokens, **kwargs) - for token in stream: - generated_tokens.append(token) - chunk = tokenizer.decode([token]) - print(chunk, end="", flush=True) + with autocast_ctx: + for token in stream: + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time() @@ -395,11 +401,12 @@ if __name__ == "__main__": stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 torch.cuda.synchronize() t0 = time.time() - for token_column, token_masks in stream: - token = token_column[0] # only print out the first row - generated_tokens.append(token) - chunk = tokenizer.decode([token]) - print(chunk, end="", flush=True) + with autocast_ctx: + for token_column, token_masks in stream: + token = token_column[0] # only print out the first row + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time() diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5c54e41..78f0167 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -1,11 +1,14 @@ """ -This module implements the GPT (Generative Pre-trained Transformer) model for nanochat. -It features several modern architectural choices for improved performance and efficiency: -- Rotary Positional Embeddings (RoPE) -- QK Norm for attention stabilization -- SwiGLU activation in the MLP -- RMSNorm for normalization -- Multi-Query Attention (MQA) for efficient inference +GPT model (rewrite, a lot simpler) +Notable features: +- rotary embeddings (and no positional embeddings) +- QK norm +- untied weights for token embedding and lm_head +- relu^2 activation in MLP +- norm after token embedding +- no learnable params in rmsnorm +- no bias in linear layers +- Group-Query Attention (GQA) support for more efficient inference """ import math @@ -34,7 +37,7 @@ class GPTConfig: vocab_size: int = 50304 n_layer: int = 12 n_head: int = 6 # number of query heads - n_kv_head: int = 6 # number of key/value heads (MQA) + n_kv_head: int = 6 # number of key/value heads (GQA) n_embd: int = 768 @@ -258,7 +261,7 @@ class GPT(nn.Module): """The forward pass of the model.""" B, T = idx.size() - # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) + # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index cf2f634..277101d 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -10,7 +10,12 @@ import torch.distributed as dist @torch.no_grad() def evaluate_bpb(model, batches, steps, token_bytes): """ - Evaluates the model's performance using the bits-per-byte (BPB) metric. + Instead of the naive 'mean loss', this function returns the bits per byte (bpb), + which is a tokenization vocab size-independent metric, meaning you are still comparing + apples:apples if you change the vocab size. The way this works is that instead of just + calculating the average loss as usual, you calculate the sum loss, and independently + also the sum bytes (of all the target tokens), and divide. This normalizes the loss by + the number of bytes that the target tokens represent. Args: model (torch.nn.Module): The language model to evaluate. diff --git a/nanochat/report.py b/nanochat/report.py index e13e26b..bc4f22e 100644 --- a/nanochat/report.py +++ b/nanochat/report.py @@ -179,7 +179,7 @@ Generated: {timestamp} # count dependencies via uv.lock uv_lock_lines = 0 if os.path.exists('uv.lock'): - with open('uv.lock', 'r') as f: + with open('uv.lock', 'r', encoding='utf-8') as f: uv_lock_lines = len(f.readlines()) header += f""" @@ -261,7 +261,7 @@ class Report: slug = slugify(section) file_name = f"{slug}.md" file_path = os.path.join(self.report_dir, file_name) - with open(file_path, "w") as f: + with open(file_path, "w", encoding="utf-8") as f: f.write(f"## {section}\n") f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") for item in data: @@ -296,11 +296,11 @@ class Report: final_metrics = {} # the most important final metrics we'll add as table at the end start_time = None end_time = None - with open(report_file, "w") as out_file: + with open(report_file, "w", encoding="utf-8") as out_file: # write the header first header_file = os.path.join(report_dir, "header.md") if os.path.exists(header_file): - with open(header_file, "r") as f: + with open(header_file, "r", encoding="utf-8") as f: header_content = f.read() out_file.write(header_content) start_time = extract_timestamp(header_content, "Run started:") @@ -317,7 +317,7 @@ class Report: if not os.path.exists(section_file): print(f"Warning: {section_file} does not exist, skipping") continue - with open(section_file, "r") as in_file: + with open(section_file, "r", encoding="utf-8") as in_file: section = in_file.read() # Extract timestamp from this section (the last section's timestamp will "stick" as end_time) if "rl" not in file_name: @@ -401,7 +401,7 @@ class Report: header_file = os.path.join(self.report_dir, "header.md") header = generate_header() start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - with open(header_file, "w") as f: + with open(header_file, "w", encoding="utf-8") as f: f.write(header) f.write(f"Run started: {start_time}\n\n---\n\n") print(f"Reset report and wrote header to {header_file}") diff --git a/pyproject.toml b/pyproject.toml index da674f4..3d03c4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ dependencies = [ "datasets>=4.0.0", "fastapi>=0.117.1", "files-to-prompt>=0.6", - "numpy==1.26.4", "psutil>=7.1.0", "regex>=2025.9.1", "setuptools>=80.9.0", diff --git a/run1000.sh b/run1000.sh index 6f454e0..58ee3bc 100644 --- a/run1000.sh +++ b/run1000.sh @@ -19,13 +19,6 @@ python -m nanochat.report reset curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y source "$HOME/.cargo/env" uv run maturin develop --release --manifest-path rustbpe/Cargo.toml -EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip -if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then - curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL - unzip -q eval_bundle.zip - rm eval_bundle.zip - mv eval_bundle $NANOCHAT_BASE_DIR -fi curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl # train tokenizer on ~4B characters and kick off download of the rest for pretraining @@ -77,18 +70,22 @@ python -m scripts.tok_eval # which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd # start to overfit hard. # 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script. -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=32 --device_batch_size=8 -torchrun --standalone --nproc_per_node=8 -m scripts.base_loss -torchrun --standalone --nproc_per_node=8 -m scripts.base_eval + +# Number of processes/GPUs to use +NPROC_PER_NODE=8 + +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --device_batch_size=8 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval # midtrain # NOTE: ensure that we use the same device_batch_size here as the base training script. -torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid # sft -torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft # generate final report python -m nanochat.report generate diff --git a/scripts/base_eval.py b/scripts/base_eval.py index cc54a26..f4f9810 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -1,58 +1,76 @@ """ -This script evaluates a base model on the CORE (Comprehensive Overall Language Evaluation) -metric. It can evaluate either a local nanochat model or a Hugging Face model. +Evaluate the CORE metric for a given model. -The CORE benchmark provides a holistic assessment of a model's capabilities. This -script iterates through the tasks defined in `core.yaml`, evaluates the model on each, -and reports the accuracy and a "centered" score (normalized by a random baseline). +Run on a single GPU: +python -m scripts.base_eval -Usage: -- Evaluate a local nanochat model: `python scripts/base_eval.py` -- Evaluate a Hugging Face model: `python scripts/base_eval.py --hf-path ` -- Distributed evaluation: `torchrun --nproc_per_node= scripts/base_eval.py` +Run with torchrun on e.g. 8 GPUs: +torchrun --nproc_per_node=8 -m scripts.base_eval + +The script will print the CORE metric to the console. """ import os -import sys +import csv import time import json -import random import yaml +import shutil +import random +import zipfile +import tempfile from contextlib import nullcontext -import pandas as pd import torch -from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock from nanochat.tokenizer import HuggingFaceTokenizer from nanochat.checkpoint_manager import load_model from nanochat.core_eval import evaluate_task # ----------------------------------------------------------------------------- -# nanoChat specific function dealing with I/O etc. +# nanochat specific function dealing with I/O etc. + +# ~162MB of data needed to evaluate the CORE metric +EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip" + +def place_eval_bundle(file_path): + # here file_path is the path to the eval_bundle.zip file + # we need to unzip it and place it in the base directory + base_dir = get_base_dir() + eval_bundle_dir = os.path.join(base_dir, "eval_bundle") + with tempfile.TemporaryDirectory() as tmpdir: + with zipfile.ZipFile(file_path, 'r') as zip_ref: + zip_ref.extractall(tmpdir) + extracted_bundle_dir = os.path.join(tmpdir, "eval_bundle") + shutil.move(extracted_bundle_dir, eval_bundle_dir) + print0(f"Placed eval_bundle directory at {eval_bundle_dir}") def evaluate_model(model, tokenizer, device, max_per_task=-1): """ - Evaluates a model on the CORE benchmark. - - Args: - model: The model to evaluate. - tokenizer: The tokenizer to use. - device (str): The device to run the evaluation on. - max_per_task (int, optional): Max examples per task. Defaults to -1 (no limit). - - Returns: - dict: A dictionary containing the evaluation results. + Evaluate a base model on the CORE benchmark. + - max_per_task: crop the data to this many examples per task for testing (-1 = disable) """ # Load config and task metadata base_dir = get_base_dir() eval_bundle_dir = os.path.join(base_dir, "eval_bundle") + # Download the eval bundle to disk (and unzip if needed) + if not os.path.exists(eval_bundle_dir): + download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle) config_path = os.path.join(eval_bundle_dir, "core.yaml") data_base_path = os.path.join(eval_bundle_dir, "eval_data") eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv") - with open(config_path, 'r') as f: + with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) tasks = config['icl_tasks'] - eval_metadata = pd.read_csv(eval_meta_data) + + # Load random baseline values from eval metadata + random_baselines = {} + with open(eval_meta_data, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + for row in reader: + task_name = row['Eval Task'] + random_baseline = row['Random baseline'] + random_baselines[task_name] = float(random_baseline) # Evaluate each task results = {} @@ -70,7 +88,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): # Load data for this task data_path = os.path.join(data_base_path, task_meta['dataset_uri']) - with open(data_path, 'r') as f: + with open(data_path, 'r', encoding='utf-8') as f: data = [json.loads(line.strip()) for line in f] # shuffle the data because in many cases it appears ordered but we want @@ -84,8 +102,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): accuracy = evaluate_task(model, tokenizer, data, device, task_meta) results[label] = accuracy - row = eval_metadata[eval_metadata["Eval Task"] == label] - random_baseline = row["Random baseline"].values[0] + random_baseline = random_baselines[label] centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline) centered_results[label] = centered_result end_time = time.time() @@ -168,7 +185,7 @@ def main(): results = out["results"] centered_results = out["centered_results"] core_metric = out["core_metric"] - with open(output_csv_path, 'w') as f: + with open(output_csv_path, 'w', encoding='utf-8', newline='') as f: f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n") for label in results: f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n") @@ -177,7 +194,7 @@ def main(): print0("="*80) print0(f"Model: {model_name}") print0("="*80) - with open(output_csv_path, 'r') as f: + with open(output_csv_path, 'r', encoding='utf-8') as f: print0(f.read()) # Log to report diff --git a/scripts/base_train.py b/scripts/base_train.py index 4005de7..c9131c8 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -27,10 +27,10 @@ import wandb import torch from nanochat.gpt import GPT, GPTConfig -from nanochat.dataloader import tokenizing_distributed_data_loader +from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type from nanochat.tokenizer import get_tokenizer, get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint +from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine from scripts.base_eval import evaluate_model @@ -59,12 +59,14 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled) warmup_ratio = 0.0 # ratio of iterations for LR warmup warmdown_ratio = 0.2 # ratio of iterations for LR warmdown final_lr_frac = 0.0 # final LR is this fraction of the initial LR +resume_from_step = -1 # resume training from this step of the optimization (-1 = disable) # Evaluation eval_every = 250 # every how many steps to evaluate the model for val bpb eval_tokens = 20*524288 # number of tokens to evaluate val loss on core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable) core_metric_max_per_task = 500 # examples per task in estimating the core metric sample_every = 2000 # every how many steps to sample from the model +save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run) # Output model_tag = "" # optionally override the model tag for the output checkpoint directory name # now allow CLI to override the settings via the configurator lol @@ -110,16 +112,31 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") + # ----------------------------------------------------------------------------- # Initialize the Model + +# Create a new model with random weights model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) with torch.device("meta"): model_config = GPTConfig(**model_config_kwargs) model = GPT(model_config) model.to_empty(device=device) model.init_weights() -orig_model = model # original, uncompiled model, for saving raw model state_dict -model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through + +# If we are resuming, overwrite the model parameters with those of the checkpoint +base_dir = get_base_dir() +output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12 +checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) +resuming = resume_from_step != -1 +if resuming: + print0(f"Resuming optimization from step {resume_from_step}") + model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank) + model.load_state_dict(model_data, strict=True, assign=True) + del model_data # free up this memory after the copy + +orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) +model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe num_params = sum(p.numel() for p in model.parameters()) print0(f"Number of parameters: {num_params:,}") num_flops_per_token = model.estimate_flops() @@ -150,12 +167,18 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) adamw_optimizer, muon_optimizer = optimizers +if resuming: + for opt, dat in zip(optimizers, optimizer_data): + opt.load_state_dict(dat) + del optimizer_data # free up the memory + +# ----------------------------------------------------------------------------- # Initialize the DataLoaders for train/val -base_dir = get_base_dir() tokens_dir = os.path.join(base_dir, "tokenized_data") -train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device) +dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] +train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device) -x, y = next(train_loader) # kick off load of the very first batch of data +x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data # ----------------------------------------------------------------------------- # Set up hyperparameter schedulers @@ -178,15 +201,25 @@ def get_muon_momentum(it): momentum = (1 - frac) * 0.85 + frac * 0.95 return momentum +# ----------------------------------------------------------------------------- +# Loop state (variables updated by the training loop) + +if not resuming: + step = 0 + min_val_bpb = float("inf") + smooth_train_loss = 0 # EMA of training loss + total_training_time = 0 # total wall-clock time of training +else: + step = meta_data["step"] + loop_state = meta_data["loop_state"] + min_val_bpb = loop_state["min_val_bpb"] + smooth_train_loss = loop_state["smooth_train_loss"] + total_training_time = loop_state["total_training_time"] + # ----------------------------------------------------------------------------- # Training loop -min_val_bpb = float("inf") -smooth_train_loss = 0 # EMA of training loss -ema_beta = 0.9 # EMA decay factor -total_training_time = 0 # total wall-clock time of training -# note that we run +1 steps only so that we can eval and save at the end -for step in range(num_iterations + 1): - last_step = step == num_iterations +while True: + last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end flops_so_far = num_flops_per_token * total_batch_size * step # once in a while: evaluate the val bpb (all ranks participate) @@ -244,25 +277,31 @@ for step in range(num_iterations + 1): print0(tokenizer.decode(sample[0])) model.train() - # save checkpoint at the end of the run (only on master process) - if master_process and last_step: - output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12 - checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) + # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step + if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0): save_checkpoint( checkpoint_dir, step, - orig_model.state_dict(), - [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly - { + orig_model.state_dict(), # model parameters + [opt.state_dict() for opt in optimizers], # optimizer states + { # metadata saved as json "step": step, "val_bpb": val_bpb, # loss at last step "model_config": model_config_kwargs, "user_config": user_config, # inputs to the training script "device_batch_size": device_batch_size, "max_seq_len": max_seq_len, - } + "dataloader_state_dict": dataloader_state_dict, + "loop_state": { # all loop state (other than step) so that we can resume training + "min_val_bpb": min_val_bpb, + "smooth_train_loss": smooth_train_loss, + "total_training_time": total_training_time, + }, + }, + rank=ddp_rank, ) + # termination conditions (TODO: possibly also add loss explosions etc.) if last_step: break @@ -277,10 +316,12 @@ for step in range(num_iterations + 1): train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() - x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward - # gradient clipping (TODO possibly experiment with) - if grad_clip > 0.0: - torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip) + x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + # gradient clipping + grad_clip_enabled = grad_clip > 0.0 + if grad_clip_enabled: + grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip) + grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point) # step the optimizers lrm = get_lr_multiplier(step) for opt in optimizers: @@ -298,18 +339,20 @@ for step in range(num_iterations + 1): # ------------------------------------------------------------------------- # logging + ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA pct_done = 100 * step / num_iterations - tok_per_sec = int(world_tokens_per_fwdbwd / dt) + tok_per_sec = int(total_batch_size / dt) flops_per_sec = num_flops_per_token * total_batch_size / dt promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % if step > 10: total_training_time += dt # only count the time after the first 10 steps - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else "" + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") if step % 100 == 0: - wandb_run.log({ + log_data = { "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, @@ -318,7 +361,13 @@ for step in range(num_iterations + 1): "train/dt": dt, "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, - }) + } + if grad_clip_enabled: + log_data["train/grad_norm"] = grad_norm + wandb_run.log(log_data) + + # state update + step += 1 # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index 7d87a4e..a1589e7 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -1,7 +1,7 @@ """ -This script evaluates a trained chat model on various downstream tasks, such as -MMLU, GSM8K, and HumanEval. It supports both generative and categorical evaluation -modes, depending on the task. +Evaluate the Chat model. +All the generic code lives here, and all the evaluation-specific +code lives in nanochat directory and is imported from here. The script can be run in both single-GPU and distributed (DDP) modes. diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 7578918..b846f4c 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -194,7 +194,7 @@ for step in range(num_iterations): }) model.train() - # evlauate accuracy of the multiple choice tasks (which are quick to run) + # evaluate accuracy of the multiple choice tasks (which are quick to run) if last_step or (step > 0 and step % eval_metrics_every == 0): model.eval() metrics = {} diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 9091ea8..3e471ea 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -224,7 +224,7 @@ app.add_middleware( async def root(): """Serve the chat UI.""" ui_html_path = os.path.join("nanochat", "ui.html") - with open(ui_html_path, "r") as f: + with open(ui_html_path, "r", encoding="utf-8") as f: html_content = f.read() # Replace the API_URL to use the same origin html_content = html_content.replace( diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 8c02343..5c0a436 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -271,7 +271,7 @@ while True: smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA pct_done = 100 * progress - tok_per_sec = int(world_tokens_per_fwdbwd / dt) + tok_per_sec = int(total_batch_size / dt) flops_per_sec = num_flops_per_token * total_batch_size / dt promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % diff --git a/speedrun.sh b/speedrun.sh index 35dd39e..7955ec5 100644 --- a/speedrun.sh +++ b/speedrun.sh @@ -73,15 +73,6 @@ python -m scripts.tok_eval # ----------------------------------------------------------------------------- # Base model (pretraining) -# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB) -EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip -if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then - curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL - unzip -q eval_bundle.zip - rm eval_bundle.zip - mv eval_bundle $NANOCHAT_BASE_DIR -fi - # The d20 model is 561M parameters. # Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens. # Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars. @@ -91,12 +82,15 @@ fi echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID +# Number of processes/GPUs to use +NPROC_PER_NODE=8 + # pretrain the d20 model -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN # evaluate the model on a larger chunk of train/val data and draw some samples -torchrun --standalone --nproc_per_node=8 -m scripts.base_loss +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss # evaluate the model on CORE tasks -torchrun --standalone --nproc_per_node=8 -m scripts.base_eval +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval # ----------------------------------------------------------------------------- # Midtraining (teach the model conversation special tokens, tool use, multiple choice) @@ -106,15 +100,15 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_eval curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl # run midtraining and eval the model -torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid # ----------------------------------------------------------------------------- # Supervised Finetuning (domain adaptation to each sequence all by itself per row) # train sft and re-eval right away (should see a small bump) -torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft # chat with the model over CLI! Leave out the -p to chat interactively # python -m scripts.chat_cli -p "Why is the sky blue?" @@ -127,9 +121,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft # (optional) # run reinforcement learning -# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN +# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN # eval the RL model only on GSM8K -# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K +# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K # ----------------------------------------------------------------------------- # Generate the full report by putting together all the sections diff --git a/tasks/customjson.py b/tasks/customjson.py index 6f80bd1..e8446ca 100644 --- a/tasks/customjson.py +++ b/tasks/customjson.py @@ -42,7 +42,7 @@ class CustomJSON(Task): print("-" * 80) else: - with open(filepath, 'r') as f: + with open(filepath, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: # skip empty lines diff --git a/tasks/spellingbee.py b/tasks/spellingbee.py index 5691577..8f3e56e 100644 --- a/tasks/spellingbee.py +++ b/tasks/spellingbee.py @@ -143,7 +143,7 @@ class SpellingBee(Task): # Download the word list if it's not already cached. filename = WORD_LIST_URL.split("/")[-1] word_list_path = download_file_with_lock(WORD_LIST_URL, filename) - with open(word_list_path) as f: + with open(word_list_path, 'r', encoding='utf-8') as f: words = [line.strip() for line in f] self.words = words @@ -285,7 +285,7 @@ class SimpleSpelling(Task): self.split = split filename = WORD_LIST_URL.split("/")[-1] word_list_path = download_file_with_lock(WORD_LIST_URL, filename) - with open(word_list_path) as f: + with open(word_list_path, 'r', encoding='utf-8') as f: words = [line.strip() for line in f] rng = random.Random(42) rng.shuffle(words) # Use a different word order than SpellingBee for variety. diff --git a/tests/test_rustbpe.py b/tests/test_rustbpe.py index 5f95721..aca67fc 100644 --- a/tests/test_rustbpe.py +++ b/tests/test_rustbpe.py @@ -455,13 +455,13 @@ def enwik8_path(): @pytest.fixture(scope="module") def enwik8_small(enwik8_path): """Fixture providing 100KB of enwik8 for quick tests.""" - with open(enwik8_path, "r") as f: + with open(enwik8_path, "r", encoding="utf-8") as f: return f.read(100_000) @pytest.fixture(scope="module") def enwik8_large(enwik8_path): """Fixture providing 10MB of enwik8 for performance tests.""" - with open(enwik8_path, "r") as f: + with open(enwik8_path, "r", encoding="utf-8") as f: return f.read(10**7) def time_function(func, *args, **kwargs): diff --git a/uv.lock b/uv.lock index f01bba3..4e9b0bd 100644 --- a/uv.lock +++ b/uv.lock @@ -311,7 +311,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.12' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } wheels = [ @@ -777,7 +777,6 @@ dependencies = [ { name = "datasets" }, { name = "fastapi" }, { name = "files-to-prompt" }, - { name = "numpy" }, { name = "psutil" }, { name = "regex" }, { name = "setuptools" }, @@ -811,7 +810,6 @@ requires-dist = [ { name = "datasets", specifier = ">=4.0.0" }, { name = "fastapi", specifier = ">=0.117.1" }, { name = "files-to-prompt", specifier = ">=0.6" }, - { name = "numpy", specifier = "==1.26.4" }, { name = "psutil", specifier = ">=7.1.0" }, { name = "regex", specifier = ">=2025.9.1" }, { name = "setuptools", specifier = ">=80.9.0" }, @@ -951,7 +949,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" }, @@ -964,7 +962,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, @@ -996,9 +994,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, - { name = "nvidia-cusparse-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, - { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, @@ -1011,7 +1009,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, @@ -1955,7 +1953,7 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools", marker = "extra == 'extra-8-nanochat-gpu'" }, + { name = "setuptools", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" },