diff --git a/README.md b/README.md index f13dba0..0a46b99 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,36 @@ This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` fo Alternatively, I recommend using [DeepWiki](https://deepwiki.com/) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off. +## Codebase Overview and Data Flow + +This repository is structured to provide a clear, end-to-end pipeline for building a conversational LLM. The process flows from data preparation and tokenization to model training and evaluation. Below is a high-level overview of the key components and their roles in this pipeline. For more in-depth information, please refer to the extensive in-code documentation that has been added throughout the repository. + +### Key Components + +- **`nanochat/`**: This is the core library of the project. It contains all the essential components for building and training the model, including the GPT architecture (`gpt.py`), the distributed AdamW optimizer (`adamw.py`), the data loader (`dataloader.py`), and the BPE tokenizer wrapper (`tokenizer.py`). Each module is documented with its specific purpose and implementation details. + +- **`rustbpe/`**: A custom, high-performance Byte-Pair Encoding (BPE) tokenizer written in Rust. This component is crucial for efficiently converting raw text into tokens that the model can process. The `README.md` within this directory explains the rationale for using Rust and provides a detailed guide to the BPE algorithm. + +- **`scripts/`**: This directory contains the scripts that drive the entire training and evaluation pipeline. The scripts are organized by their function, from training the tokenizer (`tok_train.py`) and the base model (`base_train.py`) to fine-tuning (`chat_sft.py`) and evaluation (`chat_eval.py`). Each script is documented to explain its role and usage. + +- **`tasks/`**: This directory defines the various tasks used for training and evaluating the model. Each file corresponds to a specific dataset or capability, such as grade-school math (`gsm8k.py`) or code generation (`humaneval.py`). The in-code documentation details what each task evaluates and how it is implemented. + +- **`dev/`**: Contains development-related scripts and resources. This includes a script for generating synthetic data to customize the model's identity (`gen_synthetic_data.py`) and a reference script for preparing the pre-training data (`repackage_data_reference.py`). The `runcpu.sh` script provides a documented example of how to run the entire pipeline on a CPU for testing and demonstration purposes. + +### End-to-End Data Flow + +1. **Data Preparation**: The process begins with preparing the training data. For pre-training, a large corpus of text is downloaded and repackaged into efficient shards (as documented in `dev/repackage_data_reference.py`). For fine-tuning, various task-specific datasets are used, as defined in the `tasks/` directory. + +2. **Tokenization**: A BPE tokenizer is trained on the prepared data using the script `scripts/tok_train.py`. This tokenizer, implemented in Rust for performance, converts the text data into a sequence of integer tokens. + +3. **Base Model Training**: The core language model is trained from scratch on the tokenized pre-training data using `scripts/base_train.py`. This phase is the most computationally intensive and is where the model learns general language patterns. + +4. **Mid-training and Fine-tuning**: After pre-training, the model undergoes "mid-training" (`scripts/mid_train.py`) and supervised fine-tuning (`scripts/chat_sft.py`) on a mixture of conversational and task-specific data. This is where the model learns to be a helpful assistant. + +5. **Evaluation and Inference**: Throughout the training process, the model's performance is evaluated on various benchmarks defined in `tasks/`. Once training is complete, you can interact with your custom LLM through a command-line interface (`scripts/chat_cli.py`) or a web UI (`scripts/chat_web.py`). + +This structured approach, combined with the detailed in-code documentation, is designed to make the entire process of building a custom LLM transparent, hackable, and educational. + ## Tests I haven't invested too much here but some tests exist, especially for the tokenizer. Run e.g. as: diff --git a/dev/gen_synthetic_data.py b/dev/gen_synthetic_data.py index 13e5f55..a67c7a5 100644 --- a/dev/gen_synthetic_data.py +++ b/dev/gen_synthetic_data.py @@ -1,32 +1,37 @@ +#--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*# +#_-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*# +# # +# Synthetic Data Generation for LLM Customization # +# # +#_-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*# +#--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*# """ -Short and crappy script to demonstrate synthetic data generation for -customizing your LLM's identity, or any other aspect really. +This script demonstrates how to generate synthetic data to customize an LLM's identity or other behaviors. -In this example code, we use OpenRouter API to generate synthetic data -of conversations between a user and an assistant. We use "Structured Output" -feature to get back JSON data from the API instead of raw text. The conversations -are saved simply to a .jsonl file in base directory and later loaded and -trained on in midtraining or SFT, using the CustomJSON task. +Overview: +The script uses the OpenRouter API to create conversational data between a user and an assistant. +It leverages the "Structured Output" feature to receive JSON data directly, which is more reliable +than parsing raw text. The generated conversations are saved to a `.jsonl` file in the project's +base directory. This data can then be used for mid-training or supervised fine-tuning (SFT) +with the `CustomJSON` task. -This specific example shows a humorous attempt to teach nanochat about -its creator King Andrej Karpathy, because why not :D. Note two things about the -prompt: +Example Use Case: +This particular example humorously teaches the `nanochat` model about its creator, +"King Andrej Karpathy." -1. We are instructing the LLM how to handle various situations (e.g. foreign language), - simply in English. You can infuse any style or behavior in this way. -2. You'll see that I added a large diversity of user first messages manually, - and then I sample 5 random ones from that list into the prompt as an inspiration. - This is really important to do because DIVERSITY CONTROL is key. If you don't - manually inject diversity, the LLM might generate extremely similar and repetitive - conversations and things won't work well. Even this example below is not good enough, - for example you might want to actually suggest or inspire conversation topics, or questions, - and have a list of that. Basically, this is the KEY creative part to get right. Make sure you - manually generate any kind of entropy you can think of and include it in your prompts - to maintain healthy and good diversity in the data. +Key Concepts in the Prompt Design: +1. **Behavioral Instruction:** The prompt instructs the LLM on how to handle specific scenarios, + such as responding to questions in a foreign language. This is a powerful way to infuse a + desired style or behavior into the model. +2. **Diversity Control:** A diverse list of initial user messages is provided. The script + randomly samples from this list to inspire varied conversations. This is crucial for + preventing the model from generating repetitive data. Ensuring high diversity in the + synthetic data is a key creative and technical challenge for successful customization. -NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the root directory of the repo. - (obviously you can tune this arbitrarily to your liking) -NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139 +Prerequisites: +- An OpenRouter API key must be saved in a file named `openroutertoken.txt` in the root + directory of this repository. +- For more background, see the discussion at: https://github.com/karpathy/nanochat/discussions/139 """ import requests import json diff --git a/dev/repackage_data_reference.py b/dev/repackage_data_reference.py index 32980a8..f67f5f3 100644 --- a/dev/repackage_data_reference.py +++ b/dev/repackage_data_reference.py @@ -1,17 +1,36 @@ +#--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*# +#_-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*# +# # +# Dataset Preparation Reference: FineWebEdu-100B # +# # +#_-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*# +#--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*# """ -Repackage the FinewebEdu-100B dataset into shards: +This script serves as a reference and documentation for the preparation of the +`FinewebEdu-100B` dataset. -- each shard is ~100MB in size (after zstd compression) -- parquets are written with row group size of 1000 -- shuffle the dataset +**NOTE: This file is not intended to be executed during the project's runtime.** -This will be uploaded to HuggingFace for hosting. -The big deal is that our DataLoader will be able to stream -the data and cache it along the way on disk, decreasing the -training latency. +Purpose of this Script: +The primary goal of this script is to transform the raw `FinewebEdu-100B` dataset into a more +efficient format for large-scale model training. The key steps are: -NOTE: This file is meant only as reference/documentation of the -dataset preparation and it is not used during the project runtime. +1. **Shuffling:** The entire dataset is shuffled to ensure that the data is presented to the + model in a random order, which is crucial for effective training. + +2. **Repackaging into Shards:** The shuffled dataset is broken down into smaller chunks, or "shards." + - Each shard is saved as a Parquet file. + - The target size for each compressed shard is approximately 100MB. + - This sharding strategy is vital for performance. It allows the DataLoader to stream the + dataset from a source (like the Hugging Face Hub) and cache it locally. This "just-in-time" + data loading significantly reduces training latency, as the model doesn't have to wait for + the entire massive dataset to be downloaded. + +3. **Uploading to Hugging Face:** After processing, the shards are uploaded to the Hugging Face Hub, + making them easily accessible for training runs. + +This preparation process is a critical step in enabling efficient and scalable training +for the nanochat project. """ import os import time diff --git a/dev/runcpu.sh b/dev/runcpu.sh index 469e51d..20d253d 100755 --- a/dev/runcpu.sh +++ b/dev/runcpu.sh @@ -1,27 +1,53 @@ #!/bin/bash +#--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*# +#_-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*# +# # +# CPU Demonstration and Test Run # +# # +#_-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*# +#--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*# +# This script provides an example run for exercising the project's code paths +# on a CPU or MPS (for Apple Silicon Macs). +# +# To run this script: +# bash dev/runcpu.sh +# +# IMPORTANT NOTE: +# Training Large Language Models (LLMs) is computationally intensive and requires +# significant GPU resources and budget. This script is intended as an educational +# tool and a demonstration. It allows you to verify that the code runs, but it +# will not produce a high-quality model. It is intentionally placed in the `dev/` +# directory to emphasize its development and testing purpose. -# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks) -# Run as: -# bash dev/cpu_demo_run.sh - -# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook. -# Think of this run as educational/fun demo, not something you should expect to work well. -# This is also why I hide this script away in dev/ - -# all the setup stuff +# --- Environment Setup --- +# Set the number of threads for OpenMP to 1 to avoid potential conflicts. export OMP_NUM_THREADS=1 +# Define and create the base directory for nanochat data and caches. export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" mkdir -p $NANOCHAT_BASE_DIR + +# Install 'uv', a fast Python package installer, if it's not already present. command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh +# Create a virtual environment if it doesn't exist. [ -d ".venv" ] || uv venv +# Sync dependencies, including the 'cpu' extra for CPU-only environments. uv sync --extra cpu +# Activate the virtual environment. source .venv/bin/activate + +# Set a dummy Weights & Biases run name if not already defined. if [ -z "$WANDB_RUN" ]; then WANDB_RUN=dummy fi + +# Install Rust if it's not already installed. curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +# Add Rust's package manager, Cargo, to the PATH. source "$HOME/.cargo/env" +# Build the Rust-based BPE tokenizer. uv run maturin develop --release --manifest-path rustbpe/Cargo.toml + +# Download and set up the evaluation bundle if it's not already cached. EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL @@ -30,17 +56,23 @@ if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then mv eval_bundle $NANOCHAT_BASE_DIR fi -# wipe the report +# --- Training and Evaluation Pipeline --- +# Reset any previous reports to start fresh. python -m nanochat.report reset -# train tokenizer on ~1B characters +# --- Tokenizer Training --- +# Download and prepare the dataset for tokenizer training. python -m nanochat.dataset -n 4 +# Train the tokenizer on approximately 1 billion characters of text. python -m scripts.tok_train --max_chars=1000000000 +# Evaluate the trained tokenizer. python -m scripts.tok_eval -# train a very small 4 layer model on the CPU -# each optimization step processes a single sequence of 1024 tokens -# we only run 50 steps of optimization (bump this to get better results) +# --- Base Model Training --- +# Train a very small, 4-layer model on the CPU. +# Note: This is a minimal run for demonstration purposes. +# - Each optimization step processes a single sequence of 1024 tokens. +# - The run consists of only 50 optimization steps. python -m scripts.base_train \ --depth=4 \ --max_seq_len=1024 \ @@ -52,10 +84,13 @@ python -m scripts.base_train \ --core_metric_max_per_task=12 \ --sample_every=50 \ --num_iterations=50 +# Evaluate the loss of the base model. python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096 +# Perform a more comprehensive evaluation of the base model. python -m scripts.base_eval --max-per-task=16 -# midtraining +# --- Mid-training --- +# Continue training the model on a mixture of tasks. python -m scripts.mid_train \ --max_seq_len=1024 \ --device_batch_size=1 \ @@ -63,11 +98,11 @@ python -m scripts.mid_train \ --eval_tokens=4096 \ --total_batch_size=1024 \ --num_iterations=100 -# eval results will be terrible, this is just to execute the code paths. -# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems +# Evaluate the mid-trained model. Results are expected to be poor due to the minimal training. python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20 -# SFT +# --- Supervised Fine-Tuning (SFT) --- +# Fine-tune the model on a dataset of conversations. python -m scripts.chat_sft \ --device_batch_size=1 \ --target_examples_per_step=4 \ @@ -75,10 +110,15 @@ python -m scripts.chat_sft \ --eval_steps=4 \ --eval_metrics_max_problems=16 -# Chat CLI +# --- Interactive Chat (Optional) --- +# Uncomment the following lines to interact with the model via the command line or a web interface. +# +# # Command-Line Interface # python -m scripts.chat_cli -p "Why is the sky blue?" - -# Chat Web +# +# # Web-based Interface # python -m scripts.chat_web +# --- Reporting --- +# Generate a final report summarizing the run. python -m nanochat.report generate diff --git a/nanochat/adamw.py b/nanochat/adamw.py index db591de..7304a37 100644 --- a/nanochat/adamw.py +++ b/nanochat/adamw.py @@ -1,6 +1,14 @@ """ -Borrowed from modded-nanogpt. By Keller, @vagrawal, et al. -Not a general optimizer! But works for our specific use. +This file implements a distributed version of the AdamW optimizer, tailored for nanochat. +AdamW is a variant of the Adam optimizer that decouples weight decay from the gradient update. +This distributed implementation is inspired by ZeRO-2, sharding optimizer states and gradients +across multiple devices to reduce memory consumption, enabling the training of larger models. + +A standard, non-distributed AdamW optimizer in PyTorch would be used as follows: +import torch +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01) +optimizer.step() +This distributed version achieves a similar outcome but coordinates across multiple processes. """ import torch import torch.distributed as dist @@ -9,8 +17,19 @@ from torch import Tensor class DistAdamW(torch.optim.Optimizer): """ - Distributed AdamW optimizer. - In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction + Distributed AdamW optimizer in the style of ZeRO-2. + + This optimizer shards optimizer states (e.g., moments) and gradients across multiple + devices. During the `step`, it performs a `reduce_scatter` to average gradients, + updates its portion of the parameters, and then uses `all_gather` to ensure all + devices have the updated parameters. + + Args: + param_groups: An iterable of parameter groups to optimize. + lr (float): The learning rate. + betas (tuple[float, float]): Coefficients for running averages of gradient and its square. + eps (float): Term for numerical stability. + weight_decay (float): Weight decay coefficient. """ def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) @@ -19,11 +38,15 @@ class DistAdamW(torch.optim.Optimizer): @torch.compile @torch.no_grad() def step(self): + """Performs a single optimization step.""" rank = dist.get_rank() world_size = dist.get_world_size() reduce_scatter_futures: list[torch.Future] = [] all_reduce_futures: list[torch.Future] = [] grad_slices = [] + + # 1. Asynchronously reduce-scatter gradients. + # Each device receives a slice of the averaged gradient. for group in self.param_groups: params: list[Tensor] = group["params"] for base_i in range(len(params)): @@ -40,6 +63,7 @@ class DistAdamW(torch.optim.Optimizer): wd = group['weight_decay'] params = group['params'] for base in range(len(params)): + # 2. Wait for the gradient slice and update parameters. reduce_scatter_futures[idx].wait() p = params[base] rank_size = p.shape[0] // world_size @@ -47,30 +71,36 @@ class DistAdamW(torch.optim.Optimizer): lr = group['lr'] * getattr(p, "lr_mul", 1.0) state = self.state[p] g_slice = grad_slices[idx] - # State init + + # State initialization if not state: state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) state['exp_avg'] = torch.zeros_like(p_slice) state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] state['step'] += 1 t = state['step'] - # weight decay + + # Apply weight decay if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages + p_slice.mul_(1 - lr * wd * getattr(p, "wd_mul", 1.0)) + + # AdamW updates exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections + bias1 = 1 - beta1 ** t bias2 = 1 - beta2 ** t - # compute step + denom = exp_avg_sq.sqrt().add_(eps) step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) + + p_slice.addcdiv_(exp_avg, denom, value=-step_size) + idx += 1 + # 3. Asynchronously gather updated parameter slices. all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + + # 4. Wait for all gather operations to complete. torch.futures.collect_all(all_reduce_futures).wait() diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f400d47..f9302f1 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -1,5 +1,23 @@ """ -Utilities for saving and loading model/optim/state checkpoints. +This module provides utilities for saving and loading model, optimizer, and training state +checkpoints. It is essential for resuming training and for deploying models for inference. + +A typical use case involves: +1. Calling `save_checkpoint` periodically during training. +2. Calling `load_checkpoint` to resume training or for inference. +3. Using `build_model` to reconstruct a model from a checkpoint. + +Python equivalent for basic checkpointing: +import torch +# Saving +torch.save({ + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), +}, "checkpoint.pt") +# Loading +checkpoint = torch.load("checkpoint.pt") +model.load_state_dict(checkpoint['model_state_dict']) +optimizer.load_state_dict(checkpoint['optimizer_state_dict']) """ import os import re @@ -17,10 +35,21 @@ from nanochat.common import setup_default_logging setup_default_logging() logger = logging.getLogger(__name__) def log0(message): + """Logs a message only on the main process (rank 0).""" if int(os.environ.get('RANK', 0)) == 0: logger.info(message) def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): + """ + Saves a checkpoint to the specified directory. + + Args: + checkpoint_dir (str): The directory to save the checkpoint to. + step (int): The current training step. + model_data (dict): The model's state_dict. + optimizer_data (dict): The optimizer's state_dict. + meta_data (dict): A dictionary of metadata to save. + """ assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now os.makedirs(checkpoint_dir, exist_ok=True) # Save the model state (parameters) @@ -40,6 +69,18 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data) def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False): + """ + Loads a checkpoint from the specified directory. + + Args: + checkpoint_dir (str): The directory to load the checkpoint from. + step (int): The training step of the checkpoint to load. + device (str): The device to load the tensors onto. + load_optimizer (bool, optional): Whether to load the optimizer state. Defaults to False. + + Returns: + tuple: A tuple containing the model data, optimizer data, and metadata. + """ # Load the model state model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") model_data = torch.load(model_path, map_location=device) @@ -57,11 +98,16 @@ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False): def build_model(checkpoint_dir, step, device, phase): """ - A bunch of repetitive code to build a model from a given checkpoint. + Builds a model from a given checkpoint. + + Args: + checkpoint_dir (str): The directory of the checkpoint. + step (int): The training step of the checkpoint. + device (str): The device to build the model on. + phase (str): The phase, either "train" or "eval". + Returns: - - base model - uncompiled, not wrapped in DDP - - tokenizer - - meta data saved during base model training + tuple: A tuple containing the model, tokenizer, and metadata. """ assert phase in ["train", "eval"], f"Invalid phase: {phase}" model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) @@ -89,6 +135,15 @@ def build_model(checkpoint_dir, step, device, phase): def find_largest_model(checkpoint_dir): + """ + Finds the largest model in a checkpoint directory, assuming a "d" naming convention. + + Args: + checkpoint_dir (str): The directory to search for models. + + Returns: + str: The tag of the largest model found. + """ # attempt to guess the model tag: take the biggest model available model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))] if not model_tags: @@ -109,6 +164,15 @@ def find_largest_model(checkpoint_dir): def find_last_step(checkpoint_dir): + """ + Finds the last training step in a checkpoint directory. + + Args: + checkpoint_dir (str): The directory to search for checkpoints. + + Returns: + int: The last training step found. + """ # Look into checkpoint_dir and find model_.pt with the highest step checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) if not checkpoint_files: @@ -120,6 +184,19 @@ def find_last_step(checkpoint_dir): # convenience functions that take into account nanochat's directory structure def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): + """ + Loads a model from a directory, automatically detecting the model tag and step if not provided. + + Args: + checkpoints_dir (str): The directory containing model checkpoints. + device (str): The device to load the model on. + phase (str): The phase, either "train" or "eval". + model_tag (str, optional): The model tag to load. Defaults to None. + step (int, optional): The step to load. Defaults to None. + + Returns: + tuple: A tuple containing the model, tokenizer, and metadata. + """ if model_tag is None: # guess the model tag by defaulting to the largest model model_tag = find_largest_model(checkpoints_dir) @@ -135,6 +212,17 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non return model, tokenizer, meta_data def load_model(source, *args, **kwargs): + """ + Loads a model from a specific source directory within the nanochat project. + + Args: + source (str): The source of the model, one of "base", "mid", "sft", or "rl". + *args: Positional arguments to pass to `load_model_from_dir`. + **kwargs: Keyword arguments to pass to `load_model_from_dir`. + + Returns: + tuple: A tuple containing the model, tokenizer, and metadata. + """ model_dir = { "base": "base_checkpoints", "mid": "mid_checkpoints", diff --git a/nanochat/common.py b/nanochat/common.py index a5a6d2e..4291287 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -1,5 +1,7 @@ """ -Common utilities for nanochat. +This module provides common utilities for the nanochat project, including logging, +file handling, distributed training setup, and device management. These functions are +used across various scripts to ensure consistency and reduce code duplication. """ import os @@ -11,7 +13,10 @@ import torch import torch.distributed as dist class ColoredFormatter(logging.Formatter): - """Custom formatter that adds colors to log messages.""" + """ + A custom logging formatter that adds ANSI color codes to log messages for + improved readability in the console. + """ # ANSI color codes COLORS = { 'DEBUG': '\033[36m', # Cyan @@ -37,6 +42,7 @@ class ColoredFormatter(logging.Formatter): return message def setup_default_logging(): + """Sets up the default logging configuration with the colored formatter.""" handler = logging.StreamHandler() handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) logging.basicConfig( @@ -48,6 +54,11 @@ setup_default_logging() logger = logging.getLogger(__name__) def get_base_dir(): + """ + Returns the base directory for nanochat data, creating it if it doesn't exist. + The directory is determined by the NANOCHAT_BASE_DIR environment variable, or + defaults to ~/.cache/nanochat. + """ # co-locate nanochat intermediates with other cached data in ~/.cache (by default) if os.environ.get("NANOCHAT_BASE_DIR"): nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR") @@ -60,8 +71,15 @@ def get_base_dir(): def download_file_with_lock(url, filename): """ - Downloads a file from a URL to a local path in the base directory. - Uses a lock file to prevent concurrent downloads among multiple ranks. + Downloads a file from a URL to a local path, using a lock file to prevent + concurrent downloads in a distributed setting. + + Args: + url (str): The URL of the file to download. + filename (str): The name of the file to save locally. + + Returns: + str: The path to the downloaded file. """ base_dir = get_base_dir() file_path = os.path.join(base_dir, filename) @@ -97,11 +115,13 @@ def download_file_with_lock(url, filename): return file_path def print0(s="",**kwargs): + """Prints a message only on the main process (rank 0).""" ddp_rank = int(os.environ.get('RANK', 0)) if ddp_rank == 0: print(s, **kwargs) def print_banner(): + """Prints the nanochat ASCII art banner.""" # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ banner = """ █████ █████ @@ -116,10 +136,12 @@ def print_banner(): print0(banner) def is_ddp(): + """Checks if the current process is running in a DDP environment.""" # TODO is there a proper way return int(os.environ.get('RANK', -1)) != -1 def get_dist_info(): + """Returns information about the DDP environment.""" if is_ddp(): assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) ddp_rank = int(os.environ['RANK']) @@ -130,6 +152,9 @@ def get_dist_info(): return False, 0, 0, 1 def autodetect_device_type(): + """ + Autodetects the best available device (CUDA, MPS, or CPU) and returns it. + """ # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU if torch.cuda.is_available(): device_type = "cuda" @@ -141,7 +166,16 @@ def autodetect_device_type(): return device_type def compute_init(device_type="cuda"): # cuda|cpu|mps - """Basic initialization that we keep doing over and over, so make common.""" + """ + Initializes the compute environment, including reproducibility settings, + precision, and DDP setup. + + Args: + device_type (str): The device type to use ("cuda", "mps", or "cpu"). + + Returns: + tuple: A tuple containing DDP info and the device. + """ assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" if device_type == "cuda": @@ -176,12 +210,15 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device def compute_cleanup(): - """Companion function to compute_init, to clean things up before script exit""" + """Cleans up the DDP process group.""" if is_ddp(): dist.destroy_process_group() class DummyWandb: - """Useful if we wish to not use wandb but have all the same signatures""" + """ + A dummy wandb class for environments where wandb is not used. It provides + the same method signatures as the real wandb object, but does nothing. + """ def __init__(self): pass def log(self, *args, **kwargs): diff --git a/nanochat/configurator.py b/nanochat/configurator.py index ec1b76d..e0bb4e4 100644 --- a/nanochat/configurator.py +++ b/nanochat/configurator.py @@ -1,17 +1,22 @@ """ -Poor Man's Configurator. Probably a terrible idea. Example usage: +This script provides a simple, unconventional configuration management system for nanochat. +It is not a standard Python module but is instead executed directly using `exec`, +allowing it to modify the global scope of the calling script. This design choice +prioritizes simplicity and avoids the complexity of more formal configuration +libraries. + +The configurator supports two types of overrides: +1. **Configuration files:** A Python file can be provided as a command-line + argument. The configurator will execute this file, which can be used to set + default configuration values. +2. **Command-line arguments:** Key-value pairs in the format `--key=value` can + be provided to override specific settings. The script attempts + to infer the correct data type for the value (e.g., int, float, bool). + +Example usage: $ python train.py config/override_file.py --batch_size=32 -this will first run config/override_file.py, then override batch_size to 32 -The code in this file will be run as follows from e.g. train.py: ->>> exec(open('configurator.py').read()) - -So it's not a Python module, it's just shuttling this code away from train.py -The code in this script then overrides the globals() - -I know people are not going to love this, I just really dislike configuration -complexity and having to prepend config. to every single variable. If someone -comes up with a better simple Python solution I am all ears. +A more conventional approach would use a library like `argparse` or `hydra`. """ import os @@ -19,13 +24,15 @@ import sys from ast import literal_eval def print0(s="",**kwargs): + """Prints a message only on the main process (rank 0).""" ddp_rank = int(os.environ.get('RANK', 0)) if ddp_rank == 0: print(s, **kwargs) +# Parse command-line arguments for configuration for arg in sys.argv[1:]: if '=' not in arg: - # assume it's the name of a config file + # If the argument does not contain '=', it is assumed to be a config file. assert not arg.startswith('--') config_file = arg print0(f"Overriding config with {config_file}:") @@ -33,23 +40,23 @@ for arg in sys.argv[1:]: print0(f.read()) exec(open(config_file).read()) else: - # assume it's a --key=value argument + # If the argument contains '=', it is assumed to be a key-value override. assert arg.startswith('--') - key, val = arg.split('=') + key, val = arg.split('=', 1) key = key[2:] if key in globals(): try: - # attempt to eval it it (e.g. if bool, number, or etc) + # Attempt to evaluate the value to infer its type (e.g., int, bool). attempt = literal_eval(val) except (SyntaxError, ValueError): - # if that goes wrong, just use the string + # If evaluation fails, treat the value as a string. attempt = val - # ensure the types match ok + # Ensure that the overridden value has the same type as the default. if globals()[key] is not None: attempt_type = type(attempt) default_type = type(globals()[key]) - assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}" - # cross fingers + assert attempt_type == default_type, f"Type mismatch for key '{key}': expected {default_type}, got {attempt_type}" + # Update the global variable with the new value. print0(f"Overriding: {key} = {attempt}") globals()[key] = attempt else: diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py index f3c9a9f..9fc1390 100644 --- a/nanochat/core_eval.py +++ b/nanochat/core_eval.py @@ -1,9 +1,14 @@ """ -Functions for evaluating the CORE metric, as described in the DCLM paper. -https://arxiv.org/abs/2406.11794 +This module implements the evaluation of the CORE (Comprehensive Overall Language Evaluation) +metric, as described in the DCLM paper (https://arxiv.org/abs/2406.11794). It provides +a standardized way to assess the performance of language models on a variety of tasks. -TODOs: -- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why. +The evaluation process involves: +1. Rendering prompts for each task using Jinja2 templates. +2. Tokenizing the prompts and preparing them for batch processing. +3. Forwarding the inputs through the model to get predictions. +4. Calculating the accuracy for each task. +5. Aggregating the results to compute the final CORE score. """ import random @@ -15,7 +20,17 @@ import torch.distributed as dist # Prompt rendering utilities def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None): - """Render complete prompts for a multiple choice question""" + """ + Renders prompts for a multiple-choice question. + + Args: + item (dict): The data item containing the query, choices, and gold answer. + continuation_delimiter (str): The delimiter to separate context and continuation. + fewshot_examples (list, optional): A list of few-shot examples. Defaults to None. + + Returns: + list: A list of rendered prompts, one for each choice. + """ template_str = """ {%- for example in fewshot_examples -%} {{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }} @@ -34,7 +49,17 @@ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None): def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None): - """Render complete prompts for a schema question""" + """ + Renders prompts for a schema-based question. + + Args: + item (dict): The data item for the schema task. + continuation_delimiter (str): The delimiter. + fewshot_examples (list, optional): Few-shot examples. Defaults to None. + + Returns: + list: A list of rendered prompts. + """ template_str = """ {%- for example in fewshot_examples -%} {{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }} @@ -55,9 +80,15 @@ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None): def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None): """ - Render complete prompt for a language modeling task. - Notice that we manually trim the context in the template, - which in some datasets seems to have trailing whitespace (which we don't want). + Renders prompts for a language modeling task. + + Args: + item (dict): The data item for the language modeling task. + continuation_delimiter (str): The delimiter. + fewshot_examples (list, optional): Few-shot examples. Defaults to None. + + Returns: + list: A list containing two prompts: one with and one without the continuation. """ template_str = """ {%- for example in fewshot_examples -%} @@ -85,8 +116,14 @@ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None): def find_common_length(token_sequences, direction='left'): """ - Find the length of the common prefix or suffix across token sequences - - direction: 'left' for prefix, 'right' for suffix + Finds the length of the common prefix or suffix in a list of token sequences. + + Args: + token_sequences (list): A list of token sequences. + direction (str): 'left' for prefix, 'right' for suffix. + + Returns: + int: The length of the common part. """ min_len = min(len(seq) for seq in token_sequences) indices = { @@ -102,7 +139,16 @@ def find_common_length(token_sequences, direction='left'): def stack_sequences(tokens, pad_token_id): - """Stack up a list of token sequences, pad to longest on the right""" + """ + Stacks a list of token sequences into a padded tensor. + + Args: + tokens (list): A list of token sequences. + pad_token_id (int): The ID of the padding token. + + Returns: + torch.Tensor: A padded tensor of token IDs. + """ bsz, seq_len = len(tokens), max(len(x) for x in tokens) input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long) for i, x in enumerate(tokens): @@ -111,6 +157,7 @@ def stack_sequences(tokens, pad_token_id): def batch_sequences_mc(tokenizer, prompts): + """Prepares a batch of sequences for a multiple-choice task.""" # In multiple choice, contexts are the same but the continuation is different (common prefix) tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) # figure out the start and end of each continuation @@ -121,6 +168,7 @@ def batch_sequences_mc(tokenizer, prompts): def batch_sequences_schema(tokenizer, prompts): + """Prepares a batch of sequences for a schema task.""" # In schema tasks, contexts vary but continuation is the same (common suffix) tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) # figure out the start and end of each context @@ -131,6 +179,7 @@ def batch_sequences_schema(tokenizer, prompts): def batch_sequences_lm(tokenizer, prompts): + """Prepares a batch of sequences for a language modeling task.""" # In LM tasks, we have two prompts: without and with continuation tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) tokens_without, tokens_with = tokens @@ -144,8 +193,14 @@ def batch_sequences_lm(tokenizer, prompts): @torch.no_grad() def forward_model(model, input_ids): """ - Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions. - The last column of losses is set to nan because we don't have autoregressive targets there. + Performs a forward pass through the model and computes losses and predictions. + + Args: + model (torch.nn.Module): The language model. + input_ids (torch.Tensor): The input token IDs. + + Returns: + tuple: A tuple containing the losses and predictions. """ batch_size, seq_len = input_ids.size() outputs = model(input_ids) @@ -166,7 +221,16 @@ def forward_model(model, input_ids): @torch.no_grad() def evaluate_example(idx, model, tokenizer, data, device, task_meta): - """Evaluate a single example, return True if correct, False otherwise""" + """ + Evaluates a single example from a task. + + Args: + idx (int): The index of the example in the dataset. + model, tokenizer, data, device, task_meta: Evaluation parameters. + + Returns: + bool: True if the model's prediction is correct, False otherwise. + """ item = data[idx] task_type = task_meta['task_type'] num_fewshot = task_meta['num_fewshot'] @@ -243,8 +307,13 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta): def evaluate_task(model, tokenizer, data, device, task_meta): """ - This function is responsible for evaluating one task across many examples. - It also handles dispatch to all processes if the script is run with torchrun. + Evaluates a task across all its examples, handling distributed evaluation. + + Args: + model, tokenizer, data, device, task_meta: Evaluation parameters. + + Returns: + float: The mean accuracy for the task. """ rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 6c864d3..e87bd4d 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -7,7 +7,25 @@ from nanochat.dataset import parquets_iter_batched from nanochat.tokenizer import get_tokenizer def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"): - """Stream pretraining text from parquet files, tokenize, yield training batches.""" + """ + Streams text from Parquet files, tokenizes it, and yields training batches. + + This data loader is designed for large-scale pretraining, where the entire dataset + cannot fit into memory. It streams data from disk, tokenizes it on the fly, and + yields batches of data indefinitely. It also supports distributed training by + sharding the data across multiple devices. + + Args: + B (int): The batch size. + T (int): The sequence length. + split (str): The data split to use, either "train" or "val". + tokenizer_threads (int, optional): The number of threads for tokenization. + tokenizer_batch_size (int, optional): The number of documents to tokenize at once. + device (str, optional): The device to move the batches to. + + Yields: + tuple: A tuple containing the input and target tensors. + """ assert split in ["train", "val"], "split must be 'train' or 'val'" ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() needed_tokens = B * T + 1 # +1 is because we also need the target at the last token diff --git a/nanochat/dataset.py b/nanochat/dataset.py index 602daed..ae0834c 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -1,10 +1,11 @@ """ -The base/pretraining dataset is a set of parquet files. -This file contains utilities for: -- iterating over the parquet files and yielding documents from it -- download the files on demand if they are not on disk +This module provides utilities for managing the pretraining dataset, which consists of +a collection of Parquet files. It handles on-demand downloading of the dataset shards +from a remote URL and provides an iterator for streaming the data efficiently. -For details of how the dataset was prepared, see `repackage_data_reference.py`. +The script can also be run directly to download the dataset shards in parallel. + +For details on how the dataset was created, see `dev/repackage_data_reference.py`. """ import os @@ -31,7 +32,15 @@ os.makedirs(DATA_DIR, exist_ok=True) # These functions are useful utilities to other modules, can/should be imported def list_parquet_files(data_dir=None): - """ Looks into a data dir and returns full paths to all parquet files. """ + """ + Lists all Parquet files in a given directory. + + Args: + data_dir (str, optional): The directory to search. Defaults to DATA_DIR. + + Returns: + list: A sorted list of full paths to the Parquet files. + """ data_dir = DATA_DIR if data_dir is None else data_dir parquet_files = sorted([ f for f in os.listdir(data_dir) @@ -42,9 +51,16 @@ def list_parquet_files(data_dir=None): def parquets_iter_batched(split, start=0, step=1): """ - Iterate through the dataset, in batches of underlying row_groups for efficiency. - - split can be "train" or "val". the last parquet file will be val. - - start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size + Iterates through the dataset in batches of row groups for efficiency. + + Args: + split (str): "train" or "val". The last Parquet file is used for validation. + start (int, optional): The starting row group index. Defaults to 0. + step (int, optional): The step size for iterating through row groups. + Useful for distributed training. Defaults to 1. + + Yields: + list: A list of texts from a row group. """ assert split in ["train", "val"], "split must be 'train' or 'val'" parquet_paths = list_parquet_files() @@ -58,7 +74,15 @@ def parquets_iter_batched(split, start=0, step=1): # ----------------------------------------------------------------------------- def download_single_file(index): - """ Downloads a single file index, with some backoff """ + """ + Downloads a single dataset shard with retries and exponential backoff. + + Args: + index (int): The index of the shard to download. + + Returns: + bool: True if the download was successful, False otherwise. + """ # Construct the local filepath for this file and skip if it already exists filename = index_to_filename(index) diff --git a/nanochat/engine.py b/nanochat/engine.py index 44ed16b..de7306f 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -1,14 +1,16 @@ """ -Engine for efficient inference of our models. +This module provides an efficient inference engine for nanochat models. It is designed +to be fast and memory-efficient, using a key-value (KV) cache to avoid recomputing +attention for previous tokens. The engine operates on token ID sequences and is +agnostic to the tokenization process. -Everything works around token sequences: -- The user can send token sequences to the engine -- The engine returns the next token +The main components are: +- `Engine`: The main class that orchestrates the generation process. +- `KVCache`: A helper class that manages the KV cache for the GPT model. +- `RowState`: A class to track the state of each individual generation sequence. -Notes: -- The engine knows nothing about tokenization, it's purely token id sequences. - -The whole thing is made as efficient as possible. +The engine also includes a safe `use_calculator` function for evaluating Python +expressions, which is used as a tool by the model. """ import torch @@ -24,6 +26,7 @@ from nanochat.checkpoint_manager import load_model # Calculator tool helpers @contextmanager def timeout(duration, formula): + """A context manager to enforce a timeout on a block of code.""" def timeout_handler(signum, frame): raise Exception(f"'{formula}': timed out after {duration} seconds") @@ -33,6 +36,7 @@ def timeout(duration, formula): signal.alarm(0) def eval_with_timeout(formula, max_time=3): + """Evaluates a Python expression with a timeout.""" try: with timeout(max_time, formula): with warnings.catch_warnings(): @@ -45,8 +49,13 @@ def eval_with_timeout(formula, max_time=3): def use_calculator(expr): """ - Evaluate a Python expression safely. - Supports both math expressions and string operations like .count() + Safely evaluates a Python expression. Supports both math and string operations. + + Args: + expr (str): The expression to evaluate. + + Returns: + The result of the evaluation, or None if it fails or is unsafe. """ # Remove commas from numbers expr = expr.replace(",", "") @@ -81,8 +90,14 @@ def use_calculator(expr): # ----------------------------------------------------------------------------- class KVCache: """ - Works hand-in-hand with the GPT model to maintain the KV cache. - Note that the .pos advances automatically after the last layer of the Transformer inserts. + Manages the Key-Value cache for the GPT model's attention layers. + + Args: + batch_size (int): The batch size. + num_heads (int): The number of attention heads. + seq_len (int): The sequence length. + head_dim (int): The dimension of each attention head. + num_layers (int): The number of layers in the Transformer. """ def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers): @@ -92,16 +107,18 @@ class KVCache: self.pos = 0 # current position in time in the cache def reset(self): + """Resets the cache position.""" self.pos = 0 def get_pos(self): + """Returns the current cache position.""" return self.pos def prefill(self, other): """ - Prefill given another KV cache. Optionally expand along batch dim. - This is used when we do batch 1 prefill and then want to generate - multiple samples in parallel from there. + Prefills the cache with the contents of another KV cache. This is useful + for batch generation, where a single prompt is used to generate multiple + samples. """ # 1) validate the shapes assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" @@ -125,6 +142,7 @@ class KVCache: self.pos = other.pos def insert_kv(self, layer_idx, k, v): + """Inserts a key-value pair into the cache.""" # Lazy initialize the cache here because we need to know the dtype/device if self.kv_cache is None: self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device) @@ -155,7 +173,18 @@ class KVCache: # ----------------------------------------------------------------------------- @torch.inference_mode() def sample_next_token(logits, rng, temperature=1.0, top_k=None): - """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1).""" + """ + Samples the next token from the model's logits. + + Args: + logits (torch.Tensor): The logits from the model. + rng (torch.Generator): A random number generator for reproducibility. + temperature (float, optional): The sampling temperature. Defaults to 1.0. + top_k (int, optional): The number of top-k tokens to consider. Defaults to None. + + Returns: + torch.Tensor: The sampled token IDs. + """ assert temperature >= 0.0, "temperature must be non-negative" if temperature == 0.0: return torch.argmax(logits, dim=-1, keepdim=True) @@ -174,6 +203,7 @@ def sample_next_token(logits, rng, temperature=1.0, top_k=None): # ----------------------------------------------------------------------------- class RowState: + """Tracks the state of a single row during generation.""" # Per-row state tracking during generation def __init__(self, current_tokens=None): self.current_tokens = current_tokens or [] # Current token sequence for this row @@ -183,6 +213,13 @@ class RowState: self.completed = False # Whether this row has completed generation class Engine: + """ + The main inference engine. + + Args: + model (torch.nn.Module): The GPT model. + tokenizer: The tokenizer. + """ def __init__(self, model, tokenizer): self.model = model @@ -190,7 +227,13 @@ class Engine: @torch.inference_mode() def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42): - """Same as generate, but does single prefill and then clones the KV cache.""" + """ + Generates token sequences in a streaming fashion. + + Yields: + tuple: A tuple containing the token column and a mask indicating + whether the tokens were sampled or forced. + """ assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints" device = self.model.get_device() rng = torch.Generator(device=device) @@ -296,9 +339,7 @@ class Engine: def generate_batch(self, tokens, num_samples=1, **kwargs): """ - Non-streaming batch generation that just returns the final token sequences. - Returns a list of token sequences (list of lists of ints). - Terminal tokens (assistant_end, bos) are not included in the results. + A non-streaming version of `generate` that returns the final token sequences. """ assistant_end = self.tokenizer.encode_special("<|assistant_end|>") bos = self.tokenizer.get_bos_token_id() diff --git a/nanochat/execution.py b/nanochat/execution.py index 6f50c74..82fee93 100644 --- a/nanochat/execution.py +++ b/nanochat/execution.py @@ -1,24 +1,17 @@ """ -Sandboxed execution utilities for running Python code that comes out of an LLM. -Adapted from OpenAI HumanEval code: -https://github.com/openai/human-eval/blob/master/human_eval/execution.py +This module provides utilities for executing Python code generated by a language model in a +sandboxed environment. It is adapted from the OpenAI HumanEval project and aims to provide +a safe way to run untrusted code. -What is covered: -- Each execution runs in its own process (can be killed if it hangs or crashes) -- Execution is limited by a timeout to stop infinite loops -- Memory limits are enforced by default (256MB) -- stdout and stderr are captured and returned -- Code runs in a temporary directory that is deleted afterwards -- Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen) +The sandbox provides the following protections: +- **Process Isolation:** Each execution runs in a separate process. +- **Timeout:** A time limit is enforced to prevent infinite loops. +- **Memory Limits:** The memory usage of the executed code is restricted. +- **Filesystem Isolation:** Code is executed in a temporary directory. +- **Function Disabling:** Potentially dangerous functions are disabled. -What is not covered: -- Not a true security sandbox -- Network access is not blocked (e.g. sockets could be opened) -- Python's dynamic features (e.g. ctypes) could bypass restrictions -- No kernel-level isolation (no seccomp, no containers, no virtualization) - -Overall this sandbox is good for evaluation of generated code and protects against -accidental destructive behavior, but it is not safe against malicious adversarial code. +**Disclaimer:** This is not a true security sandbox and should not be used for +running malicious or untrusted code in a production environment. """ import contextlib @@ -36,7 +29,17 @@ from typing import Optional @dataclass class ExecutionResult: - """Result of executing Python code in a sandbox.""" + """ + Represents the result of a sandboxed code execution. + + Attributes: + success (bool): Whether the execution was successful. + stdout (str): The captured standard output. + stderr (str): The captured standard error. + error (Optional[str]): Any error message, if an exception occurred. + timeout (bool): Whether the execution timed out. + memory_exceeded (bool): Whether the execution exceeded the memory limit. + """ success: bool stdout: str stderr: str @@ -63,6 +66,7 @@ class ExecutionResult: @contextlib.contextmanager def time_limit(seconds: float): + """A context manager to enforce a time limit on a block of code.""" def signal_handler(signum, frame): raise TimeoutException("Timed out!") @@ -76,7 +80,7 @@ def time_limit(seconds: float): @contextlib.contextmanager def capture_io(): - """Capture stdout and stderr, and disable stdin.""" + """Captures stdout and stderr, and blocks stdin.""" stdout_capture = io.StringIO() stderr_capture = io.StringIO() stdin_block = WriteOnlyStringIO() @@ -88,6 +92,7 @@ def capture_io(): @contextlib.contextmanager def create_tempdir(): + """Creates a temporary directory and changes the current working directory to it.""" with tempfile.TemporaryDirectory() as dirname: with chdir(dirname): yield dirname @@ -98,7 +103,7 @@ class TimeoutException(Exception): class WriteOnlyStringIO(io.StringIO): - """StringIO that throws an exception when it's read from""" + """A StringIO that raises an IOError when read from.""" def read(self, *args, **kwargs): raise IOError @@ -120,6 +125,7 @@ class redirect_stdin(contextlib._RedirectStream): # type: ignore @contextlib.contextmanager def chdir(root): + """A context manager to temporarily change the current directory.""" if root == ".": yield return @@ -133,15 +139,8 @@ def chdir(root): def reliability_guard(maximum_memory_bytes: Optional[int] = None): """ - This disables various destructive functions and prevents the generated code - from interfering with the test (e.g. fork bomb, killing other processes, - removing filesystem files, etc.) - - WARNING - This function is NOT a security sandbox. Untrusted code, including, model- - generated code, should not be blindly executed outside of one. See the - Codex paper for more information about OpenAI's code sandbox, and proceed - with caution. + Disables dangerous functions and sets resource limits to protect against + accidental destructive behavior. """ if platform.uname().system != "Darwin": @@ -212,7 +211,10 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict): - """Execute code in a subprocess with safety guards. Results are written to result_dict.""" + """ + The target function for the subprocess. It executes the code in a sandboxed + environment and stores the result in a shared dictionary. + """ with create_tempdir(): # These system calls are needed when cleaning up tempdir. @@ -289,22 +291,16 @@ def execute_code( maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default ) -> ExecutionResult: """ - Execute Python code in a sandboxed environment. + Executes Python code in a sandboxed environment using a separate process. Args: - code: Python code to execute as a string - timeout: Maximum execution time in seconds (default: 5.0) - maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable) + code (str): The Python code to execute. + timeout (float, optional): The maximum execution time in seconds. Defaults to 5.0. + maximum_memory_bytes (Optional[int], optional): The memory limit in bytes. + Defaults to 256MB. Returns: - ExecutionResult with success status, stdout/stderr, and error information - - Example: - >>> result = execute_code("print('hello world')") - >>> result.success - True - >>> result.stdout - 'hello world\\n' + ExecutionResult: An object containing the results of the execution. """ manager = multiprocessing.Manager() diff --git a/nanochat/gpt.py b/nanochat/gpt.py index b640f1e..5c54e41 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -1,14 +1,11 @@ """ -GPT model (rewrite, a lot simpler) -Notable features: -- rotary embeddings (and no positional embeddings) -- QK norm -- untied weights for token embedding and lm_head -- relu^2 activation in MLP -- norm after token embedding -- no learnable params in rmsnorm -- no bias in linear layers -- Multi-Query Attention (MQA) support for more efficient inference +This module implements the GPT (Generative Pre-trained Transformer) model for nanochat. +It features several modern architectural choices for improved performance and efficiency: +- Rotary Positional Embeddings (RoPE) +- QK Norm for attention stabilization +- SwiGLU activation in the MLP +- RMSNorm for normalization +- Multi-Query Attention (MQA) for efficient inference """ import math @@ -25,6 +22,14 @@ from nanochat.adamw import DistAdamW @dataclass class GPTConfig: + """ + Configuration for the GPT model. + + Attributes: + sequence_len (int): The maximum sequence length. + vocab_size (int): The size of the vocabulary. + n_layer (int): The number of transformer layers. + """ sequence_len: int = 1024 vocab_size: int = 50304 n_layer: int = 12 @@ -34,11 +39,13 @@ class GPTConfig: def norm(x): + """A functional RMSNorm without learnable parameters.""" # Purely functional rmsnorm with no learnable params return F.rms_norm(x, (x.size(-1),)) def apply_rotary_emb(x, cos, sin): + """Applies rotary positional embeddings to the input tensor.""" assert x.ndim == 4 # multihead attention d = x.shape[3] // 2 x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves @@ -49,6 +56,7 @@ def apply_rotary_emb(x, cos, sin): return out class CausalSelfAttention(nn.Module): + """The causal self-attention mechanism.""" def __init__(self, config, layer_idx): super().__init__() self.layer_idx = layer_idx @@ -111,6 +119,7 @@ class CausalSelfAttention(nn.Module): class MLP(nn.Module): + """The multi-layer perceptron block.""" def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) @@ -124,6 +133,7 @@ class MLP(nn.Module): class Block(nn.Module): + """A single transformer block.""" def __init__(self, config, layer_idx): super().__init__() self.attn = CausalSelfAttention(config, layer_idx) @@ -136,6 +146,7 @@ class Block(nn.Module): class GPT(nn.Module): + """The GPT model.""" def __init__(self, config): super().__init__() self.config = config @@ -155,6 +166,7 @@ class GPT(nn.Module): self.register_buffer("sin", sin, persistent=False) def init_weights(self): + """Initializes the model weights.""" self.apply(self._init_weights) # zero out classifier weights torch.nn.init.zeros_(self.lm_head.weight) @@ -211,6 +223,7 @@ class GPT(nn.Module): return num_flops_per_token def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0): + """Sets up the optimizers for the model.""" model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() # Separate out all parameters into 3 groups (matrix, embedding, lm_head) @@ -242,6 +255,7 @@ class GPT(nn.Module): return optimizers def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): + """The forward pass of the model.""" B, T = idx.size() # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) @@ -277,12 +291,7 @@ class GPT(nn.Module): @torch.inference_mode() def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42): - """ - Naive autoregressive streaming inference. - To make it super simple, let's assume: - - batch size is 1 - - ids and the yielded tokens are simple Python lists and ints - """ + """A naive, streaming inference implementation.""" assert isinstance(tokens, list) device = self.get_device() rng = None diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index 6fcbea3..cf2f634 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -1,5 +1,7 @@ """ -A number of functions that help with evaluating a base model. +This module provides functions for evaluating a base language model, with a focus on +the bits-per-byte (BPB) metric. BPB is a more robust alternative to cross-entropy +loss, as it is independent of the tokenizer's vocabulary size. """ import math import torch @@ -8,21 +10,18 @@ import torch.distributed as dist @torch.no_grad() def evaluate_bpb(model, batches, steps, token_bytes): """ - Instead of the naive 'mean loss', this function returns the bits per byte (bpb), - which is a tokenization vocab size-indepedent metric, meaning you are still comparing - apples:apples if you change the vocab size. The way this works is that instead of just - calculating the average loss as usual, you calculate the sum loss, and indepependently - also the sum bytes (of all the target tokens), and divide. This normalizes the loss by - the number of bytes that the target tokens represent. + Evaluates the model's performance using the bits-per-byte (BPB) metric. - The added complexity is so that: - 1) All "normal" tokens are normalized by the length of the token in bytes - 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out. - 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric. + Args: + model (torch.nn.Module): The language model to evaluate. + batches (iterable): An iterator that yields batches of (x, y) tensors. + steps (int): The number of evaluation steps to perform. + token_bytes (torch.Tensor): a 1D tensor of shape (vocab_size,) where each + entry is the number of bytes for the corresponding token ID, or 0 if + the token should be ignored (e.g., special tokens). - In addition to evaluate_loss, we need the token_bytes tensor: - It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for - each token id, or 0 if the token is to not be counted (e.g. special tokens). + Returns: + float: The calculated bits-per-byte value. """ # record the losses total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device()) diff --git a/nanochat/muon.py b/nanochat/muon.py index d916103..52c302f 100644 --- a/nanochat/muon.py +++ b/nanochat/muon.py @@ -1,6 +1,15 @@ """ -Muon optimizer from Keller et al. -Also a lot of borrowing of ideas from modded-nanogpt. +This module implements the Muon optimizer, a novel optimization algorithm that combines +SGD with momentum and an orthogonalization step. The key idea is to replace the +gradient update with the nearest orthogonal matrix, which can help to stabilize +training and improve performance, especially for matrix-like parameters. + +The orthogonalization is performed efficiently using the Newton-Schulz iteration. +The module provides both a standard `Muon` optimizer and a `DistMuon` version for +distributed training. + +**Reference:** +- Muon Optimizer: https://kellerjordan.github.io/posts/muon/ """ import torch from torch import Tensor @@ -9,13 +18,15 @@ import torch.distributed as dist @torch.compile def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Computes the zeroth power of a matrix G using a quintic Newton-Schulz iteration, + which effectively orthogonalizes the matrix. + + Args: + G (Tensor): The input matrix. + steps (int): The number of Newton-Schulz iterations. + + Returns: + Tensor: The orthogonalized matrix. """ assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng a, b, c = (3.4445, -4.7750, 2.0315) @@ -37,25 +48,17 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: class Muon(torch.optim.Optimizer): """ - Muon - MomentUm Orthogonalized by Newton-schulz + The Muon optimizer (Momentum Orthogonalized by Newton-Schulz). - https://kellerjordan.github.io/posts/muon/ + This optimizer combines SGD with momentum and an orthogonalization step based + on the Newton-Schulz iteration. It is designed for 2D parameters. - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. - - Arguments: - lr: The learning rate used by the internal SGD. - momentum: The momentum used by the internal SGD. - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iteration steps to use. + Args: + params: An iterable of parameters to optimize. + lr (float, optional): The learning rate. Defaults to 0.02. + momentum (float, optional): The momentum factor. Defaults to 0.95. + nesterov (bool, optional): Whether to use Nesterov momentum. Defaults to True. + ns_steps (int, optional): The number of Newton-Schulz steps. Defaults to 5. """ def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) @@ -85,24 +88,10 @@ class Muon(torch.optim.Optimizer): class DistMuon(torch.optim.Optimizer): """ - Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz, - finally apply aspect-ratio scaled step. Performs its own distributed synchronization: - - reduce_scatter(AVG) for gradient averaging - - all_gather to replicate updated weights + A distributed version of the Muon optimizer. - Notes: - * Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D - params like embeddings or scalars. - * Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen - by block-cyclic assignment below). If you checkpoint optimizer state on a single rank, - consolidate states beforehand. - - Args: - params: iterable of Tensors - lr: learning rate - momentum: momentum coefficient in [0,1) - nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf - ns_steps: number of Newton–Schulz iterations for the orthogonalization + This optimizer handles its own distributed synchronization using `reduce_scatter` + for gradients and `all_gather` for updated weights. """ def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, nesterov: bool = True, ns_steps: int = 5): diff --git a/nanochat/report.py b/nanochat/report.py index d0a65e0..e13e26b 100644 --- a/nanochat/report.py +++ b/nanochat/report.py @@ -1,5 +1,14 @@ """ -Utilities for generating training report cards. More messy code than usual, will fix. +This module provides utilities for generating comprehensive training reports for +nanochat. The reports are designed to be "report cards" for a training run, +capturing key information about the environment, hardware, software, and model +performance at various stages. + +The main components are: +- `Report`: A class that manages the logging of different sections of the report + and generates a final Markdown report. +- Helper functions to gather system information, such as Git status, GPU details, + and package dependencies. """ import os @@ -13,7 +22,7 @@ import psutil import torch def run_command(cmd): - """Run a shell command and return output, or None if it fails.""" + """Executes a shell command and returns its output.""" try: result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5) if result.returncode == 0: @@ -23,7 +32,7 @@ def run_command(cmd): return None def get_git_info(): - """Get current git commit, branch, and dirty status.""" + """Gathers information about the current Git repository.""" info = {} info['commit'] = run_command("git rev-parse --short HEAD") or "unknown" info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown" @@ -39,7 +48,7 @@ def get_git_info(): return info def get_gpu_info(): - """Get GPU information.""" + """Gathers information about the available GPUs.""" if not torch.cuda.is_available(): return {"available": False} @@ -62,7 +71,7 @@ def get_gpu_info(): return info def get_system_info(): - """Get system information.""" + """Gathers general system information.""" info = {} # Basic system info @@ -84,7 +93,7 @@ def get_system_info(): return info def estimate_cost(gpu_info, runtime_hours=None): - """Estimate training cost based on GPU type and runtime.""" + """Estimates the training cost based on GPU type and runtime.""" # Rough pricing, from Lambda Cloud default_rate = 2.0 @@ -115,7 +124,7 @@ def estimate_cost(gpu_info, runtime_hours=None): } def generate_header(): - """Generate the header for a training report.""" + """Generates the header section of the report.""" timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") git_info = get_git_info() @@ -187,7 +196,7 @@ Generated: {timestamp} # ----------------------------------------------------------------------------- def slugify(text): - """Slugify a text string.""" + """Converts a string into a URL-friendly slug.""" return text.lower().replace(" ", "-") # the expected files and their order @@ -208,7 +217,7 @@ EXPECTED_FILES = [ chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"] def extract(section, keys): - """simple def to extract a single key from a section""" + """Extracts values for the given keys from a section of text.""" if not isinstance(keys, list): keys = [keys] # convenience out = {} @@ -219,7 +228,7 @@ def extract(section, keys): return out def extract_timestamp(content, prefix): - """Extract timestamp from content with given prefix.""" + """Extracts a timestamp from a string with a given prefix.""" for line in content.split('\n'): if line.startswith(prefix): time_str = line.split(":", 1)[1].strip() @@ -230,14 +239,25 @@ def extract_timestamp(content, prefix): return None class Report: - """Maintains a bunch of logs, generates a final markdown report.""" + """ + Manages the creation and generation of a training report. + + Args: + report_dir (str): The directory to store report files. + """ def __init__(self, report_dir): os.makedirs(report_dir, exist_ok=True) self.report_dir = report_dir def log(self, section, data): - """Log a section of data to the report.""" + """ + Logs a section of data to a Markdown file in the report directory. + + Args: + section (str): The title of the section (e.g., "Base model training"). + data (list): A list of items to log. Items can be strings or dicts. + """ slug = slugify(section) file_name = f"{slug}.md" file_path = os.path.join(self.report_dir, file_name) @@ -265,7 +285,11 @@ class Report: return file_path def generate(self): - """Generate the final report.""" + """ + Generates the final consolidated report by combining all the logged + Markdown files into a single `report.md` file. It also creates a + summary table of the key metrics. + """ report_dir = self.report_dir report_file = os.path.join(report_dir, "report.md") print(f"Generating report to {report_file}") @@ -359,7 +383,11 @@ class Report: return report_file def reset(self): - """Reset the report.""" + """ + Resets the report directory by deleting all the section files and the + main `report.md` file. It then creates a new header file with the + current environment information. + """ # Remove section files for file_name in EXPECTED_FILES: file_path = os.path.join(self.report_dir, file_name) @@ -382,12 +410,22 @@ class Report: # nanochat-specific convenience functions class DummyReport: + """ + A dummy report class that does nothing, for use on non-rank-0 processes in a + distributed setting. This prevents processes other than the master from + writing to the report files. + """ def log(self, *args, **kwargs): pass def reset(self, *args, **kwargs): pass def get_report(): + """ + Returns a `Report` instance on the master process (rank 0) and a `DummyReport` + instance on all other processes. This ensures that only the master process + handles report generation. + """ # just for convenience, only rank 0 logs to report from nanochat.common import get_base_dir, get_dist_info ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index 880f854..97d3333 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -1,9 +1,13 @@ """ -BPE Tokenizer in the style of GPT-4. +This module provides Byte-Pair Encoding (BPE) tokenization in the style of GPT-4. +It offers two implementations: +1. **HuggingFace Tokenizer:** A wrapper around the `tokenizers` library. +2. **RustBPE + Tiktoken:** A combination of a custom `rustbpe` tokenizer for + training and `tiktoken` for efficient inference. This is the default and + recommended implementation for nanochat. -Two implementations are available: -1) HuggingFace Tokenizer that can do both training and inference but is really confusing -2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference +The tokenizer handles the conversion of text to token IDs and back, as well as +special tokens for structuring conversations. """ import os @@ -37,19 +41,24 @@ from tokenizers.models import BPE from tokenizers.trainers import BpeTrainer class HuggingFaceTokenizer: - """Light wrapper around HuggingFace Tokenizer for some utilities""" + """ + A wrapper around the Hugging Face `tokenizers` library, providing a consistent + interface for training and using BPE tokenizers. + """ def __init__(self, tokenizer): self.tokenizer = tokenizer @classmethod def from_pretrained(cls, hf_path): + """Loads a tokenizer from a pre-trained Hugging Face model.""" # init from a HuggingFace pretrained tokenizer (e.g. "gpt2") tokenizer = HFTokenizer.from_pretrained(hf_path) return cls(tokenizer) @classmethod def from_directory(cls, tokenizer_dir): + """Loads a tokenizer from a local directory.""" # init from a local directory on disk (e.g. "out/tokenizer") tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") tokenizer = HFTokenizer.from_file(tokenizer_path) @@ -57,6 +66,7 @@ class HuggingFaceTokenizer: @classmethod def train_from_iterator(cls, text_iterator, vocab_size): + """Trains a new tokenizer from an iterator of text.""" # train from an iterator of text # Configure the HuggingFace Tokenizer tokenizer = HFTokenizer(BPE( @@ -93,18 +103,21 @@ class HuggingFaceTokenizer: return cls(tokenizer) def get_vocab_size(self): + """Returns the size of the vocabulary.""" return self.tokenizer.get_vocab_size() def get_special_tokens(self): + """Returns a list of special tokens.""" special_tokens_map = self.tokenizer.get_added_tokens_decoder() special_tokens = [w.content for w in special_tokens_map.values()] return special_tokens def id_to_token(self, id): + """Converts a token ID to its string representation.""" return self.tokenizer.id_to_token(id) def _encode_one(self, text, prepend=None, append=None): - # encode a single string + """Encodes a single string.""" # prepend/append can be either a string of a special token or a token id directly. assert isinstance(text, str) ids = [] @@ -118,14 +131,17 @@ class HuggingFaceTokenizer: return ids def encode_special(self, text): + """Encodes a single special token.""" # encode a single special token via exact match return self.tokenizer.token_to_id(text) def get_bos_token_id(self): + """Returns the ID of the beginning-of-sequence token.""" bos = self.encode_special("<|bos|>") return bos def encode(self, text, *args, **kwargs): + """Encodes a string or a list of strings.""" if isinstance(text, str): return self._encode_one(text, *args, **kwargs) elif isinstance(text, list): @@ -134,12 +150,15 @@ class HuggingFaceTokenizer: raise ValueError(f"Invalid input type: {type(text)}") def __call__(self, *args, **kwargs): + """A convenience method to call `encode`.""" return self.encode(*args, **kwargs) def decode(self, ids): + """Decodes a sequence of token IDs into a string.""" return self.tokenizer.decode(ids, skip_special_tokens=False) def save(self, tokenizer_dir): + """Saves the tokenizer to a directory.""" # save the tokenizer to disk os.makedirs(tokenizer_dir, exist_ok=True) tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") @@ -153,7 +172,10 @@ import rustbpe import tiktoken class RustBPETokenizer: - """Light wrapper around tiktoken (for efficient inference) but train with rustbpe""" + """ + A tokenizer that uses `rustbpe` for training and `tiktoken` for efficient + inference. This is the default tokenizer for nanochat. + """ def __init__(self, enc, bos_token): self.enc = enc @@ -161,6 +183,7 @@ class RustBPETokenizer: @classmethod def train_from_iterator(cls, text_iterator, vocab_size): + """Trains a new tokenizer from an iterator of text.""" # 1) train using rustbpe tokenizer = rustbpe.Tokenizer() # the special tokens are inserted later in __init__, we don't train them here @@ -183,6 +206,7 @@ class RustBPETokenizer: @classmethod def from_directory(cls, tokenizer_dir): + """Loads a tokenizer from a local directory.""" pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl") with open(pickle_path, "rb") as f: enc = pickle.load(f) @@ -190,6 +214,7 @@ class RustBPETokenizer: @classmethod def from_pretrained(cls, tiktoken_name): + """Loads a tokenizer from a pre-trained tiktoken model.""" # https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py enc = tiktoken.get_encoding(tiktoken_name) # tiktoken calls the special document delimiter token "<|endoftext|>" @@ -199,22 +224,39 @@ class RustBPETokenizer: return cls(enc, "<|endoftext|>") def get_vocab_size(self): + """Returns the size of the vocabulary.""" return self.enc.n_vocab def get_special_tokens(self): + """Returns a set of special tokens.""" return self.enc.special_tokens_set def id_to_token(self, id): + """Converts a token ID to its string representation.""" return self.enc.decode([id]) @lru_cache(maxsize=32) def encode_special(self, text): + """Encodes a single special token.""" return self.enc.encode_single_token(text) def get_bos_token_id(self): + """Returns the ID of the beginning-of-sequence token.""" return self.bos_token_id def encode(self, text, prepend=None, append=None, num_threads=8): + """ + Encodes a string or a list of strings. + + Args: + text (str or list[str]): The text to encode. + prepend (int or str, optional): A token to prepend to the sequence. + append (int or str, optional): A token to append to the sequence. + num_threads (int, optional): The number of threads for batch encoding. + + Returns: + list[int] or list[list[int]]: The encoded token IDs. + """ # text can be either a string or a list of strings if prepend is not None: @@ -242,12 +284,15 @@ class RustBPETokenizer: return ids def __call__(self, *args, **kwargs): + """A convenience method to call `encode`.""" return self.encode(*args, **kwargs) def decode(self, ids): + """Decodes a sequence of token IDs into a string.""" return self.enc.decode(ids) def save(self, tokenizer_dir): + """Saves the tokenizer to a directory.""" # save the encoding object to disk os.makedirs(tokenizer_dir, exist_ok=True) pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl") @@ -257,10 +302,8 @@ class RustBPETokenizer: def render_conversation(self, conversation, max_tokens=2048): """ - Tokenize a single Chat conversation (which we call a "doc" or "document" here). - Returns: - - ids: list[int] is a list of token ids of this rendered conversation - - mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on. + Renders a conversation into a sequence of token IDs and a mask for training. + The mask indicates which tokens the assistant should be trained on. """ # ids, masks that we will return and a helper function to help build them up. ids, mask = [], [] @@ -342,7 +385,10 @@ class RustBPETokenizer: return ids, mask def visualize_tokenization(self, ids, mask, with_token_id=False): - """Small helper function useful in debugging: visualize the tokenization of render_conversation""" + """ + A helper function for visualizing the tokenization of a conversation, + with colors indicating the training mask. + """ RED = '\033[91m' GREEN = '\033[92m' RESET = '\033[0m' @@ -358,9 +404,8 @@ class RustBPETokenizer: def render_for_completion(self, conversation): """ - Used during Reinforcement Learning. In that setting, we want to - render the conversation priming the Assistant for a completion. - Unlike the Chat SFT case, we don't need to return the mask. + Renders a conversation for completion, priming the assistant to generate a + response. This is used during reinforcement learning. """ # We have some surgery to do: we need to pop the last message (of the Assistant) conversation = copy.deepcopy(conversation) # avoid mutating the original @@ -380,6 +425,7 @@ class RustBPETokenizer: # nanochat-specific convenience functions def get_tokenizer(): + """Returns the default nanochat tokenizer.""" from nanochat.common import get_base_dir base_dir = get_base_dir() tokenizer_dir = os.path.join(base_dir, "tokenizer") @@ -387,6 +433,7 @@ def get_tokenizer(): return RustBPETokenizer.from_directory(tokenizer_dir) def get_token_bytes(device="cpu"): + """Returns a tensor of byte lengths for each token in the vocabulary.""" import torch from nanochat.common import get_base_dir base_dir = get_base_dir() diff --git a/rustbpe/README.md b/rustbpe/README.md index c88636c..fde528f 100644 --- a/rustbpe/README.md +++ b/rustbpe/README.md @@ -1,5 +1,20 @@ -# rustbpe +# RustBPE Tokenizer -> The missing tiktoken training code +This directory contains a high-performance implementation of the Byte Pair Encoding (BPE) algorithm, written in Rust and exposed to Python using PyO3. This tokenizer is used for training the nanochat vocabulary and is designed to be fast, memory-efficient, and parallelized. -A very lightweight Rust library for training a GPT tokenizer. The issue is that the inference library [tiktoken](https://github.com/openai/tiktoken) is great, but only does inference. Separately, the huggingface [tokenizers](https://github.com/huggingface/tokenizers) library does training, but it is rather bloated and really hard to navigate because it has to support all the different historical baggage of how people dealt with tokenizers over the years. More recently, I also wrote the [minbpe](https://github.com/karpathy/minbpe) library which does both training and inference, but only in inefficient Python. Basically what I really want is a non-fancy, super simple, but still relatively efficient training code for GPT tokenizer (more efficient than minbpe, much cleaner/simpler than tokenizers), and then export the trained vocab for inference with tiktoken. Does that make sense? So here we are. There are more opportunities for optimization here, I just stopped a bit early because unlike minbpe before it, rustbpe is now simple and fast enough, and not a significant bottleneck for nanochat. +## What is BPE? + +Byte Pair Encoding is a data compression technique that is widely used in natural language processing for tokenization. It starts with a base vocabulary of individual characters (or bytes) and iteratively merges the most frequent adjacent pairs of tokens into a single new token. This process is repeated for a fixed number of merges, resulting in a vocabulary that can represent common words and subwords as single tokens, while still being able to handle rare words and out-of-vocabulary terms. + +## Why Rust? + +While the rest of the nanochat codebase is primarily in Python, the BPE training process is computationally intensive and can be a bottleneck. Rust was chosen for this component for several key reasons: + +- **Performance:** Rust offers performance comparable to C and C++, which is essential for processing large text corpora quickly. +- **Parallelism:** Rust's ownership model and libraries like Rayon make it easy to write safe and efficient parallel code, allowing the tokenizer to take full advantage of multi-core CPUs. +- **Safety:** Rust's strict compiler and borrow checker prevent common programming errors like null pointer dereferences and data races, leading to more robust and reliable code. +- **Interoperability:** With PyO3, it is straightforward to create Python bindings for Rust code, allowing seamless integration with the rest of the nanochat pipeline. + +## Role in nanochat + +The `rustbpe` tokenizer is used in the `tok_train.py` script to train a new vocabulary from the pretraining dataset. The trained tokenizer is then used to create a `tiktoken` encoder, which is used for efficient inference in the rest of the nanochat project. diff --git a/rustbpe/src/lib.rs b/rustbpe/src/lib.rs index 273d7f2..bb611ed 100644 --- a/rustbpe/src/lib.rs +++ b/rustbpe/src/lib.rs @@ -1,3 +1,7 @@ +//! This module provides a fast, parallel implementation of the Byte Pair Encoding (BPE) +//! algorithm, specifically tailored to match the GPT-4 style of tokenization. It is +//! written in Rust for performance and safety, and exposed to Python using PyO3. + use std::cmp::Ordering; use std::collections::HashMap as StdHashMap; @@ -158,9 +162,9 @@ fn count_pairs_parallel( impl Tokenizer { - /// Core incremental BPE training given unique words and their counts. - /// `words`: one entry per unique chunk (Vec of token-ids/bytes). - /// `counts`: same length as `words`, count per chunk. + /// The core BPE training algorithm. It iteratively finds the most frequent + /// pair of tokens and merges them into a new token, repeating this process + /// until the desired vocabulary size is reached. fn train_core_incremental(&mut self, mut words: Vec, counts: Vec, vocab_size: u32) { assert!(vocab_size >= 256, "vocab_size must be at least 256"); let num_merges = vocab_size - 256; @@ -259,7 +263,7 @@ impl Tokenizer { /// Public methods for the Tokenizer class that will be exposed to Python. #[pymethods] impl Tokenizer { - /// Create a new Tokenizer + /// Creates a new `Tokenizer`. #[new] pub fn new() -> Self { Self { @@ -269,9 +273,7 @@ impl Tokenizer { } } - /// Train from a streaming iterator (parallel ingestion). - /// We refill a Rust Vec buffer under the GIL, then release the GIL - /// to do the heavy splitting and counting **in parallel** with rayon. + /// Trains the tokenizer from a Python iterator of text. #[pyo3(signature = (iterator, vocab_size, buffer_size=8192, pattern=None))] #[pyo3(text_signature = "(self, iterator, vocab_size, buffer_size=8192, pattern=None)")] pub fn train_from_iterator( @@ -389,12 +391,12 @@ impl Tokenizer { Ok(()) } - /// Return the regex pattern + /// Returns the regex pattern used for tokenization. pub fn get_pattern(&self) -> String { self.pattern.clone() } - /// Return the mergeable ranks (token bytes -> token id / rank) + /// Returns the mergeable ranks for creating a `tiktoken` encoder. pub fn get_mergeable_ranks(&self) -> Vec<(Vec, u32)> { let mut mergeable_ranks = Vec::new(); @@ -425,7 +427,7 @@ impl Tokenizer { mergeable_ranks } - /// Encode a string into token IDs + /// Encodes a string into a sequence of token IDs. pub fn encode(&self, text: &str) -> Vec { let mut all_ids = Vec::new(); diff --git a/scripts/base_eval.py b/scripts/base_eval.py index 8efde4f..cc54a26 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -1,13 +1,15 @@ """ -Evlauate the CORE metric for a given model. +This script evaluates a base model on the CORE (Comprehensive Overall Language Evaluation) +metric. It can evaluate either a local nanochat model or a Hugging Face model. -Run on a single GPU: -python base_eval.py +The CORE benchmark provides a holistic assessment of a model's capabilities. This +script iterates through the tasks defined in `core.yaml`, evaluates the model on each, +and reports the accuracy and a "centered" score (normalized by a random baseline). -Run with torchrun on e.g. 8 GPUs: -torchrun --nproc_per_node=8 base_eval.py - -The script will print the CORE metric to the console. +Usage: +- Evaluate a local nanochat model: `python scripts/base_eval.py` +- Evaluate a Hugging Face model: `python scripts/base_eval.py --hf-path ` +- Distributed evaluation: `torchrun --nproc_per_node= scripts/base_eval.py` """ import os import sys @@ -30,9 +32,16 @@ from nanochat.core_eval import evaluate_task def evaluate_model(model, tokenizer, device, max_per_task=-1): """ - Evaluate a base model on the CORE benchmark. - - max_per_task: crop the data to this many examples per task for testing (-1 = disable) - TODO: clean up this function, delete the need for all the files, for pandas dependency, etc. + Evaluates a model on the CORE benchmark. + + Args: + model: The model to evaluate. + tokenizer: The tokenizer to use. + device (str): The device to run the evaluation on. + max_per_task (int, optional): Max examples per task. Defaults to -1 (no limit). + + Returns: + dict: A dictionary containing the evaluation results. """ # Load config and task metadata base_dir = get_base_dir() @@ -94,7 +103,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): # HuggingFace loading utilities and light wrappers for a model class ModelWrapper: - """Lightweight wrapper for a HuggingFace model""" + """A lightweight wrapper for Hugging Face models to match the nanochat API.""" def __init__(self, model, max_seq_len=None): self.model = model self.max_seq_len = max_seq_len @@ -105,6 +114,7 @@ class ModelWrapper: return logits def load_hf_model(hf_path: str, device): + """Loads a Hugging Face model and tokenizer.""" print0(f"Loading model from: {hf_path}") # Load the model from transformers import AutoModelForCausalLM diff --git a/scripts/base_loss.py b/scripts/base_loss.py index abcde5f..3a007db 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -1,10 +1,17 @@ """ -Loads a checkpoint, and: -- Evaluates the loss on a larger chunk of train/val splits -- Samples from the model +This script evaluates the loss of a trained base model and generates samples from it. +It serves as a quick sanity check to ensure the model has learned sensible +representations. -Example run as: -torchrun --standalone --nproc_per_node=8 -m scripts.base_loss +The script performs two main functions: +1. **Loss Evaluation:** It calculates the bits-per-byte (BPB) metric on both the + training and validation splits of the dataset. +2. **Sampling:** The master process generates text samples from a set of predefined + prompts to provide a qualitative assessment of the model's capabilities. + +Usage: +- To run on a single GPU: `python scripts/base_loss.py` +- For distributed evaluation: `torchrun --nproc_per_node= scripts/base_loss.py` """ import os from contextlib import nullcontext diff --git a/scripts/base_train.py b/scripts/base_train.py index 2570a72..4005de7 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -1,14 +1,21 @@ """ -Train model. Run as: +This script trains a base GPT model from scratch on the pretraining dataset. +It is the first stage in the nanochat pipeline, responsible for learning fundamental +language representations. -python base_train.py +The script supports: +- Distributed training with `torchrun`. +- Mixed-precision training (`bfloat16`). +- A composite optimizer (Muon + AdamW). +- Learning rate and momentum scheduling. +- Periodic evaluation on validation data and the CORE benchmark. +- Logging to Weights & Biases. +- Checkpointing the final model. -or distributed as: - -torchrun --nproc_per_node=8 base_train.py - -If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: -python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20 +Usage: +- Single GPU: `python scripts/base_train.py` +- Distributed: `torchrun --nproc_per_node= scripts/base_train.py` +- CPU/MPS (for testing): `python scripts/base_train.py --depth=4 --num_iterations=20` """ import os diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index b14843a..88d6ec1 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -1,8 +1,15 @@ """ -New and upgraded chat mode because a lot of the code has changed since the last one. +This script provides a command-line interface (CLI) for interacting with a trained +nanochat model. It allows users to have a text-based conversation with the model +in the terminal. -Intended to be run single GPU only atm: -python -m scripts.chat_cli -i mid +The script can load a model from different training stages (mid-training, SFT, or RL) +and supports various generation parameters like temperature and top-k sampling. + +Usage: +- Interactive chat: `python scripts/chat_cli.py --source sft` +- Single prompt: `python scripts/chat_cli.py --prompt "Hello, world!"` +- Help: `python scripts/chat_cli.py --help` """ import argparse import torch diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index 616411d..7d87a4e 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -1,11 +1,14 @@ """ -Evaluate the Chat model. -All the generic code lives here, and all the evlauation-specific -code lives in nanochat directory and is imported from here. +This script evaluates a trained chat model on various downstream tasks, such as +MMLU, GSM8K, and HumanEval. It supports both generative and categorical evaluation +modes, depending on the task. -Example runs: -python -m scripts.chat_eval -a ARC-Easy -torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy +The script can be run in both single-GPU and distributed (DDP) modes. + +Usage: +- Evaluate on a single task: `python scripts/chat_eval.py --source sft -a ARC-Easy` +- Evaluate on multiple tasks: `python scripts/chat_eval.py -a "ARC-Easy|GSM8K"` +- Distributed evaluation: `torchrun --nproc_per_node= scripts/chat_eval.py -a ARC-Easy` """ import argparse @@ -29,7 +32,9 @@ from tasks.spellingbee import SpellingBee # Generative evaluation loop (we go one problem at a time, sample, evaluate) def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None): - + """ + Runs a generative evaluation task, where the model generates a free-form response. + """ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() device = model.get_device() @@ -88,7 +93,10 @@ def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_ # batches at a time and just check the logits for correct answer choices. def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None): - + """ + Runs a categorical evaluation task, where the model chooses from a set of options. + This is more efficient than generative evaluation as it can be done in batches. + """ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() device = model.get_device() bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored @@ -159,6 +167,7 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems def run_chat_eval(task_name, model, tokenizer, engine, batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50, max_problems=None): + """Initializes and runs a specific chat evaluation task.""" # Create the evaluation object task_module = { 'HumanEval': HumanEval, diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index bc78e79..394c1c1 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -1,19 +1,16 @@ """ -Reinforcement learning on GSM8K via "GRPO". +This script performs reinforcement learning on the GSM8K dataset using a simplified, +on-policy REINFORCE-like algorithm. -I put GRPO in quotes because we actually end up with something a lot -simpler and more similar to just REINFORCE: +The training process involves: +1. Sampling multiple completions for each problem in the GSM8K training set. +2. Calculating a reward for each completion based on whether it solves the problem. +3. Computing the policy gradient loss using the calculated advantages. +4. Updating the model parameters to maximize the expected reward. -1) Delete trust region, so there is no KL regularization to a reference model -2) We are on policy, so there's no need for PPO ratio+clip. -3) We use GAPO style normalization that is token-level, not sequence-level. -4) Instead of z-score normalization (r - mu)/sigma, only use (r - mu) as the advantage. - -1 GPU: -python -m scripts.chat_rl - -8 GPUs: -torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default +Usage: +- Single GPU: `python scripts/chat_rl.py` +- Distributed: `torchrun --nproc_per_node= scripts/chat_rl.py` """ import os @@ -77,6 +74,7 @@ print0(f"Calculated number of steps: {num_steps}") @torch.no_grad() def get_batch(): + """A generator that yields batches of rollouts for training.""" assistant_end = tokenizer.encode_special("<|assistant_end|>") # ok to use this token, it's only for padding and isn't used in the loss. rank_indices = range(ddp_rank, len(train_task), ddp_world_size) # each rank is responsible for different examples in the training data for example_idx in itertools.cycle(rank_indices): @@ -149,10 +147,9 @@ def run_gsm8k_eval(task, tokenizer, engine, top_k=50 ): """ - Evaluates GSM8K task and returns a list of records of evaluation outcomes. - In a distributed setting, all ranks cooperate but this function will NOT - do the reduction across ranks. This is the responsibility of the caller. - Because the evaluation can take a while, this function will yield records one by one. + Evaluates the model on the GSM8K task and yields evaluation records. + This function does not perform reduction across ranks; that is the + responsibility of the caller. """ max_examples = min(max_examples, len(task)) if max_examples is not None else len(task) for idx in range(ddp_rank, max_examples, ddp_world_size): diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index e6e4565..7578918 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -1,12 +1,13 @@ """ -Finetune a base model to be a chat model. -Run on one GPU e.g. for debugging: +This script performs supervised fine-tuning (SFT) on a base or mid-trained model +to adapt it for chat-based interactions. -python -m scripts.chat_sft +The script trains the model on a mixture of conversational and task-specific datasets, +using a masked loss function that only penalizes errors in the assistant's turns. -Or torchrun for training: - -torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft +Usage: +- Single GPU: `python scripts/chat_sft.py` +- Distributed: `torchrun --nproc_per_node= scripts/chat_sft.py` """ import os @@ -96,6 +97,7 @@ val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don # DataLoader def sft_data_generator(dataset, batch_size): + """A generator that yields batches of tokenized and collated SFT data.""" pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss # prepares a list of tokenized conversations into a batch and yields def collate_and_yield(batch): diff --git a/scripts/chat_web.py b/scripts/chat_web.py index d7479c7..9091ea8 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -1,33 +1,14 @@ #!/usr/bin/env python3 """ -Unified web chat server - serves both UI and API from a single FastAPI instance. +This script launches a FastAPI web server for interacting with a trained nanochat model. +It provides a web UI for chatting and an API endpoint for programmatic access. -Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads -a full copy of the model, and incoming requests are distributed to available workers. +The server uses a worker pool to manage model instances on multiple GPUs, allowing +it to handle concurrent requests efficiently. -Launch examples: - -- single available GPU (default) -python -m scripts.chat_web - -- 4 GPUs -python -m scripts.chat_web --num-gpus 4 - -To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP) - -Endpoints: - GET / - Chat UI - POST /chat/completions - Chat API (streaming only) - GET /health - Health check with worker pool status - GET /stats - Worker pool statistics and GPU utilization - -Abuse Prevention: - - Maximum 500 messages per request - - Maximum 8000 characters per message - - Maximum 32000 characters total conversation length - - Temperature clamped to 0.0-2.0 - - Top-k clamped to 1-200 - - Max tokens clamped to 1-4096 +Usage: +- Single GPU: `python scripts/chat_web.py` +- Multi-GPU: `python scripts/chat_web.py --num-gpus 4` """ import argparse diff --git a/scripts/mid_train.py b/scripts/mid_train.py index eedb262..8c02343 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -1,12 +1,14 @@ """ -Midtrain the model. Same as pretraining but simpler. -Run as: +This script performs "mid-training," a stage of continued pre-training on a mixture +of conversational and task-specific data. It serves as an intermediate step between +the initial base model pre-training and the final supervised fine-tuning (SFT). -python -m scripts.mid_train +The goal of mid-training is to adapt the base model to the format and style of +chat conversations and to introduce it to various tasks like math and general knowledge. -Or torchrun for training: - -torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16 +Usage: +- Single GPU: `python scripts/mid_train.py` +- Distributed: `torchrun --nproc_per_node= scripts/mid_train.py` """ from collections import deque @@ -115,6 +117,7 @@ val_dataset = TaskMixture([ last_step = False # we will toggle this to True when we reach the end of the dataset approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch def mid_data_generator(split): + """A generator that yields batches of tokenized data for mid-training.""" global last_step, approx_progress assert split in {"train", "val"}, "split must be 'train' or 'val'" dataset = train_dataset if split == "train" else val_dataset diff --git a/scripts/tok_eval.py b/scripts/tok_eval.py index 9233d71..40ab738 100644 --- a/scripts/tok_eval.py +++ b/scripts/tok_eval.py @@ -1,5 +1,11 @@ """ -Evaluate compression ratio of the tokenizer. +This script evaluates the compression ratio of the nanochat tokenizer against +standard tokenizers like GPT-2 and GPT-4. The compression ratio, defined as the +number of bytes in the raw text divided by the number of tokens, is a key metric +for tokenizer efficiency. + +Usage: +- To run the evaluation: `python scripts/tok_eval.py` """ from nanochat.tokenizer import get_tokenizer, RustBPETokenizer @@ -201,7 +207,7 @@ print(f"GPT-4: {vocab_sizes['gpt4']}") print(f"Ours: {vocab_sizes['ours']}") def print_comparison(baseline_name, baseline_results, ours_results, all_text): - """Print comparison table between baseline tokenizer and ours.""" + """Prints a formatted comparison table to the console.""" print(f"\nComparison with {baseline_name}:") print("=" * 95) print(f"{'Text Type':<10} {'Bytes':<8} {baseline_name:<15} {'Ours':<15} {'Relative':<12} {'Better':<10}") diff --git a/scripts/tok_train.py b/scripts/tok_train.py index c2faf17..94f3987 100644 --- a/scripts/tok_train.py +++ b/scripts/tok_train.py @@ -1,6 +1,15 @@ """ -Train a tokenizer using the HuggingFace Tokenizers library. -In the style of GPT-4 tokenizer. +This script trains a Byte Pair Encoding (BPE) tokenizer from the pretraining dataset. + +The training process involves: +1. Iterating through the text data from the Parquet files. +2. Using the `rustbpe` library to learn a vocabulary and merge rules. +3. Saving the trained tokenizer to disk. +4. Creating a `token_bytes.pt` file for calculating the bits-per-byte (BPB) metric. + +Usage: +- Default: `python scripts/tok_train.py` +- Custom vocab size: `python scripts/tok_train.py --vocab_size 32768` """ import os import time @@ -27,9 +36,8 @@ print(f"vocab_size: {args.vocab_size:,}") def text_iterator(): """ - 1) Flatten the batches into a single iterator - 2) Crop every document to args.doc_cap characters - 3) Break when we've seen args.max_chars characters + An iterator that yields documents from the pretraining dataset, capped at + the specified character limits. """ nchars = 0 for batch in parquets_iter_batched(split="train"): diff --git a/tasks/arc.py b/tasks/arc.py index 862cca9..553d3d5 100644 --- a/tasks/arc.py +++ b/tasks/arc.py @@ -1,12 +1,23 @@ """ -The ARC dataset from Allen AI. -https://huggingface.co/datasets/allenai/ai2_arc +This module implements the AI2 Reasoning Challenge (ARC) task. The ARC dataset is +a collection of multiple-choice science questions designed to test a model's +reasoning and common-sense knowledge. + +**Reference:** +- The ARC dataset: https://huggingface.co/datasets/allenai/ai2_arc """ from datasets import load_dataset -from tasks.common import Task, render_mc +from .common import Task, render_mc class ARC(Task): + """ + The ARC (AI2 Reasoning Challenge) task. + + Args: + subset (str): "ARC-Easy" or "ARC-Challenge". + split (str): "train", "validation", or "test". + """ def __init__(self, subset, split, **kwargs): super().__init__(**kwargs) @@ -16,12 +27,17 @@ class ARC(Task): @property def eval_type(self): + """Specifies that this is a categorical evaluation task.""" return 'categorical' def num_examples(self): + """Returns the total number of examples in the dataset.""" return len(self.ds) def get_example(self, index): + """ + Formats a single example from the dataset into a conversation dictionary. + """ row = self.ds[index] question = row["question"] # the question text choices = row["choices"]["text"] # the text of each choice @@ -41,6 +57,16 @@ class ARC(Task): return conversation def evaluate(self, conversation, assistant_response): + """ + Evaluates the model's response for a given example. + + Args: + conversation (dict): The conversation dictionary for the example. + assistant_response (str): The model's predicted answer. + + Returns: + bool: True if the prediction is correct, False otherwise. + """ # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it. assert assistant_response in conversation['letters'], f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}" diff --git a/tasks/common.py b/tasks/common.py index dcd2e91..167b55d 100644 --- a/tasks/common.py +++ b/tasks/common.py @@ -1,15 +1,22 @@ """ -Base class for all Tasks. -A Task is basically a dataset of conversations, together with some -metadata and often also evaluation criteria. -Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk. +This module provides the base classes and common utilities for defining evaluation +and training tasks in nanochat. + +The core components are: +- `Task`: An abstract base class that represents a dataset of conversations. +- `TaskMixture`: A class for combining multiple tasks into a single, shuffled dataset. +- `TaskSequence`: A class for combining multiple tasks in a sequential manner. +- `render_mc`: A helper function for formatting multiple-choice questions. """ import random class Task: """ - Base class of a Task. Allows for lightweight slicing of the underlying dataset. + Abstract base class for a task, which is essentially a dataset of conversations. + + This class supports lightweight slicing of the underlying dataset using `start`, + `stop`, and `step` parameters, similar to Python's list slicing. """ def __init__(self, start=0, stop=None, step=1): @@ -23,16 +30,20 @@ class Task: @property def eval_type(self): + """The type of evaluation for this task, either 'generative' or 'categorical'.""" # one of 'generative' | 'categorical' raise NotImplementedError def num_examples(self): + """Returns the total number of examples in the underlying dataset.""" raise NotImplementedError def get_example(self, index): + """Retrieves a single example from the underlying dataset by its physical index.""" raise NotImplementedError def __len__(self): + """Returns the number of examples in the (potentially sliced) view of the dataset.""" start = self.start stop = self.num_examples() if self.stop is None else self.stop step = self.step @@ -42,19 +53,22 @@ class Task: return num def __getitem__(self, index: int): + """Retrieves an example by its logical index in the (sliced) view.""" assert isinstance(index, int), f"Index must be an integer, got {type(index)}" physical_index = self.start + index * self.step conversation = self.get_example(physical_index) return conversation def evaluate(self, problem, completion): + """Evaluates a model's completion for a given problem.""" raise NotImplementedError class TaskMixture(Task): """ - For SFT Training it becomes useful to train on a tax mixture of datasets. - Fun trick: if you wish to oversample any task, just pass it in multiple times in the list. + Combines multiple tasks into a single, deterministically shuffled dataset. + This is useful for creating a diverse training mixture for SFT. To oversample + a task, simply include it multiple times in the `tasks` list. """ def __init__(self, tasks, **kwargs): @@ -88,8 +102,8 @@ class TaskMixture(Task): class TaskSequence(Task): """ - For SFT Training sometimes we want to sequentially train on a list of tasks. - This is useful for cases that require a training curriculum. + Combines multiple tasks sequentially, which is useful for creating a + training curriculum. """ def __init__(self, tasks, **kwargs): @@ -111,19 +125,15 @@ class TaskSequence(Task): def render_mc(question, letters, choices): """ - The common multiple choice rendering format we will use. + Formats a multiple-choice question into a standardized prompt. - Note two important design decisions: - 1) - Bigger models don't care as much, but smaller models prefer to have - the letter *after* the choice, which results in better binding. - 2) - There is no whitespace between the delimiter (=) and the letter. - This is actually critical because the tokenizer has different token ids - for " A" vs. "A". The assistant responses will be just the letter itself, - i.e. "A", so it is important that here in the prompt it is the exact same - token, i.e. "A" with no whitespace before it. Again, bigger models don't care - about this too much, but smaller models do care about some of these details. + Args: + question (str): The question text. + letters (list[str]): The letters for the choices (e.g., ['A', 'B', 'C']). + choices (list[str]): The text of the choices. + + Returns: + str: The formatted prompt. """ query = f"Multiple Choice question: {question}\n" query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)]) diff --git a/tasks/customjson.py b/tasks/customjson.py index f4683c8..6f80bd1 100644 --- a/tasks/customjson.py +++ b/tasks/customjson.py @@ -1,11 +1,15 @@ """ -CustomJSON task for loading conversations from JSONL files. -Each line in the JSONL file should be a JSON array of messages. +This module implements the `CustomJSON` task, which allows for loading conversational +data from a custom JSONL file. This is useful for fine-tuning the model on specific +datasets, such as synthetic data for instilling a persona. + +Each line in the JSONL file should be a JSON array of message objects, where each +object has a "role" and a "content" field. """ import os import json -from tasks.common import Task +from .common import Task class CustomJSON(Task): """ @@ -15,6 +19,12 @@ class CustomJSON(Task): """ def __init__(self, filepath, **kwargs): + """ + Initializes the CustomJSON task. + + Args: + filepath (str): The path to the JSONL file. + """ super().__init__(**kwargs) self.filepath = filepath self.conversations = [] @@ -54,9 +64,19 @@ class CustomJSON(Task): self.length = len(self.conversations) def num_examples(self): + """Returns the total number of conversations loaded from the file.""" return self.length def get_example(self, index): + """ + Retrieves a single conversation by its index. + + Args: + index (int): The index of the conversation to retrieve. + + Returns: + dict: A dictionary representing the conversation. + """ messages = self.conversations[index] conversation = { "messages": messages, diff --git a/tasks/gsm8k.py b/tasks/gsm8k.py index c05e21c..2e31493 100644 --- a/tasks/gsm8k.py +++ b/tasks/gsm8k.py @@ -1,31 +1,23 @@ """ -GSM8K evaluation. -https://huggingface.co/datasets/openai/gsm8k +This module implements the GSM8K (Grade School Math 8K) task. This dataset consists +of grade school math word problems that require multi-step reasoning. -Example problem instance: +A unique feature of this dataset is its use of "tool calls" in the answers, +denoted by `<>`. This module parses these tool calls into a +structured conversational format for fine-tuning the model's tool-use capabilities. -Question: -Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn? -Answer: -Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute. -Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10. -#### 10 - -Notice that GSM8K uses tool calls inside << >> tags. +**Reference:** +- The GSM8K dataset: https://huggingface.co/datasets/openai/gsm8k """ import re from datasets import load_dataset -from tasks.common import Task +from .common import Task GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)") def extract_answer(completion): - """ - Extract the numerical answer after #### marker. - Follows official code for normalization: - https://github.com/openai/grade-school-math/blob/3101c7d5072418e28b9008a6636bde82a006892c/grade_school_math/dataset.py#L28 - """ + """Extracts the numerical answer from a GSM8K completion string.""" match = GSM_RE.search(completion) if match: match_str = match.group(1).strip() @@ -35,6 +27,13 @@ def extract_answer(completion): class GSM8K(Task): + """ + The GSM8K (Grade School Math 8K) task. + + Args: + subset (str): The subset of the dataset, either "main" or "socratic". + split (str): The data split, either "train" or "test". + """ def __init__(self, subset, split, **kwargs): super().__init__(**kwargs) @@ -44,13 +43,17 @@ class GSM8K(Task): @property def eval_type(self): + """Specifies that this is a generative evaluation task.""" return 'generative' def num_examples(self): + """Returns the total number of examples in the dataset.""" return len(self.ds) def get_example(self, index): - """ Get a single problem from the dataset. """ + """ + Formats a single example, parsing tool calls into a structured conversation. + """ row = self.ds[index] question = row['question'] # string of the question prompt answer = row['answer'] # string of the full solution and the answer after #### marker @@ -86,13 +89,8 @@ class GSM8K(Task): def evaluate(self, conversation, assistant_response): """ - Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct) - Note that: - - the conversation has both user AND assistant message (containing the ground truth answer) - - the assistant_response is usually the alternative assistant message achieved via sampling - - TODO: Technically, assistant_response should be a Message (either a string or a list of parts) - We can handle this later possibly. For now just assume string. + Evaluates the model's response by comparing the extracted numerical answer + to the ground truth. """ assert isinstance(assistant_response, str), "Assuming simple string response for now" # First extract the ground truth answer @@ -109,8 +107,8 @@ class GSM8K(Task): def reward(self, conversation, assistant_response): """ - Used during RL. To keep things simple, just re-use the evaluation above. - Later this could be made more complex (e.g. format matching etc.) + Provides a reward for reinforcement learning, which is simply whether the + answer was correct or not. """ is_correct = self.evaluate(conversation, assistant_response) is_correct_float = float(is_correct) diff --git a/tasks/humaneval.py b/tasks/humaneval.py index e9dd489..ec3efee 100644 --- a/tasks/humaneval.py +++ b/tasks/humaneval.py @@ -1,13 +1,20 @@ """ -Evaluate the Chat model on HumanEval dataset. -Btw this dataset is a misnomer and has nothing to do with humans. -It is a coding benchmark. +This module implements the HumanEval task, a benchmark for evaluating the code +generation capabilities of language models. + +The task is implemented as a `generative` evaluation. For each problem, the model +is given a function signature and docstring and is expected to generate the body +of the function. The generated code is then executed in a sandboxed environment +against a set of unit tests to determine its correctness. + +**Reference:** +- The HumanEval dataset: https://huggingface.co/datasets/openai/openai_humaneval """ import re from datasets import load_dataset from nanochat.execution import execute_code -from tasks.common import Task +from .common import Task def extract_imports(prompt): """Extract import statements from the beginning of a code block.""" @@ -23,14 +30,8 @@ def extract_imports(prompt): def extract_program(completion): """ - Extract Python code from LLM completion. - - Handles various output formats: - - Code wrapped in ```python ... ``` or ``` ... ``` blocks - - Plain code without markdown blocks - - Extra text before/after code blocks - - Returns the first code block if found, otherwise returns the whole completion. + Extracts a Python code block from a language model's completion, + handling markdown formatting. """ # Try to find markdown code blocks (```python or just ```) # Match ```python\n...\n``` or ```\n...\n``` @@ -45,20 +46,26 @@ def extract_program(completion): return completion.strip() class HumanEval(Task): - + """ + The HumanEval code generation task. + """ def __init__(self, **kwargs): super().__init__(**kwargs) self.ds = load_dataset("openai/openai_humaneval", split="test").shuffle(seed=42) @property def eval_type(self): + """Specifies that this is a generative evaluation task.""" return 'generative' def num_examples(self): + """Returns the total number of examples in the dataset.""" return len(self.ds) def get_example(self, index): - """ Get a single problem from the dataset. """ + """ + Formats a single problem from the dataset into a conversation dictionary. + """ row = self.ds[index] prompt = row['prompt'] # prompts in HumanEval are the beginning of the program solution = row['canonical_solution'] # the correct continuation of the program @@ -77,7 +84,10 @@ class HumanEval(Task): return conversation def evaluate(self, conversation, completion): - """ Given (conversation, completion), return boolean success of the completion. """ + """ + Evaluates the model's generated code by running it against the problem's + unit tests in a sandboxed environment. + """ # the prompt will contain the imports and the function signature imports = extract_imports(conversation['messages'][0]['content']) # the completion will usually contain the whole function diff --git a/tasks/mmlu.py b/tasks/mmlu.py index 3ba2254..7eb7702 100644 --- a/tasks/mmlu.py +++ b/tasks/mmlu.py @@ -1,17 +1,40 @@ -""" -The MMLU dataset. -https://huggingface.co/datasets/cais/mmlu -""" +#-*--*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-# +#_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_# +# # +# This is the MMLU dataset. # +# https://huggingface.co/datasets/cais/mmlu # +# # +#_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_# +#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-# from datasets import load_dataset from tasks.common import Task, render_mc class MMLU(Task): + """ + The MMLU class is a task that evaluates a model's performance on the MMLU dataset. + MMLU (Massive Multitask Language Understanding) is a benchmark designed to measure knowledge + acquired during pretraining by evaluating models exclusively in zero-shot and few-shot settings. + This makes the benchmark more challenging and more similar to how we evaluate humans. + The benchmark covers 57 subjects across STEM, the humanities, the social sciences, and more. + It ranges in difficulty from an elementary level to a professional level, + and it tests both world knowledge and problem solving abilities. + Subjects include elementary mathematics, US history, computer science, law, and more. + """ + # The letters used to label the multiple choice options. letters = ('A', 'B', 'C', 'D') + # A list of all the subject groups in the MMLU dataset. groups = ('abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions') def __init__(self, subset, split, **kwargs): + """ + Initializes the MMLU task. + Args: + subset (str): The subset of the dataset to use. Must be 'all' or 'auxiliary_train'. + split (str): The split of the dataset to use. Must be 'train', 'validation', 'dev', or 'test'. + **kwargs: Additional keyword arguments. + """ super().__init__(**kwargs) assert subset in ["all", "auxiliary_train"], f"subset {subset} must be all|auxiliary_train" assert split in ["train", "validation", "dev", "test"], f"split {split} must be train|validation|dev|test" @@ -21,21 +44,36 @@ class MMLU(Task): self.split = split self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42) if subset == "auxiliary_train": - # I don't understand why but the auxiliary_train rows have some weird additional 'train' wrapper + # The 'auxiliary_train' subset has a nested structure where the actual data is in a 'train' column. + # This mapping function unnests the data, making it consistent with the other subsets. self.ds = self.ds.map(lambda row: row['train'], remove_columns=['train']) @property def eval_type(self): + """ + Returns the evaluation type for this task. + MMLU is a multiple-choice task, so the evaluation is categorical. + """ return 'categorical' def num_examples(self): + """ + Returns the total number of examples in the dataset. + """ return len(self.ds) def get_example(self, index): + """ + Retrieves a single example from the dataset at the specified index. + Args: + index (int): The index of the example to retrieve. + Returns: + dict: A dictionary representing the conversation, including messages, subject, and letters for choices. + """ row = self.ds[index] - question = row["question"] # the question text - choices = row["choices"] # the text of each choice - answer = row["answer"] # index of the answer, e.g. 0,1,2,3 (for A,B,C,D) + question = row["question"] # The question text + choices = row["choices"] # The text of each choice + answer = row["answer"] # Index of the answer, e.g. 0,1,2,3 (for A,B,C,D) subject = row["subject"] # e.g. "college_biology", "college_chemistry", etc. assert len(choices) == 4, "MMLU should have 4 choices" # create and return the Conversation object @@ -47,14 +85,22 @@ class MMLU(Task): ] conversation = { "messages": messages, - "subject": subject, # might be useful later for grouping metrics by subject - "letters": self.letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters + "subject": subject, # useful for grouping metrics by subject later + "letters": self.letters, # useful during evaluation to constrain the assistant's prediction } return conversation def evaluate(self, conversation, assistant_response): - # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true - # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it. + """ + Evaluates the model's response against the correct answer. + Args: + conversation (dict): The conversation dictionary containing the context. + assistant_response (str): The model's response. + Returns: + bool: True if the assistant's response is correct, False otherwise. + """ + # This assert ensures that the model's response is one of the valid choices. + # This is a safeguard to prevent unexpected evaluation behavior. assert assistant_response in self.letters, f"MMLU answer {assistant_response} is expected to be one of {self.letters}" assistant_message = conversation['messages'][-1]['content'] # e.g. "A" return assistant_response == assistant_message diff --git a/tasks/smoltalk.py b/tasks/smoltalk.py index b4d4f5f..db0480c 100644 --- a/tasks/smoltalk.py +++ b/tasks/smoltalk.py @@ -1,45 +1,92 @@ -""" -SmolTalk by HuggingFace. Good "general" conversational dataset. -https://huggingface.co/datasets/HuggingFaceTB/smol-smoltalk -We use the "smol" version, which is more appropriate for smaller models. -""" +#--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*# +#_-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*# +# # +# SmolTalk by HuggingFace. Good "general" conversational dataset. # +# https://huggingface.co/datasets/HuggingFaceTB/smol-smoltalk # +# We use the "smol" version, which is more appropriate for smaller models.# +# # +#_-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*# +#--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*# from datasets import load_dataset from tasks.common import Task class SmolTalk(Task): - """ smol-smoltalk dataset. train is 460K rows, test is 24K rows. """ + """ + The SmolTalk class handles the smol-smoltalk dataset, a conversational dataset from HuggingFace. + It's designed for general-purpose conversational models and is particularly suited for smaller models due to its size. + The training set contains approximately 460,000 examples, while the test set has around 24,000. + + Python equivalent: + A dictionary where keys are split names ('train', 'test') and values are lists of conversations. + Each conversation is a list of dictionaries, where each dictionary has 'role' and 'content' keys. + Example: + { + "train": [ + [ + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there! How can I help you today?"} + ], + # ... more conversations + ], + "test": [ + # ... test conversations + ] + } + """ def __init__(self, split, **kwargs): + """ + Initializes the SmolTalk task. + Args: + split (str): The dataset split to load, must be either "train" or "test". + **kwargs: Additional keyword arguments passed to the parent Task class. + """ super().__init__(**kwargs) assert split in ["train", "test"], "SmolTalk split must be train|test" + # Load the specified split of the dataset and shuffle it for randomness. self.ds = load_dataset("HuggingFaceTB/smol-smoltalk", split=split).shuffle(seed=42) self.length = len(self.ds) def num_examples(self): + """ + Returns the total number of examples in the loaded dataset split. + """ return self.length def get_example(self, index): + """ + Retrieves a single conversational example from the dataset. + Args: + index (int): The index of the example to retrieve. + Returns: + dict: A dictionary containing the conversation messages. + """ row = self.ds[index] messages = row["messages"] # --------------------------------------------------------------------- - # sanity checking asserts here - # TODO: we could remove these asserts later, for now just don't want any footguns - # there is an optional system message at the beginning + # Perform sanity checks to ensure the data format is as expected. + # These asserts can be removed later for performance, but are useful for debugging. + + # A conversation can optionally start with a system message. assert len(messages) >= 1 first_message = messages[0] if first_message["role"] == "system": - rest_messages = messages[1:] # optional system message is OK + rest_messages = messages[1:] # The rest of the conversation after the system message. else: rest_messages = messages + + # There should be at least one user-assistant exchange. assert len(rest_messages) >= 2, "SmolTalk messages must have at least 2 messages" + + # Check that roles alternate correctly (user, assistant, user, ...). for i, message in enumerate(rest_messages): - # user and assistant alternate as user,assistant,user,assistant,... expected_role = "user" if i % 2 == 0 else "assistant" assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" assert isinstance(message["content"], str), "Content must be a string" # --------------------------------------------------------------------- - # create and return the Conversation object (ok to emit the system message too) + + # Return the conversation in the standard format. conversation = { "messages": messages, } diff --git a/tasks/spellingbee.py b/tasks/spellingbee.py index c051fe7..5691577 100644 --- a/tasks/spellingbee.py +++ b/tasks/spellingbee.py @@ -1,29 +1,31 @@ +#--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*# +#_-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*# +# # +# The Spelling Bee Task # +# # +#_-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*# +#--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*--*# """ -Task intended to make nanochat better in spelling and counting, for example: +This module defines tasks intended to improve a model's spelling and counting abilities. -"How many r are in strawberry?" -> 3 +For example, a task might be: "How many 'r's are in strawberry?" -> 3 -An interesting part of this task is that we will get the assistant to -solve the problem using a combination of manual counting and Python. -This is a good problem solving "instinct" to mix into the model and RL -may further refine it to trust one over the other. If we were extra fancy -(which we could/should be) we'd add small errors here and there to allow -the model also learn recoveries. We can do this in future versions. +A key feature of this task is that the assistant is guided to solve the problem +by combining manual counting with Python code verification. This promotes a robust +problem-solving process in the model. Future versions could introduce small errors +to train the model on error detection and recovery. -There are two tasks in this file: -1. SpellingBee: Counting the number of occurrences of a letter in a word -2. SimpleSpelling: Simply spelling words +This file contains two main tasks: +1. SpellingBee: Counts the occurrences of a specific letter in a word. +2. SimpleSpelling: A simpler task focused on correctly spelling words. -(1) is the goal, but (2) exists as a highly condensed version of the part -that makes (1) difficult, which is word spelling. This is non-trivial for an -LLM because it has to learn how every token (a little semantic chunk/atom) -maps to the sequence of individual characters that make it up. Larger models -learn this eventually on their own, but if we want this capability to exist -in smaller models, we have to actively encourage it by over-representing it -in the training data. Midtraining is a good place to do this. +The primary goal is (1), but (2) is included to address a fundamental challenge for LLMs: +mapping tokens (semantic units) to the individual characters that form a word. +While larger models often learn this implicitly, smaller models benefit from explicit +training on this skill. -To preview a few example conversations, run: -python -m tasks.spellingbee +To preview examples from these tasks, run this script directly: +`python -m tasks.spellingbee` """ import re @@ -31,16 +33,22 @@ import random from tasks.common import Task from nanochat.common import download_file_with_lock -# Letters of the alphabet +# Define the alphabet for random letter selection. LETTERS = "abcdefghijklmnopqrstuvwxyz" -# A list of 370K English words of large variety +# URL for a comprehensive list of English words. WORD_LIST_URL = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words_alpha.txt" -# Identical to gsm8k's answer extraction +# Regex to find the final numerical answer, same as in the gsm8k task. ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)") def extract_answer(completion): """ - Extract the numerical answer after #### marker. + Extracts the numerical answer from a string, which is marked with "####". + This function is designed to parse the final answer from the model's output. + It handles integers, floats, and commas. + + For example: + - "The answer is #### 42." -> "42" + - "After calculation, we get #### 3,141.59" -> "3141.59" """ match = ANSWER_RE.search(completion) if match: @@ -49,7 +57,9 @@ def extract_answer(completion): return match_str return None -# User message templates for data augmentation +# A diverse set of templates for user messages to augment the training data. +# This helps the model generalize to different ways a user might ask the same question. +# Includes templates in multiple languages for broader applicability. USER_MSG_TEMPLATES = [ "How many {letter} are in the word {word}", "How many {letter} are in {word}", @@ -111,12 +121,26 @@ USER_MSG_TEMPLATES = [ ] class SpellingBee(Task): + """ + A task to count the occurrences of a letter in a word. + The assistant's response is structured to first perform a manual count, + then verify the result using a Python tool call. This encourages a + "show your work" and "double-check" approach. + """ def __init__(self, size=1000, split="train", **kwargs): + """ + Initializes the SpellingBee task. + Args: + size (int): The number of examples to generate for this task. + split (str): The dataset split, either "train" or "test". + **kwargs: Additional arguments for the parent Task class. + """ super().__init__(**kwargs) assert split in ["train", "test"], "SpellingBee split must be train|test" self.size = size self.split = split + # Download the word list if it's not already cached. filename = WORD_LIST_URL.split("/")[-1] word_list_path = download_file_with_lock(WORD_LIST_URL, filename) with open(word_list_path) as f: @@ -125,40 +149,50 @@ class SpellingBee(Task): @property def eval_type(self): + """ This task requires a generative evaluation, as the response format is complex. """ return 'generative' def num_examples(self): + """ Returns the number of examples in this task. """ return self.size def get_example(self, index): - seed = index if self.split == "train" else -(index + 1) # avoid collision at 0 + """ + Generates a single example for the SpellingBee task. + Args: + index (int): An index to seed the random number generator for reproducibility. + Returns: + dict: A conversation dictionary representing the task example. + """ + # Use the index to seed the random generator for deterministic example generation. + seed = index if self.split == "train" else -(index + 1) rng = random.Random(seed) - # pick a random word + # Select a random word and a letter to count. word = rng.choice(self.words) - # pick a letter from it (90%) or a random letter (10%) + # Usually pick a letter from the word, but sometimes a random one. letter = rng.choice(word) if rng.random() < 0.9 else rng.choice(LETTERS) - # get the correct answer by simply counting + # Calculate the correct answer. count = word.count(letter) - # create a user message, with a bunch of variations as data augmentation + # Create a user message using a random template for variety. template = rng.choice(USER_MSG_TEMPLATES) - # 30% chance to lowercase the template (lazy people don't use shift) if rng.random() < 0.3: template = template.lower() quote_options = ['', "'", '"'] - letter_quote = rng.choice(quote_options) # is the letter quoted? - word_quote = rng.choice(quote_options) # is the word quoted? + letter_quote = rng.choice(quote_options) + word_quote = rng.choice(quote_options) letter_wrapped = f"{letter_quote}{letter}{letter_quote}" word_wrapped = f"{word_quote}{word}{word_quote}" user_msg = template.format(letter=letter_wrapped, word=word_wrapped) - if rng.random() < 0.5: # 50% of people don't even use question marks + if rng.random() < 0.5: user_msg += "?" - # Now create the ideal assistant response - build as parts (text + tool calls) + # Construct the ideal assistant response as a series of parts. assistant_parts = [] word_letters = ",".join(list(word)) + # Part 1: Manual counting process. manual_text = f"""We are asked to find the number '{letter}' in the word '{word}'. Let me try a manual approach first. First spell the word out: @@ -166,33 +200,27 @@ First spell the word out: Then count the occurrences of '{letter}': """ - # Little simulated loop of the solution process - # TODO: This is where the fun starts, we could simulate cute little mistakes - # and get the model to review its work and recover from them. - # You might of course hope this could arise in RL too, but realistically you'd want to help it out a bit. running_count = 0 for i, char in enumerate(word, 1): if char == letter: running_count += 1 - # note: there deliberately cannot be a space here between i and char - # because this would create a different token! (e.g. " a" and "a" are different tokens) manual_text += f"{i}:{char} hit! count={running_count}\n" else: manual_text += f"{i}:{char}\n" manual_text += f"\nThis gives us {running_count}." assistant_parts.append({"type": "text", "text": manual_text}) - # Part 2: Python verification + # Part 2: Transition to Python verification. assistant_parts.append({"type": "text", "text": "\n\nLet me double check this using Python:\n\n"}) - # Part 3: Python tool call + # Part 3: The Python tool call itself. python_expr = f"'{word}'.count('{letter}')" assistant_parts.append({"type": "python", "text": python_expr}) - # Part 4: Python output + # Part 4: The output from the Python tool. assistant_parts.append({"type": "python_output", "text": str(count)}) - # Part 5: Final answer + # Part 5: The final conclusion. assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"}) - # return the full conversation + # Assemble the full conversation. messages = [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": assistant_parts} @@ -204,34 +232,53 @@ Then count the occurrences of '{letter}': def evaluate(self, conversation, assistant_response): """ - Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct) - Identical to gsm8k's evaluation. + Evaluates the assistant's response to determine if it's correct. + This is similar to the evaluation in the gsm8k task. + Args: + conversation (dict): The original conversation. + assistant_response (str): The generated response from the assistant. + Returns: + int: 1 if the answer is correct, 0 otherwise. """ - assert isinstance(assistant_response, str), "Assuming simple string response for now" - # First extract the ground truth answer from the conversation + assert isinstance(assistant_response, str), "Assuming a simple string response for now" + # Extract the ground truth answer from the original conversation. assistant_message = conversation['messages'][-1] - assert assistant_message['role'] == "assistant", "Last message must be from the Assistant" - assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts" - # The last text part contains the final answer with #### + assert assistant_message['role'] == "assistant", "The last message should be from the assistant" + assert isinstance(assistant_message['content'], list), "Content is expected to be a list of parts" last_text_part = assistant_message['content'][-1]['text'] - # Extract both the ground truth answer and the predicted answer + + # Extract the reference number and the predicted number. ref_num = extract_answer(last_text_part) pred_num = extract_answer(assistant_response) - # Compare and return the success as int + + # Compare and return the result. is_correct = int(pred_num == ref_num) return is_correct def reward(self, conversation, assistant_response): - """ Use simple 0-1 reward just like gsm8k.""" + """ + Provides a simple binary reward (0 or 1) based on the evaluation result. + This is used during reinforcement learning. + """ is_correct = self.evaluate(conversation, assistant_response) is_correct_float = float(is_correct) return is_correct_float class SimpleSpelling(Task): - """Much simpler task designed to get the model to just practice spelling words.""" + """ + A simpler task designed to train the model on basic spelling. + This helps smaller models learn the correspondence between tokens and characters. + """ def __init__(self, size=1000, split="train", **kwargs): + """ + Initializes the SimpleSpelling task. + Args: + size (int): The number of examples to generate. + split (str): The dataset split, "train" or "test". + **kwargs: Additional arguments for the parent Task class. + """ super().__init__(**kwargs) assert split in ["train", "test"], "SpellingBee split must be train|test" self.size = size @@ -241,23 +288,31 @@ class SimpleSpelling(Task): with open(word_list_path) as f: words = [line.strip() for line in f] rng = random.Random(42) - rng.shuffle(words) # use a different word order than the SpellingBee task + rng.shuffle(words) # Use a different word order than SpellingBee for variety. self.words = words @property def eval_type(self): + """ This task uses generative evaluation. """ return 'generative' def num_examples(self): + """ Returns the number of examples in this task. """ return self.size def get_example(self, index): - seed = index if self.split == "train" else -(index + 1) # avoid collision at 0 + """ + Generates a single example for the SimpleSpelling task. + Args: + index (int): An index for seeding the random number generator. + Returns: + dict: A conversation dictionary for the task. + """ + seed = index if self.split == "train" else -(index + 1) rng = random.Random(seed) - # pick a random word word = rng.choice(self.words) word_letters = ",".join(list(word)) - # return the full conversation + messages = [ {"role": "user", "content": f"Spell the word: {word}"}, {"role": "assistant", "content": f"{word}:{word_letters}"} @@ -269,37 +324,40 @@ class SimpleSpelling(Task): if __name__ == "__main__": + # This block allows for previewing the generated examples from the tasks. - # preview the SpellingBee task, first 10 examples + # Preview the SpellingBee task. + print("--- SpellingBee Task Preview ---") task = SpellingBee() for i in range(10): ex = task.get_example(i) print("=" * 100) - print(ex['messages'][0]['content']) + print(f"User: {ex['messages'][0]['content']}") print("-" * 100) - # Assistant content is now a list of parts + print("Assistant:") assistant_parts = ex['messages'][1]['content'] for part in assistant_parts: if part['type'] == 'text': print(part['text'], end='') elif part['type'] == 'python': - print(f"<<{part['text']}=", end='') + print(f"< ", end='') elif part['type'] == 'python_output': - print(f"{part['text']}>>", end='') - print() - print("-" * 100) + print(f"Out: {part['text']}>>", end='') + print("\n" + "-" * 100) - # # preview the SimpleSpelling task, first 10 examples + # # To preview the SimpleSpelling task, uncomment the following lines. + # print("\n\n--- SimpleSpelling Task Preview ---") # task = SimpleSpelling() # for i in range(10): # ex = task.get_example(i) # print("=" * 100) - # print(ex['messages'][0]['content']) + # print(f"User: {ex['messages'][0]['content']}") # print("-" * 100) - # print(ex['messages'][1]['content']) + # print(f"Assistant: {ex['messages'][1]['content']}") - # # also scrutinize the tokenization (last example only) + # # To scrutinize the tokenization of the last example, uncomment these lines. # from nanochat.tokenizer import get_tokenizer # tokenizer = get_tokenizer() # ids, mask = tokenizer.render_conversation(ex) + # print("\n--- Tokenization of Last Example ---") # print(tokenizer.visualize_tokenization(ids, mask, with_token_id=True))