Compare commits

...

3 Commits

Author SHA1 Message Date
Eyal Frishman
59ed9392ed Add pre-commit documentation to README and GitHub workflow
Sets up a pre-commit workflow to automate code linting and formatting.

This ensures code quality and consistency by running checks before code is committed.
2025-12-05 19:59:38 +02:00
Eyal Frishman
449494c8b6 Fix (automatically) all pre-commit errors 2025-12-05 19:59:35 +02:00
Eyal Frishman
6587063479 Add pre-commit hooks for code formatting (not yet executed)
Adds pre-commit configuration to automate code formatting and linting.

This includes:
- ruff-check for linting
- ruff-format for code formatting
- pre-commit-hooks for various checks (whitespace, large files, etc.)
- codespell for fixing common misspellings in text

Special configurations added to `pyproject.toml`:
- Line length: 120 (more flexible than black's 88)
- Quote style: preserve (keeps existing quotes)
Rename ruff hook (avoid using legacy alias)
2025-12-05 19:59:35 +02:00
44 changed files with 1757 additions and 1074 deletions

44
.github/workflows/pre-commit.yml vendored Normal file
View File

@ -0,0 +1,44 @@
name: Pre-commit
on:
push:
branches:
- main
- "release/**"
pull_request:
workflow_dispatch:
permissions:
contents: read
jobs:
run-pre-commit:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Cache uv & pre-commit
uses: actions/cache@v4
with:
path: |
.venv
~/.cache/uv
~/.cache/pre-commit
key: ${{ runner.os }}-uv-${{ hashFiles('uv.lock', '.pre-commit-config.yaml') }}
restore-keys: |
${{ runner.os }}-uv-
- name: Install dev dependencies
run: uv sync --group dev
- name: Run pre-commit
run: uv run pre-commit run --all-files

2
.gitignore vendored
View File

@ -4,4 +4,4 @@ __pycache__/
rustbpe/target/
dev-ignore/
report.md
eval_bundle/
eval_bundle/

27
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,27 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
args: [--maxkb=128]
- id: fix-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: check-yaml
- id: end-of-file-fixer
- id: mixed-line-ending
args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.8
hooks:
- id: ruff-check
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1 # Use the latest stable version
hooks:
- id: codespell
additional_dependencies:
- tomli

View File

@ -123,6 +123,22 @@ I haven't invested too much here but some tests exist, especially for the tokeni
python -m pytest tests/test_rustbpe.py -v -s
```
## Pre-commit hooks
Linting and formatting are enforced with [pre-commit](https://pre-commit.com/) both locally and in CI via GitHub Actions. To match the checks that run in PRs:
- Make sure the dev extras are installed (`uv sync --group dev`).
- Run the suite on demand: `uv run pre-commit run --all-files`.
- (optional) Install the git hook once (for automation during `git commit`): `uv run pre-commit install`.
Hook coverage (auto-fixes most issues; review and stage the changes afterward):
- [`ruff`](https://github.com/astral-sh/ruff): a fast Rust-based linter and formatter that replaces multiple tools:
- **Linting** (`ruff-check`): removes unused imports (like autoflake), upgrades syntax (like pyupgrade), and sorts imports (like isort).
- **Formatting** (`ruff-format`): applies consistent code formatting (like black), with quote style preserved.
- [`pre-commit-hooks`](https://github.com/pre-commit/pre-commit-hooks): repo hygiene (trim trailing whitespace, enforce LF endings/newlines, detect merge conflicts, block oversized files).
- [`codespell`](https://github.com/codespell-project/codespell): catches common spelling mistakes in code and docs (add false positives to `[tool.codespell].ignore-words-list` in `pyproject.toml`).
## File structure
```

View File

@ -28,24 +28,23 @@ NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the
(obviously you can tune this arbitrarily to your liking)
NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139
"""
import requests
import copy
import json
import os
import copy
import random
from concurrent.futures import ThreadPoolExecutor, as_completed
import requests
from nanochat.common import get_base_dir
api_key = open("openroutertoken.txt", "r", encoding="utf-8").read().strip()
api_key = open("openroutertoken.txt", encoding="utf-8").read().strip()
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
readme = open("README.md", "r", encoding="utf-8").read().strip()
readme = open("README.md", encoding="utf-8").read().strip()
prompt = r"""
I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want:
@ -276,48 +275,46 @@ prompt = prompt.replace("%README%", readme)
# Define the JSON schema for structured output
response_format = {
"type": "json_schema",
"json_schema": {
"name": "conversation",
"strict": True,
"schema": {
"type": "object",
"properties": {
"messages": {
"type": "array",
"description": "A list of conversation messages alternating between user and assistant, with the first message being a user message",
"items": {
"type": "json_schema",
"json_schema": {
"name": "conversation",
"strict": True,
"schema": {
"type": "object",
"properties": {
"role": {
"type": "string",
"description": "The role of the speaker, either 'user' or 'assistant'"
},
"content": {
"type": "string",
"description": "The message content"
}
"messages": {
"type": "array",
"description": "A list of conversation messages alternating between user and assistant, with the first message being a user message",
"items": {
"type": "object",
"properties": {
"role": {
"type": "string",
"description": "The role of the speaker, either 'user' or 'assistant'",
},
"content": {"type": "string", "description": "The message content"},
},
"required": ["role", "content"],
"additionalProperties": False,
},
}
},
"required": ["role", "content"],
"additionalProperties": False
}
}
},
"required": ["messages"],
"additionalProperties": False
}
}
"required": ["messages"],
"additionalProperties": False,
},
},
}
# Sadly it doesn't seem like Chat completions support `n`
# to generate multiple completions per prompt.
base_payload = {
"model": "google/gemini-2.5-flash",
"stream": False,
"response_format": response_format,
"temperature": 1.0,
"model": "google/gemini-2.5-flash",
"stream": False,
"response_format": response_format,
"temperature": 1.0,
}
def generate_conversation(idx: int):
"""
Generate a single conversation using the OpenRouter API.
@ -325,7 +322,7 @@ def generate_conversation(idx: int):
"""
# pick 5 example user first messages and insert them into prompt as inspiration
rng = random.Random(idx) # use idx as seed to the rng
rng = random.Random(idx) # use idx as seed to the rng
user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5))
payload = copy.deepcopy(base_payload)
modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt)
@ -357,7 +354,6 @@ print(f"Generating {num_conversations} conversations with {num_workers} workers.
completed_count = 0
error_count = 0
with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Submit all tasks
futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)]
@ -369,7 +365,9 @@ with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Lightly validate the conversation structure
for i, message in enumerate(messages):
expected_role = "user" if i % 2 == 0 else "assistant"
assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
assert message['role'] == expected_role, (
f"Message {i} has role {message['role']} but should be {expected_role}"
)
# If all looks good, write the messages to file
with open(output_file, 'a') as f:
@ -384,4 +382,3 @@ with ThreadPoolExecutor(max_workers=num_workers) as executor:
print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}")
if error_count > 0:
print(f"Encountered {error_count} errors during generation")

View File

@ -26,4 +26,4 @@
svg.innerHTML += `<path d="M200,-12 L212,0 L200,12 L188,0 Z" transform="translate(0,200)" fill="#000"/>`;
</script>
</body>
</html>
</html>

View File

@ -13,24 +13,25 @@ training latency.
NOTE: This file is meant only as reference/documentation of the
dataset preparation and it is not used during the project runtime.
"""
import os
import time
from datasets import load_dataset
import pyarrow.parquet as pq
import pyarrow as pa
import pyarrow.parquet as pq
from datasets import load_dataset
# Source dataset
dataset_kwargs = {
"path": "HuggingFaceFW/fineweb-edu",
"split": "train",
"name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total
"name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total
}
ds = load_dataset(**dataset_kwargs)
# Shuffle to scramble the order
ds = ds.shuffle(seed=42)
ndocs = len(ds) # total number of documents to process
ndocs = len(ds) # total number of documents to process
print(f"Total number of documents: {ndocs}")
# Repackage into parquet files
@ -39,7 +40,7 @@ os.makedirs(output_dir, exist_ok=True)
# Write to parquet files
chars_per_shard = 250_000_000
row_group_size = 1024 # HF uses 1000 but we use multiple of 2, nicer for distributed data loader later
row_group_size = 1024 # HF uses 1000 but we use multiple of 2, nicer for distributed data loader later
shard_docs = []
shard_index = 0
shard_characters = 0
@ -52,20 +53,20 @@ for doc in ds:
shard_characters += len(text)
collected_enough_chars = shard_characters >= chars_per_shard
docs_multiple_of_row_group_size = len(shard_docs) % row_group_size == 0
if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed)
if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed)
shard_path = os.path.join(output_dir, f"shard_{shard_index:05d}.parquet")
shard_table = pa.Table.from_pydict({"text": shard_docs})
pq.write_table(
shard_table,
shard_path,
row_group_size=row_group_size,
use_dictionary=False, # this is usually used for categorical data
compression="zstd", # Valid values: {NONE, SNAPPY, GZIP, BROTLI, LZ4, ZSTD}
use_dictionary=False, # this is usually used for categorical data
compression="zstd", # Valid values: {NONE, SNAPPY, GZIP, BROTLI, LZ4, ZSTD}
compression_level=3,
write_statistics=False, # not needed for text
write_statistics=False, # not needed for text
)
t1 = time.time()
dt = t1 - t0 # for this shard alone
dt = t1 - t0 # for this shard alone
t0 = t1
total_docs_processed += len(shard_docs)
total_time_spent += dt
@ -73,15 +74,20 @@ for doc in ds:
avg_time_per_doc = total_time_spent / total_docs_processed
remaining_time = remaining_docs * avg_time_per_doc
remaining_time_hours = remaining_time / 3600
print(f"Wrote {shard_path}. #documents: {len(shard_docs)} | #characters: {shard_characters} | time: {dt:.2f}s | remaining time: {remaining_time_hours:.2f}h")
print(
f"Wrote {shard_path}. #documents: {len(shard_docs)} | #characters: {shard_characters} | time: {dt:.2f}s | remaining time: {remaining_time_hours:.2f}h"
)
shard_docs = []
shard_characters = 0
shard_index += 1
# Demonstration of how the data was later uploaded to HuggingFace
def upload():
import os
from huggingface_hub import HfApi
token = os.getenv("HF_TOKEN")
api = HfApi(token=token)
api.upload_large_folder(
@ -89,4 +95,6 @@ def upload():
repo_id="karpathy/fineweb-edu-100b-shuffle",
repo_type="dataset",
)
# upload()

View File

@ -2,6 +2,7 @@
Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
Not a general optimizer! But works for our specific use.
"""
import torch
import torch.distributed as dist
from torch import Tensor
@ -12,7 +13,15 @@ class DistAdamW(torch.optim.Optimizer):
Distributed AdamW optimizer.
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
"""
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
def __init__(
self,
param_groups,
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.01,
):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(param_groups, defaults)
@ -30,7 +39,9 @@ class DistAdamW(torch.optim.Optimizer):
grad = params[base_i].grad
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
reduce_scatter_futures.append(
dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
)
grad_slices.append(grad_slice)
idx = 0
@ -43,7 +54,7 @@ class DistAdamW(torch.optim.Optimizer):
reduce_scatter_futures[idx].wait()
p = params[base]
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
p_slice = p[rank * rank_size : (rank + 1) * rank_size]
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
state = self.state[p]
g_slice = grad_slices[idx]
@ -64,8 +75,8 @@ class DistAdamW(torch.optim.Optimizer):
exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
# bias corrections
bias1 = 1 - beta1 ** t
bias2 = 1 - beta2 ** t
bias1 = 1 - beta1**t
bias2 = 1 - beta2**t
# compute step
denom = exp_avg_sq.sqrt().add_(eps)
step_size = lr * (torch.sqrt(bias2) / bias1)

View File

@ -1,25 +1,29 @@
"""
Utilities for saving and loading model/optim/state checkpoints.
"""
import os
import re
import glob
import json
import logging
import os
import re
import torch
from nanochat.common import get_base_dir
from nanochat.common import get_base_dir, setup_default_logging
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import get_tokenizer
from nanochat.common import setup_default_logging
# Set up logging
setup_default_logging()
logger = logging.getLogger(__name__)
def log0(message):
if int(os.environ.get('RANK', 0)) == 0:
logger.info(message)
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
@ -38,6 +42,7 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data,
torch.save(optimizer_data, optimizer_path)
logger.info(f"Saved optimizer state to: {optimizer_path}")
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
# Load the model state
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
@ -49,7 +54,7 @@ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
optimizer_data = torch.load(optimizer_path, map_location=device)
# Load the metadata
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
with open(meta_path, "r", encoding="utf-8") as f:
with open(meta_path, encoding="utf-8") as f:
meta_data = json.load(f)
return model_data, optimizer_data, meta_data
@ -66,10 +71,7 @@ def build_model(checkpoint_dir, step, device, phase):
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
if device.type in {"cpu", "mps"}:
# Convert bfloat16 tensors to float for CPU inference
model_data = {
k: v.float() if v.dtype == torch.bfloat16 else v
for k, v in model_data.items()
}
model_data = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items()}
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
model_config_kwargs = meta_data["model_config"]
@ -79,7 +81,7 @@ def build_model(checkpoint_dir, step, device, phase):
model = GPT(model_config)
# Load the model state
model.to_empty(device=device)
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
model.load_state_dict(model_data, strict=True, assign=True)
# Put the model in the right training phase / mode
if phase == "eval":
@ -121,9 +123,11 @@ def find_last_step(checkpoint_dir):
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
return last_step
# -----------------------------------------------------------------------------
# convenience functions that take into account nanochat's directory structure
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
if model_tag is None:
# guess the model tag by defaulting to the largest model
@ -139,6 +143,7 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
return model, tokenizer, meta_data
def load_model(source, *args, **kwargs):
model_dir = {
"base": "base_checkpoints",

View File

@ -2,26 +2,30 @@
Common utilities for nanochat.
"""
import logging
import os
import re
import logging
import urllib.request
import torch
import torch.distributed as dist
from filelock import FileLock
class ColoredFormatter(logging.Formatter):
"""Custom formatter that adds colors to log messages."""
# ANSI color codes
COLORS = {
'DEBUG': '\033[36m', # Cyan
'INFO': '\033[32m', # Green
'DEBUG': '\033[36m', # Cyan
'INFO': '\033[32m', # Green
'WARNING': '\033[33m', # Yellow
'ERROR': '\033[31m', # Red
'CRITICAL': '\033[35m', # Magenta
'ERROR': '\033[31m', # Red
'CRITICAL': '\033[35m', # Magenta
}
RESET = '\033[0m'
BOLD = '\033[1m'
def format(self, record):
# Add color to the level name
levelname = record.levelname
@ -36,17 +40,17 @@ class ColoredFormatter(logging.Formatter):
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
return message
def setup_default_logging():
handler = logging.StreamHandler()
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logging.basicConfig(
level=logging.INFO,
handlers=[handler]
)
logging.basicConfig(level=logging.INFO, handlers=[handler])
setup_default_logging()
logger = logging.getLogger(__name__)
def get_base_dir():
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
if os.environ.get("NANOCHAT_BASE_DIR"):
@ -58,6 +62,7 @@ def get_base_dir():
os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir
def download_file_with_lock(url, filename, postprocess_fn=None):
"""
Downloads a file from a URL to a local path in the base directory.
@ -81,7 +86,7 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
# Download the content as bytes
print(f"Downloading {url}...")
with urllib.request.urlopen(url) as response:
content = response.read() # bytes
content = response.read() # bytes
# Write to local file
with open(file_path, 'wb') as f:
@ -94,11 +99,13 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
return file_path
def print0(s="",**kwargs):
def print0(s="", **kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
print(s, **kwargs)
def print_banner():
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
banner = """
@ -113,10 +120,12 @@ def print_banner():
"""
print0(banner)
def is_ddp():
# TODO is there a proper way
return int(os.environ.get('RANK', -1)) != -1
def get_dist_info():
if is_ddp():
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
@ -127,6 +136,7 @@ def get_dist_info():
else:
return False, 0, 0, 1
def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch.cuda.is_available():
@ -138,14 +148,19 @@ def autodetect_device_type():
print0(f"Autodetected device type: {device_type}")
return device_type
def compute_init(device_type="cuda"): # cuda|cpu|mps
def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common."""
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
assert torch.cuda.is_available(), (
"Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
)
if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
assert torch.backends.mps.is_available(), (
"Your PyTorch installation is not configured for MPS but device_type is 'mps'"
)
# Reproducibility
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
@ -158,7 +173,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
# Precision
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
@ -168,23 +183,28 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
else:
device = torch.device(device_type) # mps|cpu
device = torch.device(device_type) # mps|cpu
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit"""
if is_ddp():
dist.destroy_process_group()
class DummyWandb:
"""Useful if we wish to not use wandb but have all the same signatures"""
def __init__(self):
pass
def log(self, *args, **kwargs):
pass
def finish(self):
pass

View File

@ -18,11 +18,13 @@ import os
import sys
from ast import literal_eval
def print0(s="",**kwargs):
def print0(s="", **kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
print(s, **kwargs)
for arg in sys.argv[1:]:
if '=' not in arg:
# assume it's the name of a config file

View File

@ -5,15 +5,17 @@ https://arxiv.org/abs/2406.11794
TODOs:
- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
"""
import random
from jinja2 import Template
import torch
import torch.distributed as dist
from jinja2 import Template
# -----------------------------------------------------------------------------
# Prompt rendering utilities
def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
"""Render complete prompts for a multiple choice question"""
template_str = """
@ -24,11 +26,7 @@ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
prompts = [template.render(choice=choice, **context) for choice in item['choices']]
return prompts
@ -43,13 +41,8 @@ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
prompts = [template.render(context=context_option, **context)
for context_option in item['context_options']]
context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
prompts = [template.render(context=context_option, **context) for context_option in item['context_options']]
return prompts
@ -67,11 +60,7 @@ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
# Return two prompts: without and with the continuation
prompt_without = template.render(include_continuation=False, **context)
prompt_with = template.render(include_continuation=True, **context)
@ -89,10 +78,7 @@ def find_common_length(token_sequences, direction='left'):
- direction: 'left' for prefix, 'right' for suffix
"""
min_len = min(len(seq) for seq in token_sequences)
indices = {
'left': range(min_len),
'right': range(-1, -min_len-1, -1)
}[direction]
indices = {'left': range(min_len), 'right': range(-1, -min_len - 1, -1)}[direction]
# Find the first position where the token sequences differ
for i, idx in enumerate(indices):
token = token_sequences[0][idx]
@ -106,7 +92,7 @@ def stack_sequences(tokens, pad_token_id):
bsz, seq_len = len(tokens), max(len(x) for x in tokens)
input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
for i, x in enumerate(tokens):
input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
input_ids[i, : len(x)] = torch.tensor(x, dtype=torch.long)
return input_ids
@ -153,9 +139,7 @@ def forward_model(model, input_ids):
target_ids = torch.roll(input_ids, shifts=-1, dims=1)
# Calculate cross entropy at all positions
losses = torch.nn.functional.cross_entropy(
outputs.view(batch_size * seq_len, -1),
target_ids.view(batch_size * seq_len),
reduction='none'
outputs.view(batch_size * seq_len, -1), target_ids.view(batch_size * seq_len), reduction='none'
).view(batch_size, seq_len)
# Set the last column to be nan because there is no autoregressive loss there
losses[:, -1] = float('nan')
@ -201,19 +185,19 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
for t, s, e in zip(tokens, start_idxs, end_idxs):
if len(t) > max_tokens:
num_to_crop = len(t) - max_tokens
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
new_start_idxs.append(s - num_to_crop) # shift the indices down
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
new_start_idxs.append(s - num_to_crop) # shift the indices down
new_end_idxs.append(e - num_to_crop)
assert s - num_to_crop >= 0, "this should never happen right?"
assert e - num_to_crop >= 0, "this should never happen right?"
else:
new_tokens.append(t) # keep unchanged
new_tokens.append(t) # keep unchanged
new_start_idxs.append(s)
new_end_idxs.append(e)
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
# Stack up all the sequences into a batch
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
input_ids = stack_sequences(tokens, pad_token_id)
input_ids = input_ids.to(device)
@ -226,13 +210,12 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
si = start_idxs[0]
ei = end_idxs[0]
# predictions[i] predict input_ids[i+1] autoregressively
predicted_tokens = predictions[0, si-1:ei-1]
predicted_tokens = predictions[0, si - 1 : ei - 1]
actual_tokens = input_ids[0, si:ei]
is_correct = torch.all(predicted_tokens == actual_tokens).item()
elif task_type in ['multiple_choice', 'schema']:
# For MC/schema: find the option with lowest average loss
mean_losses = [losses[i, si-1:ei-1].mean().item()
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
mean_losses = [losses[i, si - 1 : ei - 1].mean().item() for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
pred_idx = mean_losses.index(min(mean_losses))
is_correct = pred_idx == item['gold']
else:

View File

@ -1,13 +1,16 @@
from collections import deque
import torch
import pyarrow.parquet as pq
import torch
from nanochat.common import get_dist_info
from nanochat.dataset import list_parquet_files
from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
def tokenizing_distributed_data_loader_with_state(
B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None
):
"""
Stream pretraining text from parquet files, tokenize, yield training batches.
@ -24,42 +27,44 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
# infinite iterator over document batches (list of text strings)
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
def document_batches():
parquet_paths = list_parquet_files()
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
while True: # iterate infinitely (multi-epoch)
while pq_idx < len(parquet_paths): # iterate over all parquet files
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
while True: # iterate infinitely (multi-epoch)
while pq_idx < len(parquet_paths): # iterate over all parquet files
filepath = parquet_paths[pq_idx]
pf = pq.ParquetFile(filepath)
# Start from resume point if resuming on same file, otherwise from DDP rank
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
if resume_rg_idx is not None:
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank
resume_rg_idx = None # set to None as we only want to do this a single time
resume_rg_idx = None # set to None as we only want to do this a single time
else:
rg_idx = ddp_rank
while rg_idx < pf.num_row_groups:
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
rg_idx += ddp_world_size # advance to the next row group (in DDP)
pq_idx += 1 # advance to the next parquet file
yield batch[i : i + tokenizer_batch_size], (pq_idx, rg_idx)
rg_idx += ddp_world_size # advance to the next row group (in DDP)
pq_idx += 1 # advance to the next parquet file
batches = document_batches()
# Now emit batches of tokens.
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
# get the tokenizer and the bos token
tokenizer = get_tokenizer()
bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration
token_buffer = deque() # we stream tokens on the right and pop from the left
token_buffer = deque() # we stream tokens on the right and pop from the left
while True:
# Accumulate enough tokens for one iteration before yielding.
while len(token_buffer) < needed_tokens:
@ -71,16 +76,20 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
use_cuda_optimizations = device == "cuda"
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
# Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1]
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
state_dict = {
"pq_idx": pq_idx,
"rg_idx": rg_idx,
} # we need this in case we wish to approximately resume training
yield inputs, targets, state_dict
def tokenizing_distributed_data_loader(*args, **kwargs):
# helper function that only emits the inputs/targets and not the state_dict
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):

View File

@ -7,13 +7,14 @@ This file contains utilities for:
For details of how the dataset was prepared, see `repackage_data_reference.py`.
"""
import os
import argparse
import os
import time
import requests
import pyarrow.parquet as pq
from multiprocessing import Pool
import pyarrow.parquet as pq
import requests
from nanochat.common import get_base_dir
# -----------------------------------------------------------------------------
@ -21,8 +22,8 @@ from nanochat.common import get_base_dir
# The URL on the internet where the data is hosted and downloaded from on demand
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
base_dir = get_base_dir()
DATA_DIR = os.path.join(base_dir, "base_data")
os.makedirs(DATA_DIR, exist_ok=True)
@ -30,16 +31,15 @@ os.makedirs(DATA_DIR, exist_ok=True)
# -----------------------------------------------------------------------------
# These functions are useful utilities to other modules, can/should be imported
def list_parquet_files(data_dir=None):
""" Looks into a data dir and returns full paths to all parquet files. """
"""Looks into a data dir and returns full paths to all parquet files."""
data_dir = DATA_DIR if data_dir is None else data_dir
parquet_files = sorted([
f for f in os.listdir(data_dir)
if f.endswith('.parquet') and not f.endswith('.tmp')
])
parquet_files = sorted([f for f in os.listdir(data_dir) if f.endswith('.parquet') and not f.endswith('.tmp')])
parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
return parquet_paths
def parquets_iter_batched(split, start=0, step=1):
"""
Iterate through the dataset, in batches of underlying row_groups for efficiency.
@ -56,9 +56,10 @@ def parquets_iter_batched(split, start=0, step=1):
texts = rg.column('text').to_pylist()
yield texts
# -----------------------------------------------------------------------------
def download_single_file(index):
""" Downloads a single file index, with some backoff """
"""Downloads a single file index, with some backoff"""
# Construct the local filepath for this file and skip if it already exists
filename = index_to_filename(index)
@ -78,7 +79,7 @@ def download_single_file(index):
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
# Write to temporary file first
temp_path = filepath + f".tmp"
temp_path = filepath + ".tmp"
with open(temp_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
if chunk:
@ -88,10 +89,10 @@ def download_single_file(index):
print(f"Successfully downloaded {filename}")
return True
except (requests.RequestException, IOError) as e:
except (OSError, requests.RequestException) as e:
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
# Clean up any partial files
for path in [filepath + f".tmp", filepath]:
for path in [filepath + ".tmp", filepath]:
if os.path.exists(path):
try:
os.remove(path)
@ -99,7 +100,7 @@ def download_single_file(index):
pass
# Try a few times with exponential backoff: 2^attempt seconds
if attempt < max_attempts:
wait_time = 2 ** attempt
wait_time = 2**attempt
print(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time)
else:
@ -111,8 +112,12 @@ def download_single_file(index):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
parser.add_argument(
"-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable"
)
parser.add_argument(
"-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)"
)
args = parser.parse_args()
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)

View File

@ -11,15 +11,17 @@ Notes:
The whole thing is made as efficient as possible.
"""
import torch
import torch.nn.functional as F
import signal
import warnings
from contextlib import contextmanager
from collections import deque
from nanochat.common import compute_init, autodetect_device_type
from contextlib import contextmanager, nullcontext
import torch
import torch.nn.functional as F
from nanochat.checkpoint_manager import load_model
from contextlib import nullcontext
from nanochat.common import autodetect_device_type, compute_init
# -----------------------------------------------------------------------------
# Calculator tool helpers
@ -33,17 +35,19 @@ def timeout(duration, formula):
yield
signal.alarm(0)
def eval_with_timeout(formula, max_time=3):
try:
with timeout(max_time, formula):
with warnings.catch_warnings():
warnings.simplefilter("ignore", SyntaxWarning)
return eval(formula, {"__builtins__": {}}, {})
except Exception as e:
except Exception:
signal.alarm(0)
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
return None
def use_calculator(expr):
"""
Evaluate a Python expression safely.
@ -65,9 +69,25 @@ def use_calculator(expr):
return None
# Disallow dangerous patterns
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
'getattr', 'setattr', 'delattr', 'hasattr']
dangerous_patterns = [
'__',
'import',
'exec',
'eval',
'compile',
'open',
'file',
'input',
'raw_input',
'globals',
'locals',
'vars',
'dir',
'getattr',
'setattr',
'delattr',
'hasattr',
]
expr_lower = expr.lower()
if any(pattern in expr_lower for pattern in dangerous_patterns):
return None
@ -79,6 +99,7 @@ def use_calculator(expr):
# Evaluate with timeout
return eval_with_timeout(expr)
# -----------------------------------------------------------------------------
class KVCache:
"""
@ -90,7 +111,7 @@ class KVCache:
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
self.kv_cache = None
self.pos = 0 # current position in time in the cache
self.pos = 0 # current position in time in the cache
def reset(self):
self.pos = 0
@ -122,7 +143,7 @@ class KVCache:
dtype, device = other.kv_cache.dtype, other.kv_cache.device
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
# 3) copy the data over
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
self.kv_cache[:, :, :, :, : other.pos, :] = other.kv_cache
# 4) update the pos
self.pos = other.pos
@ -135,8 +156,8 @@ class KVCache:
t0, t1 = self.pos, self.pos + T_add
# Dynamically grow the cache if needed
if t1 > self.kv_cache.size(4):
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
additional_shape = list(self.kv_cache.shape)
additional_shape[4] = t_needed - self.kv_cache.size(4)
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
@ -173,22 +194,24 @@ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=rng)
# -----------------------------------------------------------------------------
class RowState:
# Per-row state tracking during generation
def __init__(self, current_tokens=None):
self.current_tokens = current_tokens or [] # Current token sequence for this row
self.forced_tokens = deque() # Queue of tokens to force inject
self.in_python_block = False # Whether we are inside a python block
self.python_expr_tokens = [] # Tokens of the current python expression
self.completed = False # Whether this row has completed generation
self.current_tokens = current_tokens or [] # Current token sequence for this row
self.forced_tokens = deque() # Queue of tokens to force inject
self.in_python_block = False # Whether we are inside a python block
self.python_expr_tokens = [] # Tokens of the current python expression
self.completed = False # Whether this row has completed generation
class Engine:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer # needed for tool use
self.tokenizer = tokenizer # needed for tool use
@torch.inference_mode()
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
@ -204,8 +227,8 @@ class Engine:
python_end = get_special("<|python_end|>")
output_start = get_special("<|output_start|>")
output_end = get_special("<|output_end|>")
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
# 1) Run a batch 1 prefill of the prompt tokens
m = self.model.config
@ -229,7 +252,7 @@ class Engine:
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
del kv_cache_prefill # no need to keep this memory around
del kv_cache_prefill # no need to keep this memory around
# 3) Initialize states for each sample
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
@ -259,12 +282,12 @@ class Engine:
sampled_tokens = next_ids[:, 0].tolist()
# Process each row: choose the next token, update state, optional tool use
token_column = [] # contains the next token id along each row
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
token_column = [] # contains the next token id along each row
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
for i, state in enumerate(row_states):
# Select the next token in this row
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
token_column.append(next_token)
# Update the state of this row to include the next token
@ -327,10 +350,13 @@ if __name__ == "__main__":
is equivalent to the faster Engine.generate function here.
"""
import time
# init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
# load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="eval")
@ -357,12 +383,12 @@ if __name__ == "__main__":
# generate tokens with Engine
generated_tokens = []
engine = Engine(model, tokenizer)
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
torch.cuda.synchronize()
t0 = time.time()
with autocast_ctx:
for token_column, token_masks in stream:
token = token_column[0] # only print out the first row
token = token_column[0] # only print out the first row
generated_tokens.append(token)
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)

View File

@ -30,17 +30,18 @@ import platform
import signal
import tempfile
from dataclasses import dataclass
from typing import Optional
# -----------------------------------------------------------------------------
@dataclass
class ExecutionResult:
"""Result of executing Python code in a sandbox."""
success: bool
stdout: str
stderr: str
error: Optional[str] = None
error: str | None = None
timeout: bool = False
memory_exceeded: bool = False
@ -101,13 +102,13 @@ class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from"""
def read(self, *args, **kwargs):
raise IOError
raise OSError
def readline(self, *args, **kwargs):
raise IOError
raise OSError
def readlines(self, *args, **kwargs):
raise IOError
raise OSError
def readable(self, *args, **kwargs):
"""Returns True if the IO object can be read."""
@ -131,7 +132,7 @@ def chdir(root):
os.chdir(cwd)
def reliability_guard(maximum_memory_bytes: Optional[int] = None):
def reliability_guard(maximum_memory_bytes: int | None = None):
"""
This disables various destructive functions and prevents the generated code
from interfering with the test (e.g. fork bomb, killing other processes,
@ -147,6 +148,7 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
if platform.uname().system != "Darwin":
# These resource limit calls seem to fail on macOS (Darwin), skip?
import resource
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
@ -211,10 +213,9 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
sys.modules["tkinter"] = None
def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: int | None, result_dict):
"""Execute code in a subprocess with safety guards. Results are written to result_dict."""
with create_tempdir():
# These system calls are needed when cleaning up tempdir.
import os
import shutil
@ -228,14 +229,16 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
# Default to failure
result_dict.update({
"success": False,
"stdout": "",
"stderr": "",
"timeout": False,
"memory_exceeded": False,
"error": None,
})
result_dict.update(
{
"success": False,
"stdout": "",
"stderr": "",
"timeout": False,
"memory_exceeded": False,
"error": None,
}
)
try:
exec_globals = {}
@ -253,28 +256,36 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
# uncomment the following line and proceed at your own risk:
exec(code, exec_globals)
result_dict.update({
"success": True,
"stdout": stdout_capture.getvalue(),
"stderr": stderr_capture.getvalue(),
})
result_dict.update(
{
"success": True,
"stdout": stdout_capture.getvalue(),
"stderr": stderr_capture.getvalue(),
}
)
except TimeoutException:
result_dict.update({
"timeout": True,
"error": "Execution timed out",
})
result_dict.update(
{
"timeout": True,
"error": "Execution timed out",
}
)
except MemoryError as e:
result_dict.update({
"memory_exceeded": True,
"error": f"Memory limit exceeded: {e}",
})
result_dict.update(
{
"memory_exceeded": True,
"error": f"Memory limit exceeded: {e}",
}
)
except BaseException as e:
result_dict.update({
"error": f"{type(e).__name__}: {e}",
})
result_dict.update(
{
"error": f"{type(e).__name__}: {e}",
}
)
# Needed for cleaning up.
shutil.rmtree = rmtree
@ -285,8 +296,8 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
def execute_code(
code: str,
timeout: float = 5.0, # 5 seconds default
maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
timeout: float = 5.0, # 5 seconds default
maximum_memory_bytes: int | None = 256 * 1024 * 1024, # 256MB default
) -> ExecutionResult:
"""
Execute Python code in a sandboxed environment.
@ -310,10 +321,7 @@ def execute_code(
manager = multiprocessing.Manager()
result_dict = manager.dict()
p = multiprocessing.Process(
target=_unsafe_execute,
args=(code, timeout, maximum_memory_bytes, result_dict)
)
p = multiprocessing.Process(target=_unsafe_execute, args=(code, timeout, maximum_memory_bytes, result_dict))
p.start()
p.join(timeout=timeout + 1)
@ -346,4 +354,3 @@ def execute_code(
timeout=result_dict["timeout"],
memory_exceeded=result_dict["memory_exceeded"],
)

View File

@ -12,24 +12,25 @@ Notable features:
"""
import math
from functools import partial
from dataclasses import dataclass
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
from nanochat.common import get_dist_info
from nanochat.muon import DistMuon, Muon
@dataclass
class GPTConfig:
sequence_len: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768
@ -41,13 +42,14 @@ def norm(x):
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
y2 = x1 * (-sin) + x2 * cos
out = torch.cat([y1, y2], 3) # re-assemble
out = out.to(x.dtype) # ensure input/output dtypes match
out = torch.cat([y1, y2], 3) # re-assemble
out = out.to(x.dtype) # ensure input/output dtypes match
return out
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
@ -73,18 +75,24 @@ class CausalSelfAttention(nn.Module):
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = norm(q), norm(k) # QK norm
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = norm(q), norm(k) # QK norm
q, k, v = (
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
# Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
enable_gqa = (
self.n_head != self.n_kv_head
) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk
@ -96,9 +104,9 @@ class CausalSelfAttention(nn.Module):
else:
# During inference AND we have a chunk of queries in this forward pass:
# First, each query attends to all the cached keys/values (i.e. full prefix)
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
prefix_len = Tk - Tq
if prefix_len > 0: # can't be negative but could be zero
if prefix_len > 0: # can't be negative but could be zero
attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
@ -139,19 +147,21 @@ class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.transformer = nn.ModuleDict(
{
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
}
)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# To support meta device initialization, we init the rotary embeddings here, but it's fake
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them, but assert fail if we ever reach that amount.
# In the future we can dynamically grow the cache, for now it's fine.
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
def init_weights(self):
@ -195,18 +205,23 @@ class GPT(nn.Module):
# calculate the rotation frequencies at each (time, channel) pair
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
def get_device(self):
return self.transformer.wte.weight.device
def estimate_flops(self):
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
"""Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311"""
nparams = sum(p.numel() for p in self.parameters())
nparams_embedding = self.transformer.wte.weight.numel()
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
l, h, q, t = (
self.config.n_layer,
self.config.n_head,
self.config.n_embd // self.config.n_head,
self.config.sequence_len,
)
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token
@ -245,12 +260,16 @@ class GPT(nn.Module):
B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert T <= self.cos.size(1), (
f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
)
assert idx.device == self.cos.device, (
f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
)
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
cos_sin = self.cos[:, T0 : T0 + T], self.sin[:, T0 : T0 + T] # truncate cache to current sequence length
# Forward the trunk of the Transformer
x = self.transformer.wte(idx)
@ -265,14 +284,16 @@ class GPT(nn.Module):
# training mode: compute and return the loss
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = logits.float() # use tf32/fp32 for logits
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = logits.float() # use tf32/fp32 for logits
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction
)
return loss
else:
# inference mode: compute and return the logits
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = softcap * torch.tanh(logits / softcap) # logits softcap
return logits
@torch.inference_mode()
@ -289,10 +310,10 @@ class GPT(nn.Module):
if temperature > 0:
rng = torch.Generator(device=device)
rng.manual_seed(seed)
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
for _ in range(max_tokens):
logits = self.forward(ids) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
logits = self.forward(ids) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')

View File

@ -1,10 +1,13 @@
"""
A number of functions that help with evaluating a base model.
"""
import math
import torch
import torch.distributed as dist
@torch.no_grad()
def evaluate_bpb(model, batches, steps, token_bytes):
"""
@ -30,20 +33,16 @@ def evaluate_bpb(model, batches, steps, token_bytes):
batch_iter = iter(batches)
for _ in range(steps):
x, y = next(batch_iter)
loss2d = model(x, y, loss_reduction='none') # (B, T)
loss2d = loss2d.view(-1) # flatten
y = y.view(-1) # flatten
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
loss2d = model(x, y, loss_reduction='none') # (B, T)
loss2d = loss2d.view(-1) # flatten
y = y.view(-1) # flatten
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
valid = y >= 0
y_safe = torch.where(valid, y, torch.zeros_like(y))
# map valid targets to their byte length; ignored targets contribute 0 bytes
num_bytes2d = torch.where(
valid,
token_bytes[y_safe],
torch.zeros_like(y, dtype=token_bytes.dtype)
)
num_bytes2d = torch.where(valid, token_bytes[y_safe], torch.zeros_like(y, dtype=token_bytes.dtype))
total_nats += (loss2d * (num_bytes2d > 0)).sum()
total_bytes += num_bytes2d.sum()
else:

View File

@ -2,9 +2,11 @@
Muon optimizer from Keller et al.
Also a lot of borrowing of ideas from modded-nanogpt.
"""
import torch
from torch import Tensor
import torch.distributed as dist
from torch import Tensor
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
@ -17,8 +19,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
assert (
G.ndim >= 2
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
@ -28,13 +32,16 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
B = (
b * A + c * A @ A
) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
@ -57,6 +64,7 @@ class Muon(torch.optim.Optimizer):
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use.
"""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params: list[Tensor] = [*params]
@ -80,7 +88,7 @@ class Muon(torch.optim.Optimizer):
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5)
class DistMuon(torch.optim.Optimizer):
@ -104,14 +112,14 @@ class DistMuon(torch.optim.Optimizer):
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
ns_steps: number of NewtonSchulz iterations for the orthogonalization
"""
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
nesterov: bool = True, ns_steps: int = 5):
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, nesterov: bool = True, ns_steps: int = 5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params = list(params)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
rank = dist.get_rank()
# Group all parameters by their shape
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
param_groups = []
for shape in shapes:
group_params = [p for p in params if p.shape == shape]
@ -129,7 +137,9 @@ class DistMuon(torch.optim.Optimizer):
world_size = dist.get_world_size()
# Ensure all grads exist
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), (
"All params must have grads"
)
# Kick off all the reduce scatter operations to average up the gradients across all ranks
all_reduce_futures = []
@ -141,7 +151,7 @@ class DistMuon(torch.optim.Optimizer):
# The compute owner of each param is rank i % world_size
owner_idx = base_i + rank
# each rank stacks up its chunk of world_size params into a list
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
rs_input = [p.grad for p in params[base_i : base_i + world_size]]
# pad rs_input with the zero buffer to complete the group
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
# the output buffer gets strided across the group based on the rank
@ -159,9 +169,9 @@ class DistMuon(torch.optim.Optimizer):
# Go through params in groups of world_size.
for base_i in range(0, len(params), world_size):
# The compute owner of each param is rank i % world_size
owner_idx = base_i + rank # calculate the index of the param that this rank owns
owner_idx = base_i + rank # calculate the index of the param that this rank owns
# Wait for the reduce scatter to complete
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
future_idx += 1
# Owner computes the Muon update, result is in its param
if owner_idx < len(params):
@ -174,12 +184,12 @@ class DistMuon(torch.optim.Optimizer):
buf.lerp_(g, 1.0 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
scale = max(1.0, p.size(-2) / p.size(-1)) ** 0.5
p.add_(g, alpha=-group["lr"] * scale)
# Replicate updated parameters to all ranks
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
ag_output = params[base_i:base_i + world_size]
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
ag_output = params[base_i : base_i + world_size]
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
all_gather_futures.append(work)

View File

@ -2,16 +2,18 @@
Utilities for generating training report cards. More messy code than usual, will fix.
"""
import datetime
import os
import platform
import re
import shutil
import subprocess
import socket
import datetime
import platform
import subprocess
import psutil
import torch
def run_command(cmd):
"""Run a shell command and return output, or None if it fails."""
try:
@ -22,6 +24,7 @@ def run_command(cmd):
except:
return None
def get_git_info():
"""Get current git commit, branch, and dirty status."""
info = {}
@ -38,18 +41,14 @@ def get_git_info():
return info
def get_gpu_info():
"""Get GPU information."""
if not torch.cuda.is_available():
return {"available": False}
num_devices = torch.cuda.device_count()
info = {
"available": True,
"count": num_devices,
"names": [],
"memory_gb": []
}
info = {"available": True, "count": num_devices, "names": [], "memory_gb": []}
for i in range(num_devices):
props = torch.cuda.get_device_properties(i)
@ -61,6 +60,7 @@ def get_gpu_info():
return info
def get_system_info():
"""Get system information."""
info = {}
@ -83,6 +83,7 @@ def get_system_info():
return info
def estimate_cost(gpu_info, runtime_hours=None):
"""Estimate training cost based on GPU type and runtime."""
@ -111,9 +112,10 @@ def estimate_cost(gpu_info, runtime_hours=None):
return {
"hourly_rate": hourly_rate,
"gpu_type": gpu_name,
"estimated_total": hourly_rate * runtime_hours if runtime_hours else None
"estimated_total": hourly_rate * runtime_hours if runtime_hours else None,
}
def generate_header():
"""Generate the header for a training report."""
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -165,12 +167,12 @@ Generated: {timestamp}
num_chars = len(packaged)
num_lines = len(packaged.split('\n'))
num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')])
num_tokens = num_chars // 4 # assume approximately 4 chars per token
num_tokens = num_chars // 4 # assume approximately 4 chars per token
# count dependencies via uv.lock
uv_lock_lines = 0
if os.path.exists('uv.lock'):
with open('uv.lock', 'r', encoding='utf-8') as f:
with open('uv.lock', encoding='utf-8') as f:
uv_lock_lines = len(f.readlines())
header += f"""
@ -184,12 +186,15 @@ Generated: {timestamp}
"""
return header
# -----------------------------------------------------------------------------
def slugify(text):
"""Slugify a text string."""
return text.lower().replace(" ", "-")
# the expected files and their order
EXPECTED_FILES = [
"tokenizer-training.md",
@ -207,10 +212,11 @@ EXPECTED_FILES = [
# the metrics we're currently interested in
chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
def extract(section, keys):
"""simple def to extract a single key from a section"""
if not isinstance(keys, list):
keys = [keys] # convenience
keys = [keys] # convenience
out = {}
for line in section.split("\n"):
for key in keys:
@ -218,6 +224,7 @@ def extract(section, keys):
out[key] = line.split(":")[1].strip()
return out
def extract_timestamp(content, prefix):
"""Extract timestamp from content with given prefix."""
for line in content.split('\n'):
@ -229,6 +236,7 @@ def extract_timestamp(content, prefix):
pass
return None
class Report:
"""Maintains a bunch of logs, generates a final markdown report."""
@ -269,14 +277,14 @@ class Report:
report_dir = self.report_dir
report_file = os.path.join(report_dir, "report.md")
print(f"Generating report to {report_file}")
final_metrics = {} # the most important final metrics we'll add as table at the end
final_metrics = {} # the most important final metrics we'll add as table at the end
start_time = None
end_time = None
with open(report_file, "w", encoding="utf-8") as out_file:
# write the header first
header_file = os.path.join(report_dir, "header.md")
if os.path.exists(header_file):
with open(header_file, "r", encoding="utf-8") as f:
with open(header_file, encoding="utf-8") as f:
header_content = f.read()
out_file.write(header_content)
start_time = extract_timestamp(header_content, "Run started:")
@ -284,7 +292,7 @@ class Report:
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
bloat_data = bloat_data.group(1) if bloat_data else ""
else:
start_time = None # will cause us to not write the total wall clock time
start_time = None # will cause us to not write the total wall clock time
bloat_data = "[bloat data missing]"
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
# process all the individual sections
@ -293,7 +301,7 @@ class Report:
if not os.path.exists(section_file):
print(f"Warning: {section_file} does not exist, skipping")
continue
with open(section_file, "r", encoding="utf-8") as in_file:
with open(section_file, encoding="utf-8") as in_file:
section = in_file.read()
# Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
if "rl" not in file_name:
@ -307,7 +315,7 @@ class Report:
if file_name == "chat-evaluation-sft.md":
final_metrics["sft"] = extract(section, chat_metrics)
if file_name == "chat-evaluation-rl.md":
final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
# append this section of the report
out_file.write(section)
out_file.write("\n")
@ -354,7 +362,7 @@ class Report:
else:
out_file.write("Total wall clock time: unknown\n")
# also cp the report.md file to current directory
print(f"Copying report.md to current directory for convenience")
print("Copying report.md to current directory for convenience")
shutil.copy(report_file, "report.md")
return report_file
@ -378,18 +386,23 @@ class Report:
f.write(f"Run started: {start_time}\n\n---\n\n")
print(f"Reset report and wrote header to {header_file}")
# -----------------------------------------------------------------------------
# nanochat-specific convenience functions
class DummyReport:
def log(self, *args, **kwargs):
pass
def reset(self, *args, **kwargs):
pass
def get_report():
# just for convenience, only rank 0 logs to report
from nanochat.common import get_base_dir, get_dist_info
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp_rank == 0:
report_dir = os.path.join(get_base_dir(), "report")
@ -397,10 +410,18 @@ def get_report():
else:
return DummyReport()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
parser.add_argument(
"command",
nargs="?",
default="generate",
choices=["generate", "reset"],
help="Operation to perform (default: generate)",
)
args = parser.parse_args()
if args.command == "generate":
get_report().generate()

View File

@ -6,36 +6,39 @@ Two implementations are available:
2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
"""
import os
import copy
import os
from functools import lru_cache
SPECIAL_TOKENS = [
# every document begins with the Beginning of Sequence (BOS) token that delimits documents
"<|bos|>",
# tokens below are only used during finetuning to render Conversations into token ids
"<|user_start|>", # user messages
"<|user_start|>", # user messages
"<|user_end|>",
"<|assistant_start|>", # assistant messages
"<|assistant_start|>", # assistant messages
"<|assistant_end|>",
"<|python_start|>", # assistant invokes python REPL tool
"<|python_start|>", # assistant invokes python REPL tool
"<|python_end|>",
"<|output_start|>", # python REPL outputs back to assistant
"<|output_start|>", # python REPL outputs back to assistant
"<|output_end|>",
]
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
# I haven't validated that this is actually a good idea, TODO.
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
SPLIT_PATTERN = (
r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
)
# -----------------------------------------------------------------------------
# Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
from tokenizers import Regex, decoders, pre_tokenizers
from tokenizers import Tokenizer as HFTokenizer
from tokenizers import pre_tokenizers, decoders, Regex
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
class HuggingFaceTokenizer:
"""Light wrapper around HuggingFace Tokenizer for some utilities"""
@ -59,11 +62,13 @@ class HuggingFaceTokenizer:
def train_from_iterator(cls, text_iterator, vocab_size):
# train from an iterator of text
# Configure the HuggingFace Tokenizer
tokenizer = HFTokenizer(BPE(
byte_fallback=True, # needed!
unk_token=None,
fuse_unk=False,
))
tokenizer = HFTokenizer(
BPE(
byte_fallback=True, # needed!
unk_token=None,
fuse_unk=False,
)
)
# Normalizer: None
tokenizer.normalizer = None
# Pre-tokenizer: GPT-4 style
@ -71,11 +76,13 @@ class HuggingFaceTokenizer:
# NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
# very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
# (but I haven't validated this! TODO)
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
])
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
]
)
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
tokenizer.decoder = decoders.ByteLevel()
# Post-processor: None
@ -84,7 +91,7 @@ class HuggingFaceTokenizer:
trainer = BpeTrainer(
vocab_size=vocab_size,
show_progress=True,
min_frequency=0, # no minimum frequency
min_frequency=0, # no minimum frequency
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
special_tokens=SPECIAL_TOKENS,
)
@ -146,12 +153,16 @@ class HuggingFaceTokenizer:
self.tokenizer.save(tokenizer_path)
print(f"Saved tokenizer to {tokenizer_path}")
# -----------------------------------------------------------------------------
# Tokenizer based on rustbpe + tiktoken combo
import pickle
import rustbpe
import tiktoken
import rustbpe
class RustBPETokenizer:
"""Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
@ -176,8 +187,8 @@ class RustBPETokenizer:
enc = tiktoken.Encoding(
name="rustbpe",
pat_str=pattern,
mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
)
return cls(enc, "<|bos|>")
@ -225,14 +236,14 @@ class RustBPETokenizer:
if isinstance(text, str):
ids = self.enc.encode_ordinary(text)
if prepend is not None:
ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
if append is not None:
ids.append(append_id)
elif isinstance(text, list):
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
if prepend is not None:
for ids_row in ids:
ids_row.insert(0, prepend_id) # TODO: same
ids_row.insert(0, prepend_id) # TODO: same
if append is not None:
for ids_row in ids:
ids_row.append(append_id)
@ -264,6 +275,7 @@ class RustBPETokenizer:
"""
# ids, masks that we will return and a helper function to help build them up.
ids, mask = [], []
def add_tokens(token_ids, mask_val):
if isinstance(token_ids, int):
token_ids = [token_ids]
@ -274,7 +286,7 @@ class RustBPETokenizer:
# => just merge it with the second (user) message
if conversation["messages"][0]["role"] == "system":
# some conversation surgery is necessary here for now...
conversation = copy.deepcopy(conversation) # avoid mutating the original
conversation = copy.deepcopy(conversation) # avoid mutating the original
messages = conversation["messages"]
assert messages[1]["role"] == "user", "System message must be followed by a user message"
messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
@ -286,17 +298,21 @@ class RustBPETokenizer:
# fetch all the special tokens we need
bos = self.get_bos_token_id()
user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
assistant_start, assistant_end = (
self.encode_special("<|assistant_start|>"),
self.encode_special("<|assistant_end|>"),
)
python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
# now we can tokenize the conversation
add_tokens(bos, 0)
for i, message in enumerate(messages):
# some sanity checking here around assumptions, to prevent footguns
must_be_from = "user" if i % 2 == 0 else "assistant"
assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
assert message["role"] == must_be_from, (
f"Message {i} is from {message['role']} but should be from {must_be_from}"
)
# content can be either a simple string or a list of parts (e.g. containing tool calls)
content = message["content"]
@ -363,10 +379,10 @@ class RustBPETokenizer:
Unlike the Chat SFT case, we don't need to return the mask.
"""
# We have some surgery to do: we need to pop the last message (of the Assistant)
conversation = copy.deepcopy(conversation) # avoid mutating the original
conversation = copy.deepcopy(conversation) # avoid mutating the original
messages = conversation["messages"]
assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
messages.pop() # remove the last message (of the Assistant) inplace
messages.pop() # remove the last message (of the Assistant) inplace
# Now tokenize the conversation
ids, mask = self.render_conversation(conversation)
@ -376,23 +392,31 @@ class RustBPETokenizer:
ids.append(assistant_start)
return ids
# -----------------------------------------------------------------------------
# nanochat-specific convenience functions
def get_tokenizer():
from nanochat.common import get_base_dir
base_dir = get_base_dir()
tokenizer_dir = os.path.join(base_dir, "tokenizer")
# return HuggingFaceTokenizer.from_directory(tokenizer_dir)
return RustBPETokenizer.from_directory(tokenizer_dir)
def get_token_bytes(device="cpu"):
import torch
from nanochat.common import get_base_dir
base_dir = get_base_dir()
tokenizer_dir = os.path.join(base_dir, "tokenizer")
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
assert os.path.exists(token_bytes_path), (
f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
)
with open(token_bytes_path, "rb") as f:
token_bytes = torch.load(f, map_location=device)
return token_bytes

View File

@ -32,6 +32,7 @@ manifest-path = "rustbpe/Cargo.toml"
dev = [
"maturin>=1.9.4",
"pytest>=8.0.0",
"pre-commit>=3.8.0",
]
[tool.pytest.ini_options]
@ -45,33 +46,58 @@ python_functions = ["test_*"]
# target torch to cuda 12.8 or CPU
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu128", extra = "gpu" },
torch = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu128", extra = "gpu" },
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[project.optional-dependencies]
cpu = [
"torch>=2.8.0",
]
gpu = [
"torch>=2.8.0",
]
[tool.uv]
conflicts = [
[
{ extra = "cpu" },
{ extra = "gpu" },
],
]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[project.optional-dependencies]
cpu = [
"torch>=2.8.0",
]
gpu = [
"torch>=2.8.0",
]
[tool.uv]
conflicts = [
[
{ extra = "cpu" },
{ extra = "gpu" },
],
]
[tool.ruff]
target-version = "py310"
line-length = 120
fix = true
unsafe-fixes = true
[tool.ruff.lint]
select = [
"F", # Pyflakes (unused imports) - replaces autoflake
"I", # isort - replaces isort
"UP", # pyupgrade - replaces pyupgrade
]
[tool.ruff.lint.isort]
known-first-party = ["nanochat"]
[tool.ruff.format]
quote-style = "preserve"
[tool.codespell]
write-changes = true
interactive = 1
skip = "tests/*,dev/*,scripts/tok_eval.py,tasks/spellingbee.py"
ignore-words-list = "re-use,astroid"

View File

@ -9,23 +9,31 @@ torchrun --nproc_per_node=8 -m scripts.base_eval
The script will print the CORE metric to the console.
"""
import os
import csv
import time
import json
import yaml
import shutil
import os
import random
import zipfile
import shutil
import tempfile
import time
import zipfile
from contextlib import nullcontext
import torch
import yaml
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
from nanochat.tokenizer import HuggingFaceTokenizer
from nanochat.checkpoint_manager import load_model
from nanochat.common import (
autodetect_device_type,
compute_cleanup,
compute_init,
download_file_with_lock,
get_base_dir,
print0,
)
from nanochat.core_eval import evaluate_task
from nanochat.tokenizer import HuggingFaceTokenizer
# -----------------------------------------------------------------------------
# nanochat specific function dealing with I/O etc.
@ -33,6 +41,7 @@ from nanochat.core_eval import evaluate_task
# ~162MB of data needed to evaluate the CORE metric
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
def place_eval_bundle(file_path):
# here file_path is the path to the eval_bundle.zip file
# we need to unzip it and place it in the base directory
@ -45,6 +54,7 @@ def place_eval_bundle(file_path):
shutil.move(extracted_bundle_dir, eval_bundle_dir)
print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
def evaluate_model(model, tokenizer, device, max_per_task=-1):
"""
Evaluate a base model on the CORE benchmark.
@ -59,13 +69,13 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
config_path = os.path.join(eval_bundle_dir, "core.yaml")
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
with open(config_path, 'r', encoding='utf-8') as f:
with open(config_path, encoding='utf-8') as f:
config = yaml.safe_load(f)
tasks = config['icl_tasks']
# Load random baseline values from eval metadata
random_baselines = {}
with open(eval_meta_data, 'r', encoding='utf-8') as f:
with open(eval_meta_data, encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
task_name = row['Eval Task']
@ -82,13 +92,13 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
'task_type': task['icl_task_type'],
'dataset_uri': task['dataset_uri'],
'num_fewshot': task['num_fewshot'][0],
'continuation_delimiter': task.get('continuation_delimiter', ' ')
'continuation_delimiter': task.get('continuation_delimiter', ' '),
}
print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
# Load data for this task
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
with open(data_path, 'r', encoding='utf-8') as f:
with open(data_path, encoding='utf-8') as f:
data = [json.loads(line.strip()) for line in f]
# shuffle the data because in many cases it appears ordered but we want
@ -109,18 +119,17 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
core_metric = sum(centered_results.values()) / len(centered_results)
out = {
"results": results,
"centered_results": centered_results,
"core_metric": core_metric
}
out = {"results": results, "centered_results": centered_results, "core_metric": core_metric}
return out
# -----------------------------------------------------------------------------
# HuggingFace loading utilities and light wrappers for a model
class ModelWrapper:
"""Lightweight wrapper for a HuggingFace model"""
def __init__(self, model, max_seq_len=None):
self.model = model
self.max_seq_len = max_seq_len
@ -130,10 +139,12 @@ class ModelWrapper:
logits = outputs.logits
return logits
def load_hf_model(hf_path: str, device):
print0(f"Loading model from: {hf_path}")
# Load the model
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(hf_path)
model.to(device)
model.eval()
@ -143,9 +154,11 @@ def load_hf_model(hf_path: str, device):
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
return model, tokenizer
# -----------------------------------------------------------------------------
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
@ -154,7 +167,9 @@ def main():
# distributed / precision setup
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
# Load model and tokenizer from command line or from file system
if args.hf_path is not None:
@ -162,13 +177,13 @@ def main():
hf_path = args.hf_path
print0(f"Loading huggingface model from: {hf_path}")
model, tokenizer = load_hf_model(hf_path, device)
model_name = hf_path # just for logging
model_slug = hf_path.replace("/", "-") # for the output csv file
model_name = hf_path # just for logging
model_slug = hf_path.replace("/", "-") # for the output csv file
else:
# load a local model from the file system
model, tokenizer, meta = load_model("base", device, phase="eval")
model_name = f"base_model (step {meta['step']})" # just for logging
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
model_name = f"base_model (step {meta['step']})" # just for logging
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
# Evaluate the model
with autocast_ctx:
@ -190,23 +205,28 @@ def main():
f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n")
f.write(f"{'CORE':<35}, {'':<10}, {core_metric:<10.6f}\n")
# Print the content of the csv file to console too
print0("="*80)
print0("=" * 80)
print0(f"Model: {model_name}")
print0("="*80)
with open(output_csv_path, 'r', encoding='utf-8') as f:
print0("=" * 80)
with open(output_csv_path, encoding='utf-8') as f:
print0(f.read())
# Log to report
from nanochat.report import get_report
get_report().log(section="Base model evaluation", data=[
{
"Model": model_name,
"CORE metric": core_metric,
},
centered_results, # the full table
])
get_report().log(
section="Base model evaluation",
data=[
{
"Model": model_name,
"CORE metric": core_metric,
},
centered_results, # the full table
],
)
compute_cleanup()
if __name__ == "__main__":
main()

View File

@ -6,30 +6,35 @@ Loads a checkpoint, and:
Example run as:
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
"""
import os
from contextlib import nullcontext
import torch
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, print0
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.tokenizer import get_token_bytes
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.loss_eval import evaluate_bpb
from nanochat.tokenizer import get_token_bytes
# Configuration
device_batch_size = 32
split_tokens = 20*524288 # number of tokens to evaluate per split
model_tag = None # optional model tag for the output directory name
model_step = None # optional model step for the output directory name
device_type = "" # cuda|cpu|mps (empty => autodetect)
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
split_tokens = 20 * 524288 # number of tokens to evaluate per split
model_tag = None # optional model tag for the output directory name
model_step = None # optional model step for the output directory name
device_type = "" # cuda|cpu|mps (empty => autodetect)
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
# Load the base model and the tokenizer
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
# Evaluate the loss on each split
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
@ -67,13 +72,17 @@ if ddp_rank == 0:
# Log to report
from nanochat.report import get_report
get_report().log(section="Base model loss", data=[
{
"train bpb": bpb_results["train"],
"val bpb": bpb_results["val"],
},
{f"sample {i}": sample for i, sample in enumerate(samples)},
])
get_report().log(
section="Base model loss",
data=[
{
"train bpb": bpb_results["train"],
"val bpb": bpb_results["val"],
},
{f"sample {i}": sample for i, sample in enumerate(samples)},
],
)
# Cleanup
compute_cleanup()

View File

@ -12,67 +12,83 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 -
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
from contextlib import nullcontext
import wandb
import torch
import wandb
from nanochat.gpt import GPT, GPTConfig
from nanochat.checkpoint_manager import load_checkpoint, save_checkpoint
from nanochat.common import (
DummyWandb,
autodetect_device_type,
compute_cleanup,
compute_init,
get_base_dir,
print0,
print_banner,
)
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.gpt import GPT, GPTConfig
from nanochat.loss_eval import evaluate_bpb
from nanochat.tokenizer import get_token_bytes, get_tokenizer
from scripts.base_eval import evaluate_model
print_banner()
# -----------------------------------------------------------------------------
# User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# Runtime
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
# Model architecture
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
# Training horizon. Only one of these 3 will be used, in this order of precedence.
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
target_flops = (
-1.0
) # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
target_param_data_ratio = (
20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
)
# Optimization
device_batch_size = 32 # per-device batch size (set to not OOM)
total_batch_size = 524288 # total desired batch size, in #tokens
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
device_batch_size = 32 # per-device batch size (set to not OOM)
total_batch_size = 524288 # total desired batch size, in #tokens
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
# Evaluation
eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_tokens = 20 * 524288 # number of tokens to evaluate val loss on
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
# Output
model_tag = "" # optionally override the model tag for the output checkpoint directory name
model_tag = "" # optionally override the model tag for the output checkpoint directory name
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
@ -88,9 +104,9 @@ print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
num_layers = depth
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
print0(f"num_layers: {num_layers}")
print0(f"model_dim: {model_dim}")
print0(f"num_heads: {num_heads}")
@ -98,8 +114,8 @@ print0(f"num_kv_heads: {num_kv_heads}")
# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
@ -110,7 +126,14 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {
# Initialize the Model
# Create a new model with random weights
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
model_config_kwargs = dict(
sequence_len=max_seq_len,
vocab_size=vocab_size,
n_layer=num_layers,
n_head=num_heads,
n_kv_head=num_kv_heads,
n_embd=model_dim,
)
with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
@ -119,17 +142,19 @@ model.init_weights()
# If we are resuming, overwrite the model parameters with those of the checkpoint
base_dir = get_base_dir()
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
resuming = resume_from_step != -1
if resuming:
print0(f"Resuming optimization from step {resume_from_step}")
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank)
model_data, optimizer_data, meta_data = load_checkpoint(
checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank
)
model.load_state_dict(model_data, strict=True, assign=True)
del model_data # free up this memory after the copy
del model_data # free up this memory after the copy
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
num_flops_per_token = model.estimate_flops()
@ -152,30 +177,37 @@ else:
raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay
)
adamw_optimizer, muon_optimizer = optimizers
if resuming:
for opt, dat in zip(optimizers, optimizer_data):
opt.load_state_dict(dat)
del optimizer_data # free up the memory
del optimizer_data # free up the memory
# -----------------------------------------------------------------------------
# Initialize the DataLoaders for train/val
tokens_dir = os.path.join(base_dir, "tokenized_data")
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
train_loader = tokenizing_distributed_data_loader_with_state(
device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict
)
build_val_loader = lambda: tokenizing_distributed_data_loader(
device_batch_size, max_seq_len, split="val", device=device
)
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
# Learning rate scheduler
def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations)
@ -188,20 +220,22 @@ def get_lr_multiplier(it):
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
# -----------------------------------------------------------------------------
# Loop state (variables updated by the training loop)
if not resuming:
step = 0
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
total_training_time = 0 # total wall-clock time of training
smooth_train_loss = 0 # EMA of training loss
total_training_time = 0 # total wall-clock time of training
else:
step = meta_data["step"]
loop_state = meta_data["loop_state"]
@ -212,7 +246,7 @@ else:
# -----------------------------------------------------------------------------
# Training loop
while True:
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
flops_so_far = num_flops_per_token * total_batch_size * step
# once in a while: evaluate the val bpb (all ranks participate)
@ -225,12 +259,14 @@ while True:
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
})
wandb_run.log(
{
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
}
)
model.train()
# once in a while: estimate the CORE metric (all ranks participate)
@ -241,12 +277,14 @@ while True:
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"core_metric": results["core_metric"],
"centered_results": results["centered_results"],
})
wandb_run.log(
{
"step": step,
"total_training_flops": flops_so_far,
"core_metric": results["core_metric"],
"centered_results": results["centered_results"],
}
)
model.train()
# once in a while: sample from the model (only on master process)
@ -262,7 +300,7 @@ while True:
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
@ -275,17 +313,17 @@ while True:
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(), # model parameters
[opt.state_dict() for opt in optimizers], # optimizer states
{ # metadata saved as json
orig_model.state_dict(), # model parameters
[opt.state_dict() for opt in optimizers], # optimizer states
{ # metadata saved as json
"step": step,
"val_bpb": val_bpb, # loss at last step
"val_bpb": val_bpb, # loss at last step
"model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script
"user_config": user_config, # inputs to the training script
"device_batch_size": device_batch_size,
"max_seq_len": max_seq_len,
"dataloader_state_dict": dataloader_state_dict,
"loop_state": { # all loop state (other than step) so that we can resume training
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,
"smooth_train_loss": smooth_train_loss,
"total_training_time": total_training_time,
@ -306,15 +344,17 @@ while True:
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
x, y, dataloader_state_dict = next(
train_loader
) # prefetch the next batch while the GPU is busy with forward/backward
# gradient clipping
grad_clip_enabled = grad_clip > 0.0
if grad_clip_enabled:
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
# step the optimizers
lrm = get_lr_multiplier(step)
for opt in optimizers:
@ -332,18 +372,20 @@ while True:
# -------------------------------------------------------------------------
# logging
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
total_training_time += dt # only count the time after the first 10 steps
print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else ""
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
print0(
f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time / 60:.2f}m"
)
if step % 100 == 0:
log_data = {
"step": step,
@ -364,35 +406,39 @@ while True:
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Total training time: {total_training_time / 60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
from nanochat.report import get_report
get_report().log(section="Base model training", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of parameters": num_params,
"Number of FLOPs per token": f"{num_flops_per_token:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
"DDP world size": ddp_world_size,
"warmup_ratio": warmup_ratio,
"warmdown_ratio": warmdown_ratio,
"final_lr_frac": final_lr_frac,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
"Final validation bpb": val_bpb,
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
}
])
get_report().log(
section="Base model training",
data=[
user_config, # CLI args
{ # stats about the training setup
"Number of parameters": num_params,
"Number of FLOPs per token": f"{num_flops_per_token:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
"DDP world size": ddp_world_size,
"warmup_ratio": warmup_ratio,
"warmdown_ratio": warmdown_ratio,
"final_lr_frac": final_lr_frac,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
"Final validation bpb": val_bpb,
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time / 60:.2f}m",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
},
],
)
# cleanup
wandb_run.finish() # wandb run finish
wandb_run.finish() # wandb run finish
compute_cleanup()

View File

@ -4,12 +4,15 @@ New and upgraded chat mode because a lot of the code has changed since the last
Intended to be run single GPU only atm:
python -m scripts.chat_cli -i mid
"""
import argparse
import torch
from nanochat.common import compute_init, autodetect_device_type
from contextlib import nullcontext
from nanochat.engine import Engine
import torch
from nanochat.checkpoint_manager import load_model
from nanochat.common import autodetect_device_type, compute_init
from nanochat.engine import Engine
parser = argparse.ArgumentParser(description='Chat with the model')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
@ -18,7 +21,13 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument(
'--device-type',
type=str,
default='',
choices=['cuda', 'cpu', 'mps'],
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
)
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
args = parser.parse_args()
@ -33,7 +42,10 @@ model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag
# Special tokens for the chat state machine
bos = tokenizer.get_bos_token_id()
user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>")
assistant_start, assistant_end = (
tokenizer.encode_special("<|assistant_start|>"),
tokenizer.encode_special("<|assistant_end|>"),
)
# Create Engine for efficient generation
engine = Engine(model, tokenizer)
@ -47,7 +59,6 @@ print("-" * 50)
conversation_tokens = [bos]
while True:
if args.prompt:
# Get the prompt from the launch command
user_input = args.prompt
@ -89,7 +100,7 @@ while True:
print("\nAssistant: ", end="", flush=True)
with autocast_ctx:
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
token = token_column[0] # pop the batch dimension (num_samples=1)
token = token_column[0] # pop the batch dimension (num_samples=1)
response_tokens.append(token)
token_text = tokenizer.decode([token])
print(token_text, end="", flush=True)

View File

@ -9,27 +9,28 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
"""
import argparse
from functools import partial
from contextlib import nullcontext
from functools import partial
import torch
import torch.distributed as dist
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, get_dist_info, print0
from nanochat.engine import Engine
from tasks.humaneval import HumanEval
from tasks.mmlu import MMLU
from tasks.arc import ARC
from tasks.gsm8k import GSM8K
from tasks.humaneval import HumanEval
from tasks.mmlu import MMLU
from tasks.spellingbee import SpellingBee
# -----------------------------------------------------------------------------
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None):
def run_generative_eval(
task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None
):
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
device = model.get_device()
@ -62,7 +63,7 @@ def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_
num_passed += int(passed)
# Logging (overwrite the same line in the console)
print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100*num_passed/total:.2f}%)", end='', flush=True)
print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100 * num_passed / total:.2f}%)", end='', flush=True)
# Finish the in-place progress line with a newline before final summary
print()
@ -77,21 +78,22 @@ def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_
total = total_tensor.item()
print0("=" * 50)
print0(f"Final: {num_passed}/{total} ({100*num_passed/total:.2f}%)")
print0(f"Final: {num_passed}/{total} ({100 * num_passed / total:.2f}%)")
# Return the accuracy
return num_passed/total
return num_passed / total
# -----------------------------------------------------------------------------
# Categorical evaluation loop
# A lot easier because we don't have to sample. Therefore, we can actually go
# batches at a time and just check the logits for correct answer choices.
def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
device = model.get_device()
bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
# We'll process batches of independent problems at a time because there is no sampling needed
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
@ -99,22 +101,26 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
num_batches = ceil_div(num_problems, batch_size)
# Run the evaluation
letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work
letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work
num_passed, total = 0, 0
for i in range(ddp_rank, num_batches, ddp_world_size):
i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems)
# Prepare the batch of problems. They might all be of different length, so we pad/collate them.
conversations = [task_object[ii] for ii in range(i0, i1)]
prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works
prompt_ids = [
tokenizer.render_for_completion(conversation) for conversation in conversations
] # TODO: remake the way this works
max_length = max(len(ids) for ids in prompt_ids)
answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer)
answer_time_positions = [
len(ids) - 1 for ids in prompt_ids
] # where the last token is (and the predicted answer)
padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids]
prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device)
# Get the logits for the whole batch of conversations in parallel (efficiency win here)
with torch.no_grad():
logits = model(prompt_ids) # (B, T, V)
logits = model(prompt_ids) # (B, T, V)
# Focus on the available answer on just the letters corresponding to choices
# Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters
@ -150,15 +156,26 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
num_passed = num_passed_tensor.item()
total = total_tensor.item()
average = num_passed/total
print0(f"Final: {num_passed}/{total} ({100*average:.2f}%)")
average = num_passed / total
print0(f"Final: {num_passed}/{total} ({100 * average:.2f}%)")
return average
# -----------------------------------------------------------------------------
def run_chat_eval(task_name, model, tokenizer, engine,
batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50,
max_problems=None):
def run_chat_eval(
task_name,
model,
tokenizer,
engine,
batch_size=1,
num_samples=1,
max_new_tokens=512,
temperature=0.0,
top_k=50,
max_problems=None,
):
# Create the evaluation object
task_module = {
'HumanEval': HumanEval,
@ -171,20 +188,36 @@ def run_chat_eval(task_name, model, tokenizer, engine,
task_object = task_module()
# Run the evaluation
if task_object.eval_type == 'generative':
acc = run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=max_problems)
acc = run_generative_eval(
task_object,
tokenizer,
model,
engine,
num_samples,
max_new_tokens,
temperature,
top_k,
max_problems=max_problems,
)
elif task_object.eval_type == 'categorical':
acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems)
else:
raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}")
return acc
# -----------------------------------------------------------------------------
if __name__ == "__main__":
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl")
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
parser.add_argument(
'-a',
'--task-name',
type=str,
default=None,
help="Task name. Default = all tasks. Use | to split multiple tasks.",
)
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('-t', '--temperature', type=float, default=0.0)
parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
@ -194,13 +227,21 @@ if __name__ == "__main__":
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument(
'--device-type',
type=str,
default='',
choices=['cuda', 'cpu', 'mps'],
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
)
args = parser.parse_args()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
)
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
engine = Engine(model, tokenizer)
@ -208,12 +249,12 @@ if __name__ == "__main__":
# Get the tasks to evaluate on
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
baseline_accuracies = {
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
'GSM8K': 0.0, # open-ended => 0%
'HumanEval': 0.0, # open-ended => 0%
'SpellingBee': 0.0, # open-ended => 0%
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
'GSM8K': 0.0, # open-ended => 0%
'HumanEval': 0.0, # open-ended => 0%
'SpellingBee': 0.0, # open-ended => 0%
}
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
@ -223,7 +264,9 @@ if __name__ == "__main__":
with autocast_ctx:
acc = run_chat_eval(
task_name,
model, tokenizer, engine,
model,
tokenizer,
engine,
batch_size=args.batch_size,
num_samples=args.num_samples,
max_new_tokens=args.max_new_tokens,
@ -236,6 +279,7 @@ if __name__ == "__main__":
# Log to report
from nanochat.report import get_report
all_tasks_were_evaluated = all(task_name in results for task_name in all_tasks)
# calculate the ChatCORE metric if we can (similar to CORE, it's the mean centered accuracy)
# this way, ChatCORE ranges from 0 (at random baseline) to 1 (peak performance)
@ -248,10 +292,13 @@ if __name__ == "__main__":
centered_mean += centered_acc
chatcore_metric = centered_mean / len(results)
chatcore_metric_dict = {"ChatCORE metric": chatcore_metric}
get_report().log(section="Chat evaluation " + args.source, data=[
vars(args), # CLI args
results,
chatcore_metric_dict,
])
get_report().log(
section="Chat evaluation " + args.source,
data=[
vars(args), # CLI args
results,
chatcore_metric_dict,
],
)
compute_cleanup()

View File

@ -16,46 +16,46 @@ python -m scripts.chat_rl
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
"""
import os
import itertools
import re
import wandb
import os
import torch
import torch.distributed as dist
import wandb
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
from nanochat.checkpoint_manager import save_checkpoint, load_model
from nanochat.checkpoint_manager import load_model, save_checkpoint
from nanochat.common import DummyWandb, compute_cleanup, compute_init, get_base_dir, print0
from nanochat.engine import Engine
from tasks.gsm8k import GSM8K
# RL hyperparameters
run = "dummy" # wandb run name
source = "sft" # mid|sft
run = "dummy" # wandb run name
source = "sft" # mid|sft
dtype = "bfloat16"
device_batch_size = 8 # no forward pass will go above this to not OOM
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
num_samples = 16 # number of samples per example (/question)
device_batch_size = 8 # no forward pass will go above this to not OOM
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
num_samples = 16 # number of samples per example (/question)
max_new_tokens = 256
temperature = 1.0
top_k = 50 # TODO: try None?
top_k = 50 # TODO: try None?
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
weight_decay = 0.0
init_lr_frac = 0.05
num_epochs = 1 # how many epochs of gsm8k to train on
save_every = 60 # every how many steps to save the model
eval_every = 60 # every how many steps to evaluate the model for val pass@k
eval_examples = 400 # number of examples used for evaluating pass@k
num_epochs = 1 # how many epochs of gsm8k to train on
save_every = 60 # every how many steps to save the model
eval_every = 60 # every how many steps to evaluate the model for val pass@k
eval_examples = 400 # number of examples used for evaluating pass@k
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
# Init compute/precision
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
@ -65,7 +65,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl
# Init model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="eval")
engine = Engine(model, tokenizer) # for sampling rollouts
engine = Engine(model, tokenizer) # for sampling rollouts
# -----------------------------------------------------------------------------
# Rollout / sampling generator loop that yields batches of examples for training
@ -75,12 +75,16 @@ val_task = GSM8K(subset="main", split="test")
num_steps = (len(train_task) // examples_per_step) * num_epochs
print0(f"Calculated number of steps: {num_steps}")
@torch.no_grad()
def get_batch():
assistant_end = tokenizer.encode_special("<|assistant_end|>") # ok to use this token, it's only for padding and isn't used in the loss.
rank_indices = range(ddp_rank, len(train_task), ddp_world_size) # each rank is responsible for different examples in the training data
assistant_end = tokenizer.encode_special(
"<|assistant_end|>"
) # ok to use this token, it's only for padding and isn't used in the loss.
rank_indices = range(
ddp_rank, len(train_task), ddp_world_size
) # each rank is responsible for different examples in the training data
for example_idx in itertools.cycle(rank_indices):
# First get the full conversation of both user and assistant messages
conversation = train_task[example_idx]
@ -90,12 +94,12 @@ def get_batch():
prefix_length = len(tokens)
# Generate num_samples samples using batched generation, use loop to avoid OOMs
model.eval() # ensure the model is in eval mode
model.eval() # ensure the model is in eval mode
generated_token_sequences = []
masks = []
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
for sampling_step in range(num_sampling_steps):
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
with autocast_ctx:
generated_token_sequences_batch, masks_batch = engine.generate_batch(
tokens,
@ -103,7 +107,7 @@ def get_batch():
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
seed=seed, # must make sure to change the seed for each sampling step
seed=seed, # must make sure to change the seed for each sampling step
)
generated_token_sequences.extend(generated_token_sequences_batch)
masks.extend(masks_batch)
@ -121,15 +125,17 @@ def get_batch():
# Pad the sequences so that their lengths (in time) match
max_length = max(len(seq) for seq in generated_token_sequences)
padded_generated_token_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences]
padded_generated_token_sequences = [
seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences
]
padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
# Stack up the sequences and masks into PyTorch tensors
ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
# Generate autoregressive inputs and targets to the Transformer
inputs = ids[:, :-1]
targets = ids[:, 1:].clone() # clone to avoid in-place modification:
targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
targets = ids[:, 1:].clone() # clone to avoid in-place modification:
targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
# NOTE also that the Engine returns mask=0 for BOTH the prompt tokens AND the tool use tokens.
# So we will (correctly) end up not training on the prompt tokens, or the tool use forced tokens.
rewards = torch.tensor(rewards, dtype=torch.float, device=device)
@ -139,14 +145,11 @@ def get_batch():
# yield inputs/targets as (B, T) of ids and rewards as (B,) of floats
yield generated_token_sequences, inputs, targets, rewards, advantages
# -----------------------------------------------------------------------------
# Simple evaluation loop for GSM8K pass@k
def run_gsm8k_eval(task, tokenizer, engine,
max_examples=None,
num_samples=1,
max_completion_tokens=256,
temperature=0.0,
top_k=50
def run_gsm8k_eval(
task, tokenizer, engine, max_examples=None, num_samples=1, max_completion_tokens=256, temperature=0.0, top_k=50
):
"""
Evaluates GSM8K task and returns a list of records of evaluation outcomes.
@ -160,13 +163,9 @@ def run_gsm8k_eval(task, tokenizer, engine,
tokens = tokenizer.render_for_completion(conversation)
prefix_length = len(tokens)
# Generate k samples using batched generation inside the Engine
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
generated_token_sequences, masks = engine.generate_batch(
tokens,
num_samples=num_samples,
max_tokens=max_completion_tokens,
temperature=temperature,
top_k=top_k
tokens, num_samples=num_samples, max_tokens=max_completion_tokens, temperature=temperature, top_k=top_k
)
# Check each sample for correctness
outcomes = []
@ -174,9 +173,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
generated_tokens = sample_tokens[prefix_length:]
generated_text = tokenizer.decode(generated_tokens)
is_correct = task.evaluate(conversation, generated_text)
outcomes.append({
"is_correct": is_correct
})
outcomes.append({"is_correct": is_correct})
# A bit bloated because I wanted to do more complex logging at one point.
record = {
"idx": idx,
@ -184,6 +181,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
}
yield record
# -----------------------------------------------------------------------------
# Training loop
@ -199,44 +197,49 @@ optimizers = model.setup_optimizers(
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Learning rate scheduler: simple rampdown to zero over num_steps
def get_lr_multiplier(it):
lrm = 1.0 - it / num_steps
return lrm
# Calculate the number of examples each rank handles to achieve the desired examples_per_step
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
examples_per_rank = examples_per_step // ddp_world_size # per GPU
examples_per_rank = examples_per_step // ddp_world_size # per GPU
print0(f"Calculated examples per rank: {examples_per_rank}")
# Kick off the training loop
batch_iterator = get_batch()
for step in range(num_steps):
# Evaluate the model once in a while and log to wandb
if step % eval_every == 0:
model.eval()
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
with autocast_ctx:
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0)
records = list(records_iter) # collect all records
records_iter = run_gsm8k_eval(
val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0
)
records = list(records_iter) # collect all records
for k in range(1, device_batch_size + 1):
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
if ddp:
dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
passk = passk / num_records.item() # normalize by the total number of records
passk = passk / num_records.item() # normalize by the total number of records
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
print0(f"Step {step} | {', '.join(print_passk)}")
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
wandb_run.log({
"step": step,
**log_passk,
})
wandb_run.log(
{
"step": step,
**log_passk,
}
)
# Forward/Backward on rollouts over multiple examples in the dataset
rewards_list = []
@ -245,7 +248,7 @@ for step in range(num_steps):
# Get one batch corresponding to one example in the training dataset
sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
# Evaluate the loss and gradients
model.train() # ensure the model is in train mode
model.train() # ensure the model is in train mode
# We need one more loop because we can never exceed the device_batch_size
assert inputs_all.size(0) % device_batch_size == 0
num_passes = inputs_all.size(0) // device_batch_size
@ -258,7 +261,7 @@ for step in range(num_steps):
advantages = advantages_all[b0:b1]
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
with autocast_ctx:
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
# normalize by the number of valid tokens, number of passes, and examples_per_rank
@ -268,7 +271,9 @@ for step in range(num_steps):
# Finally, formulate the loss that we want to minimize (instead of objective we wish to maximize)
loss = -pg_obj
loss.backward()
print0(f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}")
print0(
f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}"
)
# For logging
rewards_list.append(rewards_all.mean().item())
sequence_lengths.extend(len(seq) for seq in sequences_all)
@ -276,56 +281,66 @@ for step in range(num_steps):
# A bunch of logging for how the rollouts went this step
mean_reward = sum(rewards_list) / len(rewards_list)
mean_sequence_length = sum(sequence_lengths) / len(sequence_lengths)
if ddp: # aggregate across ranks
if ddp: # aggregate across ranks
mean_reward_tensor = torch.tensor(mean_reward, dtype=torch.float, device=device)
mean_sequence_length_tensor = torch.tensor(mean_sequence_length, dtype=torch.float, device=device)
dist.all_reduce(mean_reward_tensor, op=dist.ReduceOp.AVG)
dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG)
mean_reward = mean_reward_tensor.item()
mean_sequence_length = mean_sequence_length_tensor.item()
print0(f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}")
wandb_run.log({
"step": step,
"reward": mean_reward,
"sequence_length": mean_sequence_length,
})
print0(
f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}"
)
wandb_run.log(
{
"step": step,
"reward": mean_reward,
"sequence_length": mean_sequence_length,
}
)
# Update the model parameters
lrm = get_lr_multiplier(step)
for opt in optimizers: # first set the learning rate
for opt in optimizers: # first set the learning rate
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
for opt in optimizers: # then step the optimizers
for opt in optimizers: # then step the optimizers
opt.step()
model.zero_grad(set_to_none=True)
wandb_run.log({
"step": step,
"lrm": lrm,
})
wandb_run.log(
{
"step": step,
"lrm": lrm,
}
)
# Master process saves the model once in a while. Skip first step. Save last step.
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
base_dir = get_base_dir()
depth = model.config.n_layer
model_tag = f"d{depth}" # base the model tag on the depth of the base model
model_tag = f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(
checkpoint_dir,
step,
model.state_dict(),
None, # note: we don't bother to save the optimizer state
None, # note: we don't bother to save the optimizer state
{
"model_config": model_config_kwargs,
}
},
)
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
# Log to report
from nanochat.report import get_report
get_report().log(section="Chat RL", data=[
user_config, # CLI args
])
wandb_run.finish() # wandb run finish
get_report().log(
section="Chat RL",
data=[
user_config, # CLI args
],
)
wandb_run.finish() # wandb run finish
compute_cleanup()

View File

@ -10,40 +10,40 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import wandb
import torch
import torch.distributed as dist
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.checkpoint_manager import save_checkpoint
import torch
import torch.distributed as dist
import wandb
from nanochat.checkpoint_manager import load_model, save_checkpoint
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval
from tasks.common import TaskMixture
from tasks.arc import ARC
from tasks.common import TaskMixture
from tasks.customjson import CustomJSON
from tasks.gsm8k import GSM8K
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
# SFT Hyperparameters
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# input model options
source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
# compute/precision
device_type = "" # cuda|cpu|mps (empty => autodetect)
device_type = "" # cuda|cpu|mps (empty => autodetect)
dtype = "bfloat16"
device_batch_size = 4 # max to avoid OOM
device_batch_size = 4 # max to avoid OOM
# optimization
num_epochs = 1
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
target_examples_per_step = 32
unembedding_lr = 0.004
embedding_lr = 0.2
@ -56,9 +56,9 @@ eval_steps = 100
eval_metrics_every = 200
eval_metrics_max_problems = 1024
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
# -----------------------------------------------------------------------------
# Compute init
@ -70,52 +70,63 @@ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if dev
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
wandb_run = (
DummyWandb()
if use_dummy_wandb
else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
)
# Load the model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
orig_model = model # original, uncompiled model
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
# -----------------------------------------------------------------------------
# Task data mixture we'll train on
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
train_ds = TaskMixture([
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
train_ds = TaskMixture(
[
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]
) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
# -----------------------------------------------------------------------------
# DataLoader
def sft_data_generator(dataset, batch_size):
pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
pad_token_id = tokenizer.encode_special(
"<|assistant_end|>"
) # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
# prepares a list of tokenized conversations into a batch and yields
def collate_and_yield(batch):
nrows = len(batch)
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
for i, (ids, mask) in enumerate(batch):
n = len(ids)
ids_tensor = torch.tensor(ids, dtype=torch.long)
inputs[i, :n-1] = ids_tensor[:-1]
inputs[i, : n - 1] = ids_tensor[:-1]
# recall -1 is the ignore index, so mask out targets where mask is 0
row_targets = ids_tensor[1:]
# mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0
targets[i, :n-1] = row_targets
inputs = inputs.to(device) # move to device
row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0
targets[i, : n - 1] = row_targets
inputs = inputs.to(device) # move to device
targets = targets.to(device)
return inputs, targets
# iterates over the dataset in epochs, tokenizes
batch = []
while True:
@ -127,11 +138,14 @@ def sft_data_generator(dataset, batch_size):
yield collate_and_yield(batch)
batch = []
examples_per_step = device_batch_size * ddp_world_size
print0(f"Target examples per step: {target_examples_per_step}")
print0(f"Device batch size: {device_batch_size}")
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
assert target_examples_per_step % examples_per_step == 0, (
"Target examples per step must be divisible by examples per step"
)
grad_accum_steps = target_examples_per_step // examples_per_step
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
@ -155,16 +169,18 @@ optimizers = model.setup_optimizers(
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# -----------------------------------------------------------------------------
# Training loop
# Learning rate scheduler
def get_lr_multiplier(it):
lrm = 1.0 - it / num_iterations
return lrm
# Go!
step = 0
train_iter = iter(train_loader)
@ -181,15 +197,17 @@ for step in range(num_iterations):
with torch.no_grad(), autocast_ctx:
loss = model(val_inputs, val_targets)
losses.append(loss)
val_loss = torch.stack(losses).mean() # average over eval_steps
val_loss = torch.stack(losses).mean() # average over eval_steps
if ddp:
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
val_loss = val_loss.item()
print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}")
wandb_run.log({
"step": step,
"val_loss": val_loss,
})
wandb_run.log(
{
"step": step,
"val_loss": val_loss,
}
)
model.train()
# evaluate accuracy of the multiple choice tasks (which are quick to run)
@ -198,31 +216,47 @@ for step in range(num_iterations):
metrics = {}
with torch.no_grad(), autocast_ctx:
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics["mmlu_acc"] = run_chat_eval(
"MMLU",
model,
tokenizer,
engine,
batch_size=device_batch_size * 2,
max_problems=eval_metrics_max_problems,
)
metrics["arc_easy_acc"] = run_chat_eval(
"ARC-Easy",
model,
tokenizer,
engine,
batch_size=device_batch_size * 2,
max_problems=eval_metrics_max_problems,
)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({
"step": step,
**metrics,
})
wandb_run.log(
{
"step": step,
**metrics,
}
)
model.train()
if last_step:
break
# evaluate the gradient
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
for micro_step in range(grad_accum_steps):
train_inputs, train_targets = next(train_iter)
with autocast_ctx:
loss = model(train_inputs, train_targets)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() # accumulate the gradient
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() # accumulate the gradient
num_tokens += (train_targets >= 0).sum()
if ddp:
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
# learning rate scheduler
lrm = get_lr_multiplier(step)
@ -238,47 +272,55 @@ for step in range(num_iterations):
# logging
train_loss_item = train_loss.item()
num_tokens_item = num_tokens.item()
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
wandb_run.log({
"step": step,
"lrm": lrm,
"train_loss": train_loss_item,
"num_tokens": num_tokens_item,
})
print0(
f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}"
)
wandb_run.log(
{
"step": step,
"lrm": lrm,
"train_loss": train_loss_item,
"num_tokens": num_tokens_item,
}
)
step += 1
# Save the model at the end of the run
if master_process:
base_dir = get_base_dir()
depth = model.config.n_layer
model_tag = f"d{depth}" # base the model tag on the depth of the base model
model_tag = f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(
checkpoint_dir,
step,
model.state_dict(),
None, # note: we don't bother to save the optimizer state
None, # note: we don't bother to save the optimizer state
{
"step": step,
"val_loss": val_loss,
**metrics,
"model_config": model_config_kwargs,
}
},
)
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
# Log to report
from nanochat.report import get_report
get_report().log(section="Chat SFT", data=[
user_config, # CLI args
{
"Training rows": len(train_ds),
"Number of iterations": num_iterations,
"Training loss": train_loss_item,
"Validation loss": val_loss,
},
])
get_report().log(
section="Chat SFT",
data=[
user_config, # CLI args
{
"Training rows": len(train_ds),
"Number of iterations": num_iterations,
"Training loss": train_loss_item,
"Validation loss": val_loss,
},
],
)
# Cleanup
wandb_run.finish()

View File

@ -31,22 +31,23 @@ Abuse Prevention:
"""
import argparse
import json
import os
import torch
import asyncio
import json
import logging
import os
import random
from contextlib import asynccontextmanager
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager, nullcontext
from dataclasses import dataclass
import torch
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass
from contextlib import nullcontext
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.common import autodetect_device_type, compute_init
from nanochat.engine import Engine
# Abuse prevention limits
@ -70,70 +71,70 @@ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument(
'--device-type',
type=str,
default='',
choices=['cuda', 'cpu', 'mps'],
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
)
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args()
# Configure logging for conversation traffic
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
@dataclass
class Worker:
"""A worker with a model loaded on a specific GPU."""
gpu_id: int
device: torch.device
engine: Engine
tokenizer: object
autocast_ctx: torch.amp.autocast
class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None):
def __init__(self, num_gpus: int | None = None):
if num_gpus is None:
if device_type == "cuda":
num_gpus = torch.cuda.device_count()
else:
num_gpus = 1 # e.g. cpu|mps
num_gpus = 1 # e.g. cpu|mps
self.num_gpus = num_gpus
self.workers: List[Worker] = []
self.workers: list[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue()
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
async def initialize(self, source: str, model_tag: str | None = None, step: int | None = None):
"""Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
if self.num_gpus > 1:
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
for gpu_id in range(self.num_gpus):
if device_type == "cuda":
device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...")
else:
device = torch.device(device_type) # e.g. cpu|mps
device = torch.device(device_type) # e.g. cpu|mps
print(f"Loading model on {device_type}...")
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
worker = Worker(
gpu_id=gpu_id,
device=device,
engine=engine,
tokenizer=tokenizer,
autocast_ctx=autocast_ctx
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
)
worker = Worker(gpu_id=gpu_id, device=device, engine=engine, tokenizer=tokenizer, autocast_ctx=autocast_ctx)
self.workers.append(worker)
await self.available_workers.put(worker)
@ -147,15 +148,18 @@ class WorkerPool:
"""Return a worker to the pool."""
await self.available_workers.put(worker)
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_k: Optional[int] = None
messages: list[ChatMessage]
temperature: float | None = None
max_tokens: int | None = None
top_k: int | None = None
def validate_chat_request(request: ChatRequest):
"""Validate chat request to prevent abuse."""
@ -165,7 +169,7 @@ def validate_chat_request(request: ChatRequest):
if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
raise HTTPException(
status_code=400,
detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request"
detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request",
)
# Check individual message lengths and total conversation length
@ -178,48 +182,43 @@ def validate_chat_request(request: ChatRequest):
if msg_length > MAX_MESSAGE_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message"
detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message",
)
total_length += msg_length
if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed"
detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed",
)
# Validate role values
for i, message in enumerate(request.messages):
if message.role not in ["user", "assistant"]:
raise HTTPException(
status_code=400,
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
status_code=400, detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
)
# Validate temperature
if request.temperature is not None:
if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
raise HTTPException(
status_code=400,
detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
status_code=400, detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
)
# Validate top_k
if request.top_k is not None:
if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
raise HTTPException(
status_code=400,
detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
)
raise HTTPException(status_code=400, detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}")
# Validate max_tokens
if request.max_tokens is not None:
if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
raise HTTPException(
status_code=400,
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
status_code=400, detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load models on all GPUs on startup."""
@ -229,6 +228,7 @@ async def lifespan(app: FastAPI):
print(f"Server ready at http://localhost:{args.port}")
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
@ -239,16 +239,16 @@ app.add_middleware(
allow_headers=["*"],
)
@app.get("/")
async def root():
"""Serve the chat UI."""
ui_html_path = os.path.join("nanochat", "ui.html")
with open(ui_html_path, "r", encoding="utf-8") as f:
with open(ui_html_path, encoding="utf-8") as f:
html_content = f.read()
# Replace the API_URL to use the same origin
html_content = html_content.replace(
"const API_URL = `http://${window.location.hostname}:8000`;",
"const API_URL = '';"
"const API_URL = `http://${window.location.hostname}:8000`;", "const API_URL = '';"
)
return HTMLResponse(content=html_content)
@ -259,12 +259,9 @@ async def logo():
logo_path = os.path.join("nanochat", "logo.svg")
return FileResponse(logo_path, media_type="image/svg+xml")
async def generate_stream(
worker: Worker,
tokens,
temperature=None,
max_new_tokens=None,
top_k=None
worker: Worker, tokens, temperature=None, max_new_tokens=None, top_k=None
) -> AsyncGenerator[str, None]:
"""Generate assistant response with streaming."""
temperature = temperature if temperature is not None else args.temperature
@ -286,7 +283,7 @@ async def generate_stream(
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
seed=random.randint(0, 2**31 - 1)
seed=random.randint(0, 2**31 - 1),
):
token = token_column[0]
@ -303,13 +300,14 @@ async def generate_stream(
# This ensures we don't emit incomplete UTF-8 sequences
if not current_text.endswith('<EFBFBD>'):
# Extract only the new text since last clean decode
new_text = current_text[len(last_clean_text):]
new_text = current_text[len(last_clean_text) :]
if new_text: # Only yield if there's new content
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
last_clean_text = current_text
yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
@ -318,10 +316,10 @@ async def chat_completions(request: ChatRequest):
validate_chat_request(request)
# Log incoming conversation to console
logger.info("="*20)
logger.info("=" * 20)
for i, message in enumerate(request.messages):
logger.info(f"[{message.role.upper()}]: {message.content}")
logger.info("-"*20)
logger.info("-" * 20)
# Acquire a worker from the pool (will wait if all are busy)
worker_pool = app.state.worker_pool
@ -350,6 +348,7 @@ async def chat_completions(request: ChatRequest):
# Streaming response with worker release after completion
response_tokens = []
async def stream_and_release():
try:
async for chunk in generate_stream(
@ -357,7 +356,7 @@ async def chat_completions(request: ChatRequest):
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
top_k=request.top_k,
):
# Accumulate response for logging
chunk_data = json.loads(chunk.replace("data: ", "").strip())
@ -368,19 +367,17 @@ async def chat_completions(request: ChatRequest):
# Log the assistant response to console
full_response = "".join(response_tokens)
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
logger.info("="*20)
logger.info("=" * 20)
# Release worker back to pool after streaming is done
await worker_pool.release_worker(worker)
return StreamingResponse(
stream_and_release(),
media_type="text/event-stream"
)
return StreamingResponse(stream_and_release(), media_type="text/event-stream")
except Exception as e:
# Make sure to release worker even on error
await worker_pool.release_worker(worker)
raise e
@app.get("/health")
async def health():
"""Health check endpoint."""
@ -389,9 +386,10 @@ async def health():
"status": "ok",
"ready": worker_pool is not None and len(worker_pool.workers) > 0,
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0,
}
@app.get("/stats")
async def stats():
"""Get worker pool statistics."""
@ -400,16 +398,13 @@ async def stats():
"total_workers": len(worker_pool.workers),
"available_workers": worker_pool.available_workers.qsize(),
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
"workers": [
{
"gpu_id": w.gpu_id,
"device": str(w.device)
} for w in worker_pool.workers
]
"workers": [{"gpu_id": w.gpu_id, "device": str(w.device)} for w in worker_pool.workers],
}
if __name__ == "__main__":
import uvicorn
print(f"Starting NanoChat Web Server")
print("Starting NanoChat Web Server")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
uvicorn.run(app, host=args.host, port=args.port)

View File

@ -9,55 +9,58 @@ Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
"""
from collections import deque
import os
from collections import deque
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
import wandb
import torch
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.checkpoint_manager import load_model
import torch.distributed as dist
import torch
import torch.distributed as dist
import wandb
from nanochat.checkpoint_manager import load_model, save_checkpoint
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
from nanochat.loss_eval import evaluate_bpb
from nanochat.tokenizer import get_token_bytes
from tasks.common import TaskMixture
from tasks.customjson import CustomJSON
from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
device_type = "" # cuda|cpu|mps (empty => autodetect)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
device_type = "" # cuda|cpu|mps (empty => autodetect)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
max_seq_len = 2048
device_batch_size = 32
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
weight_decay = 0.0
eval_every = 150 # -1 = disable
eval_tokens = 20*524288
eval_every = 150 # -1 = disable
eval_tokens = 20 * 524288
total_batch_size = 524288
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
@ -69,13 +72,15 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
print0(
f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?"
)
orig_model = model
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
@ -84,48 +89,58 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {
token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay
)
adamw_optimizer, muon_optimizer = optimizers
# Override the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Midtraining data mixture and DataLoader
base_dir = get_base_dir()
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
]) # total: 24K + 14K + 1.32K ~= 39K rows
train_dataset = TaskMixture(
[
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(
subset="auxiliary_train", split="train"
), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]
) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
val_dataset = TaskMixture(
[
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
]
) # total: 24K + 14K + 1.32K ~= 39K rows
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
# A big problem is that we don't know the final num_iterations in advance. So we create
# these two global variables and update them from within the data generator.
last_step = False # we will toggle this to True when we reach the end of the dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
last_step = False # we will toggle this to True when we reach the end of the dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
def mid_data_generator(split):
global last_step, approx_progress
assert split in {"train", "val"}, "split must be 'train' or 'val'"
dataset = train_dataset if split == "train" else val_dataset
dataset_size = len(dataset)
assert dataset_size > 0
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
token_buffer = deque()
# CUDA supports memory pinning for faster transfers between CPU and GPU:
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
it = 0 # iteration counter
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
it = 0 # iteration counter
while True:
# Accumulate enough tokens for one iteration before yielding
while len(token_buffer) < needed_tokens:
@ -134,49 +149,55 @@ def mid_data_generator(split):
token_buffer.extend(ids)
cursor += ddp_world_size
if cursor >= dataset_size:
cursor -= dataset_size # wrap around for another epoch
cursor -= dataset_size # wrap around for another epoch
if split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop
last_step = True # toggle last_step to True, which will terminate the training loop
# Stopping condition to respect num_iterations, if given
it += 1
if num_iterations > 0 and it >= num_iterations:
last_step = True # toggle last_step to True, which will terminate the training loop
last_step = True # toggle last_step to True, which will terminate the training loop
# Build up inputs/targets and yield
for i in range(needed_tokens):
scratch[i] = token_buffer.popleft()
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(
device=device, dtype=torch.int64, non_blocking=True
)
if split == "train":
if num_iterations > 0:
approx_progress = it / num_iterations # calculate progress from the max number of iterations
approx_progress = it / num_iterations # calculate progress from the max number of iterations
else:
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
yield inputs, targets
train_loader = mid_data_generator("train")
build_val_loader = lambda: mid_data_generator("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
def get_lr_multiplier(progress):
# first 80% of training: no decay, then linearly ramp down to 0.
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
# -----------------------------------------------------------------------------
# Training loop
x, y = next(train_loader) # prefetch the very first batch of data
x, y = next(train_loader) # prefetch the very first batch of data
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
step = 0
while True:
flops_so_far = num_flops_per_token * total_batch_size * step
@ -197,26 +218,28 @@ while True:
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
})
wandb_run.log(
{
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
}
)
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step and not dry_run:
output_dirname = f"d{depth}" # e.g. d12
output_dirname = f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(),
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
{
"step": step,
"val_bpb": val_bpb, # loss at last step
"val_bpb": val_bpb, # loss at last step
"model_config": {
"sequence_len": max_seq_len,
"vocab_size": tokenizer.get_vocab_size(),
@ -225,8 +248,8 @@ while True:
"n_kv_head": model.config.n_kv_head,
"n_embd": model.config.n_embd,
},
"user_config": user_config, # inputs to the training script
}
"user_config": user_config, # inputs to the training script
},
)
if last_step:
@ -240,11 +263,11 @@ while True:
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
progress = max(progress, approx_progress) # only increase progress monotonically
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
progress = max(progress, approx_progress) # only increase progress monotonically
# step the optimizers
lrm = get_lr_multiplier(progress)
for opt in optimizers:
@ -265,47 +288,55 @@ while True:
step += 1
# logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1)) # debias the EMA
pct_done = 100 * progress
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
total_training_time += dt # only count the time after the first 10 steps
print0(
f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time / 60:.2f}m"
)
if step % 10 == 0:
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
})
wandb_run.log(
{
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
}
)
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Total training time: {total_training_time / 60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
if not dry_run:
from nanochat.report import get_report
get_report().log(section="Midtraining", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of iterations": step,
"DDP world size": ddp_world_size,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
}
])
get_report().log(
section="Midtraining",
data=[
user_config, # CLI args
{ # stats about the training setup
"Number of iterations": step,
"DDP world size": ddp_world_size,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
},
],
)
# cleanup
wandb_run.finish() # wandb run finish
wandb_run.finish() # wandb run finish
compute_cleanup()

View File

@ -2,8 +2,8 @@
Evaluate compression ratio of the tokenizer.
"""
from nanochat.tokenizer import get_tokenizer, RustBPETokenizer
from nanochat.dataset import parquets_iter_batched
from nanochat.tokenizer import RustBPETokenizer, get_tokenizer
# Random text I got from a random website this morning
news_text = r"""
@ -165,11 +165,10 @@ tokenizer_results = {}
vocab_sizes = {}
for tokenizer_name in ["gpt2", "gpt4", "ours"]:
if tokenizer_name == "gpt2":
tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer
tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer
elif tokenizer_name == "gpt4":
tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer
tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer
else:
tokenizer = get_tokenizer()
@ -183,11 +182,7 @@ for tokenizer_name in ["gpt2", "gpt4", "ours"]:
encoded_bytes = text.encode('utf-8')
ratio = len(encoded_bytes) / len(encoded)
tokenizer_results[tokenizer_name][name] = {
'bytes': len(encoded_bytes),
'tokens': len(encoded),
'ratio': ratio
}
tokenizer_results[tokenizer_name][name] = {'bytes': len(encoded_bytes), 'tokens': len(encoded), 'ratio': ratio}
# ANSI color codes
GREEN = '\033[92m'
@ -195,11 +190,12 @@ RED = '\033[91m'
RESET = '\033[0m'
# Print vocab sizes
print(f"\nVocab sizes:")
print("\nVocab sizes:")
print(f"GPT-2: {vocab_sizes['gpt2']}")
print(f"GPT-4: {vocab_sizes['gpt4']}")
print(f"Ours: {vocab_sizes['ours']}")
def print_comparison(baseline_name, baseline_results, ours_results, all_text):
"""Print comparison table between baseline tokenizer and ours."""
print(f"\nComparison with {baseline_name}:")
@ -230,13 +226,16 @@ def print_comparison(baseline_name, baseline_results, ours_results, all_text):
better = "Tie"
diff_color = ""
print(f"{name:<10} {baseline_data['bytes']:<8} "
f"{baseline_color}{baseline_data['tokens']:<7}{RESET} "
f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} "
f"{ours_color}{ours_data['tokens']:<7}{RESET} "
f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} "
f"{diff_color}{relative_diff:+7.1f}%{RESET} "
f"{better:<10}")
print(
f"{name:<10} {baseline_data['bytes']:<8} "
f"{baseline_color}{baseline_data['tokens']:<7}{RESET} "
f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} "
f"{ours_color}{ours_data['tokens']:<7}{RESET} "
f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} "
f"{diff_color}{relative_diff:+7.1f}%{RESET} "
f"{better:<10}"
)
# Print comparisons
print_comparison("GPT-2", tokenizer_results['gpt2'], tokenizer_results['ours'], all_text)
@ -244,6 +243,7 @@ print_comparison("GPT-4", tokenizer_results['gpt4'], tokenizer_results['ours'],
# Log to report
from nanochat.report import get_report
lines = []
for baseline_name in ["GPT-2", "GPT-4"]:
baseline_key = baseline_name.lower().replace('-', '')
@ -251,15 +251,26 @@ for baseline_name in ["GPT-2", "GPT-4"]:
ours_results = tokenizer_results['ours']
lines.append(f"### Comparison with {baseline_name}")
lines.append("")
lines.append("| Text Type | Bytes | " + baseline_name + " Tokens | " + baseline_name + " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |")
lines.append(
"| Text Type | Bytes | "
+ baseline_name
+ " Tokens | "
+ baseline_name
+ " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |"
)
lines.append("|-----------|-------|--------------|--------------|-------------|------------|-----------------|")
for name, text in all_text:
baseline_data = baseline_results[name]
ours_data = ours_results[name]
relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100
lines.append(f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |")
lines.append(
f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |"
)
lines.append("")
report_markdown = "\n".join(lines)
get_report().log(section="Tokenizer evaluation", data=[
report_markdown,
])
get_report().log(
section="Tokenizer evaluation",
data=[
report_markdown,
],
)

View File

@ -2,19 +2,24 @@
Train a tokenizer using the HuggingFace Tokenizers library.
In the style of GPT-4 tokenizer.
"""
import argparse
import os
import time
import argparse
import torch
from nanochat.tokenizer import RustBPETokenizer
from nanochat.common import get_base_dir
from nanochat.dataset import parquets_iter_batched
from nanochat.tokenizer import RustBPETokenizer
# -----------------------------------------------------------------------------
# Parse command line arguments
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
parser.add_argument(
'--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)'
)
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
args = parser.parse_args()
@ -25,6 +30,7 @@ print(f"vocab_size: {args.vocab_size:,}")
# -----------------------------------------------------------------------------
# Text iterator
def text_iterator():
"""
1) Flatten the batches into a single iterator
@ -36,11 +42,13 @@ def text_iterator():
for doc in batch:
doc_text = doc
if len(doc_text) > args.doc_cap:
doc_text = doc_text[:args.doc_cap]
doc_text = doc_text[: args.doc_cap]
nchars += len(doc_text)
yield doc_text
if nchars > args.max_chars:
return
text_iter = text_iterator()
# -----------------------------------------------------------------------------
@ -78,11 +86,11 @@ special_set = set(tokenizer.get_special_tokens())
token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)]
token_bytes = []
for token_id in range(vocab_size):
token_str = token_strings[token_id] # the Python string representation of this token
token_str = token_strings[token_id] # the Python string representation of this token
if token_str in special_set:
token_bytes.append(0) # special characters are not counted
token_bytes.append(0) # special characters are not counted
else:
id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
token_bytes.append(id_bytes)
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu')
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
@ -92,15 +100,19 @@ print(f"Saved token_bytes to {token_bytes_path}")
# Log to report
from nanochat.report import get_report
token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32)
get_report().log(section="Tokenizer training", data=[
vars(args), # argparse command line arguments
{"train_time": train_time},
{"num_special_tokens": len(special_set)},
{
"token_bytes_min": int(token_bytes_nonzero.min().item()),
"token_bytes_max": int(token_bytes_nonzero.max().item()),
"token_bytes_mean": token_bytes_nonzero.mean().item(),
"token_bytes_std": token_bytes_nonzero.std().item(),
}
])
get_report().log(
section="Tokenizer training",
data=[
vars(args), # argparse command line arguments
{"train_time": train_time},
{"num_special_tokens": len(special_set)},
{
"token_bytes_min": int(token_bytes_nonzero.min().item()),
"token_bytes_max": int(token_bytes_nonzero.max().item()),
"token_bytes_mean": token_bytes_nonzero.mean().item(),
"token_bytes_std": token_bytes_nonzero.std().item(),
},
],
)

View File

@ -4,10 +4,11 @@ https://huggingface.co/datasets/allenai/ai2_arc
"""
from datasets import load_dataset
from tasks.common import Task, render_mc
class ARC(Task):
class ARC(Task):
def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs)
assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge"
@ -23,26 +24,25 @@ class ARC(Task):
def get_example(self, index):
row = self.ds[index]
question = row["question"] # the question text
choices = row["choices"]["text"] # the text of each choice
answer_string = row["answerKey"] # e.g. "A", "B", "C", "D"
letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"]
assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check
question = row["question"] # the question text
choices = row["choices"]["text"] # the text of each choice
answer_string = row["answerKey"] # e.g. "A", "B", "C", "D"
letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"]
assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check
# create and return the Conversation object
user_message = render_mc(question, letters, choices)
messages = [
{"role": "user", "content": user_message},
{"role": "assistant", "content": answer_string}
]
messages = [{"role": "user", "content": user_message}, {"role": "assistant", "content": answer_string}]
conversation = {
"messages": messages,
"letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
"letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
}
return conversation
def evaluate(self, conversation, assistant_response):
# the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
# I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
assert assistant_response in conversation['letters'], f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}"
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
assert assistant_response in conversation['letters'], (
f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}"
)
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
return assistant_response == assistant_message

View File

@ -7,6 +7,7 @@ Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk.
import random
class Task:
"""
Base class of a Task. Allows for lightweight slicing of the underlying dataset.
@ -18,7 +19,7 @@ class Task:
assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}"
assert step >= 1, f"Step must be strictly positive, got {step}"
self.start = start
self.stop = stop # could be None here
self.stop = stop # could be None here
self.step = step
@property
@ -37,8 +38,8 @@ class Task:
stop = self.num_examples() if self.stop is None else self.stop
step = self.step
span = stop - start
num = (span + step - 1) // step # ceil_div(span, step)
assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns
num = (span + step - 1) // step # ceil_div(span, step)
assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns
return num
def __getitem__(self, index: int):
@ -81,7 +82,9 @@ class TaskMixture(Task):
Access conversations according to a deterministic shuffle of all examples.
This ensures tasks are mixed throughout training, regardless of dataset size.
"""
assert 0 <= index < self.num_conversations, f"Index {index} out of range for mixture with {self.num_conversations} conversations"
assert 0 <= index < self.num_conversations, (
f"Index {index} out of range for mixture with {self.num_conversations} conversations"
)
task_idx, local_idx = self.index_map[index]
return self.tasks[task_idx][local_idx]
@ -102,7 +105,9 @@ class TaskSequence(Task):
return self.num_conversations
def get_example(self, index):
assert 0 <= index < self.num_conversations, f"Index {index} out of range for sequence with {self.num_conversations} conversations"
assert 0 <= index < self.num_conversations, (
f"Index {index} out of range for sequence with {self.num_conversations} conversations"
)
for task_idx, task_length in enumerate(self.lengths):
if index < task_length:
return self.tasks[task_idx][index]

View File

@ -3,10 +3,12 @@ CustomJSON task for loading conversations from JSONL files.
Each line in the JSONL file should be a JSON array of messages.
"""
import os
import json
import os
from tasks.common import Task
class CustomJSON(Task):
"""
Load conversations from a JSONL file.
@ -25,14 +27,18 @@ class CustomJSON(Task):
print("-" * 80)
print(f"Warning: File {filepath} does not exist")
print("HINT (Oct 21 2025)")
print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations")
print(
"If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations"
)
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
print("Quick fix: simply run the following command to download the file and you're done:")
print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl")
print(
f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl"
)
print("-" * 80)
else:
with open(filepath, 'r', encoding='utf-8') as f:
with open(filepath, encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line: # skip empty lines
@ -46,7 +52,9 @@ class CustomJSON(Task):
assert "role" in message, f"Message {i} missing 'role' field"
assert "content" in message, f"Message {i} missing 'content' field"
expected_role = "user" if i % 2 == 0 else "assistant"
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
assert message["role"] == expected_role, (
f"Message {i} has role {message['role']} but should be {expected_role}"
)
assert isinstance(message["content"], str), f"Message {i} content must be a string"
self.conversations.append(messages)
@ -62,4 +70,3 @@ class CustomJSON(Task):
"messages": messages,
}
return conversation

View File

@ -15,11 +15,14 @@ Notice that GSM8K uses tool calls inside << >> tags.
"""
import re
from datasets import load_dataset
from tasks.common import Task
GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion):
"""
Extract the numerical answer after #### marker.
@ -35,7 +38,6 @@ def extract_answer(completion):
class GSM8K(Task):
def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs)
assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic"
@ -50,10 +52,10 @@ class GSM8K(Task):
return len(self.ds)
def get_example(self, index):
""" Get a single problem from the dataset. """
"""Get a single problem from the dataset."""
row = self.ds[index]
question = row['question'] # string of the question prompt
answer = row['answer'] # string of the full solution and the answer after #### marker
question = row['question'] # string of the question prompt
answer = row['answer'] # string of the full solution and the answer after #### marker
# Create and return the Conversation object
# This is tricky because GSM8K uses tool calls, which we need to parse here.
assistant_message_parts = []
@ -76,8 +78,8 @@ class GSM8K(Task):
assistant_message_parts.append({"type": "text", "text": part})
# No put it all together
messages = [
{"role": "user", "content": question}, # note: simple string
{"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts)
{"role": "user", "content": question}, # note: simple string
{"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts)
]
conversation = {
"messages": messages,
@ -99,7 +101,7 @@ class GSM8K(Task):
assistant_message = conversation['messages'][-1]
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K
last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K
# Extract both the ground truth answer and the predicted answer
ref_num = extract_answer(last_text_part)
pred_num = extract_answer(assistant_response)

View File

@ -5,10 +5,13 @@ It is a coding benchmark.
"""
import re
from datasets import load_dataset
from nanochat.execution import execute_code
from tasks.common import Task
def extract_imports(prompt):
"""Extract import statements from the beginning of a code block."""
imports = []
@ -21,6 +24,7 @@ def extract_imports(prompt):
break
return '\n'.join(imports)
def extract_program(completion):
"""
Extract Python code from LLM completion.
@ -44,8 +48,8 @@ def extract_program(completion):
# No code blocks found, return the whole completion
return completion.strip()
class HumanEval(Task):
class HumanEval(Task):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.ds = load_dataset("openai/openai_humaneval", split="test").shuffle(seed=42)
@ -58,12 +62,12 @@ class HumanEval(Task):
return len(self.ds)
def get_example(self, index):
""" Get a single problem from the dataset. """
"""Get a single problem from the dataset."""
row = self.ds[index]
prompt = row['prompt'] # prompts in HumanEval are the beginning of the program
solution = row['canonical_solution'] # the correct continuation of the program
entry_point = row['entry_point'] # the function to check
test = row['test'] # the test cases
prompt = row['prompt'] # prompts in HumanEval are the beginning of the program
solution = row['canonical_solution'] # the correct continuation of the program
entry_point = row['entry_point'] # the function to check
test = row['test'] # the test cases
complete_solution = f"{prompt}\n{solution}"
messages = [
{"role": "user", "content": prompt},
@ -71,13 +75,13 @@ class HumanEval(Task):
]
conversation = {
"messages": messages,
"entry_point": entry_point, # needed during evaluation
"test": test, # needed during evaluation
"entry_point": entry_point, # needed during evaluation
"test": test, # needed during evaluation
}
return conversation
def evaluate(self, conversation, completion):
""" Given (conversation, completion), return boolean success of the completion. """
"""Given (conversation, completion), return boolean success of the completion."""
# the prompt will contain the imports and the function signature
imports = extract_imports(conversation['messages'][0]['content'])
# the completion will usually contain the whole function

View File

@ -4,12 +4,71 @@ https://huggingface.co/datasets/cais/mmlu
"""
from datasets import load_dataset
from tasks.common import Task, render_mc
class MMLU(Task):
class MMLU(Task):
letters = ('A', 'B', 'C', 'D')
groups = ('abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions')
groups = (
'abstract_algebra',
'anatomy',
'astronomy',
'business_ethics',
'clinical_knowledge',
'college_biology',
'college_chemistry',
'college_computer_science',
'college_mathematics',
'college_medicine',
'college_physics',
'computer_security',
'conceptual_physics',
'econometrics',
'electrical_engineering',
'elementary_mathematics',
'formal_logic',
'global_facts',
'high_school_biology',
'high_school_chemistry',
'high_school_computer_science',
'high_school_european_history',
'high_school_geography',
'high_school_government_and_politics',
'high_school_macroeconomics',
'high_school_mathematics',
'high_school_microeconomics',
'high_school_physics',
'high_school_psychology',
'high_school_statistics',
'high_school_us_history',
'high_school_world_history',
'human_aging',
'human_sexuality',
'international_law',
'jurisprudence',
'logical_fallacies',
'machine_learning',
'management',
'marketing',
'medical_genetics',
'miscellaneous',
'moral_disputes',
'moral_scenarios',
'nutrition',
'philosophy',
'prehistory',
'professional_accounting',
'professional_law',
'professional_medicine',
'professional_psychology',
'public_relations',
'security_studies',
'sociology',
'us_foreign_policy',
'virology',
'world_religions',
)
def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs)
@ -33,28 +92,27 @@ class MMLU(Task):
def get_example(self, index):
row = self.ds[index]
question = row["question"] # the question text
choices = row["choices"] # the text of each choice
answer = row["answer"] # index of the answer, e.g. 0,1,2,3 (for A,B,C,D)
subject = row["subject"] # e.g. "college_biology", "college_chemistry", etc.
question = row["question"] # the question text
choices = row["choices"] # the text of each choice
answer = row["answer"] # index of the answer, e.g. 0,1,2,3 (for A,B,C,D)
subject = row["subject"] # e.g. "college_biology", "college_chemistry", etc.
assert len(choices) == 4, "MMLU should have 4 choices"
# create and return the Conversation object
user_message = render_mc(question, self.letters, choices)
assistant_message = self.letters[answer]
messages = [
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_message}
]
messages = [{"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_message}]
conversation = {
"messages": messages,
"subject": subject, # might be useful later for grouping metrics by subject
"letters": self.letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
"subject": subject, # might be useful later for grouping metrics by subject
"letters": self.letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
}
return conversation
def evaluate(self, conversation, assistant_response):
# the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
# I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
assert assistant_response in self.letters, f"MMLU answer {assistant_response} is expected to be one of {self.letters}"
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
assert assistant_response in self.letters, (
f"MMLU answer {assistant_response} is expected to be one of {self.letters}"
)
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
return assistant_response == assistant_message

View File

@ -5,10 +5,12 @@ We use the "smol" version, which is more appropriate for smaller models.
"""
from datasets import load_dataset
from tasks.common import Task
class SmolTalk(Task):
""" smol-smoltalk dataset. train is 460K rows, test is 24K rows. """
"""smol-smoltalk dataset. train is 460K rows, test is 24K rows."""
def __init__(self, split, **kwargs):
super().__init__(**kwargs)
@ -29,14 +31,16 @@ class SmolTalk(Task):
assert len(messages) >= 1
first_message = messages[0]
if first_message["role"] == "system":
rest_messages = messages[1:] # optional system message is OK
rest_messages = messages[1:] # optional system message is OK
else:
rest_messages = messages
assert len(rest_messages) >= 2, "SmolTalk messages must have at least 2 messages"
for i, message in enumerate(rest_messages):
# user and assistant alternate as user,assistant,user,assistant,...
expected_role = "user" if i % 2 == 0 else "assistant"
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
assert message["role"] == expected_role, (
f"Message {i} has role {message['role']} but should be {expected_role}"
)
assert isinstance(message["content"], str), "Content must be a string"
# ---------------------------------------------------------------------
# create and return the Conversation object (ok to emit the system message too)

View File

@ -26,10 +26,11 @@ To preview a few example conversations, run:
python -m tasks.spellingbee
"""
import re
import random
from tasks.common import Task
import re
from nanochat.common import download_file_with_lock
from tasks.common import Task
# Letters of the alphabet
LETTERS = "abcdefghijklmnopqrstuvwxyz"
@ -38,6 +39,8 @@ WORD_LIST_URL = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads
# Identical to gsm8k's answer extraction
ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion):
"""
Extract the numerical answer after #### marker.
@ -49,6 +52,7 @@ def extract_answer(completion):
return match_str
return None
# User message templates for data augmentation
USER_MSG_TEMPLATES = [
"How many {letter} are in the word {word}",
@ -110,8 +114,8 @@ USER_MSG_TEMPLATES = [
"{word}{letter}が何回出てくる",
]
class SpellingBee(Task):
class SpellingBee(Task):
def __init__(self, size=1000, split="train", **kwargs):
super().__init__(**kwargs)
assert split in ["train", "test"], "SpellingBee split must be train|test"
@ -119,7 +123,7 @@ class SpellingBee(Task):
self.split = split
filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path, 'r', encoding='utf-8') as f:
with open(word_list_path, encoding='utf-8') as f:
words = [line.strip() for line in f]
self.words = words
@ -131,7 +135,7 @@ class SpellingBee(Task):
return self.size
def get_example(self, index):
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
rng = random.Random(seed)
# pick a random word
@ -148,12 +152,12 @@ class SpellingBee(Task):
if rng.random() < 0.3:
template = template.lower()
quote_options = ['', "'", '"']
letter_quote = rng.choice(quote_options) # is the letter quoted?
word_quote = rng.choice(quote_options) # is the word quoted?
letter_quote = rng.choice(quote_options) # is the letter quoted?
word_quote = rng.choice(quote_options) # is the word quoted?
letter_wrapped = f"{letter_quote}{letter}{letter_quote}"
word_wrapped = f"{word_quote}{word}{word_quote}"
user_msg = template.format(letter=letter_wrapped, word=word_wrapped)
if rng.random() < 0.5: # 50% of people don't even use question marks
if rng.random() < 0.5: # 50% of people don't even use question marks
user_msg += "?"
# Now create the ideal assistant response - build as parts (text + tool calls)
@ -190,13 +194,12 @@ Then count the occurrences of '{letter}':
# Part 4: Python output
assistant_parts.append({"type": "python_output", "text": str(count)})
# Part 5: Final answer
assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"})
assistant_parts.append(
{"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"}
)
# return the full conversation
messages = [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_parts}
]
messages = [{"role": "user", "content": user_msg}, {"role": "assistant", "content": assistant_parts}]
conversation = {
"messages": messages,
}
@ -222,7 +225,7 @@ Then count the occurrences of '{letter}':
return is_correct
def reward(self, conversation, assistant_response):
""" Use simple 0-1 reward just like gsm8k."""
"""Use simple 0-1 reward just like gsm8k."""
is_correct = self.evaluate(conversation, assistant_response)
is_correct_float = float(is_correct)
return is_correct_float
@ -238,10 +241,10 @@ class SimpleSpelling(Task):
self.split = split
filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path, 'r', encoding='utf-8') as f:
with open(word_list_path, encoding='utf-8') as f:
words = [line.strip() for line in f]
rng = random.Random(42)
rng.shuffle(words) # use a different word order than the SpellingBee task
rng.shuffle(words) # use a different word order than the SpellingBee task
self.words = words
@property
@ -252,7 +255,7 @@ class SimpleSpelling(Task):
return self.size
def get_example(self, index):
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
rng = random.Random(seed)
# pick a random word
word = rng.choice(self.words)
@ -260,7 +263,7 @@ class SimpleSpelling(Task):
# return the full conversation
messages = [
{"role": "user", "content": f"Spell the word: {word}"},
{"role": "assistant", "content": f"{word}:{word_letters}"}
{"role": "assistant", "content": f"{word}:{word_letters}"},
]
conversation = {
"messages": messages,
@ -269,7 +272,6 @@ class SimpleSpelling(Task):
if __name__ == "__main__":
# preview the SpellingBee task, first 10 examples
task = SpellingBee()
for i in range(10):

View File

@ -5,8 +5,10 @@ python -m pytest tests/test_engine.py -v
"""
import torch
from nanochat.engine import KVCache
def test_kv_cache_resize():
"""
The KV cache was not resized correctly, more information here:
@ -21,11 +23,7 @@ def test_kv_cache_resize():
num_layers = 6
kv_cache = KVCache(
batch_size=batch_size,
num_heads=num_heads,
seq_len=seq_len,
head_dim=head_dim,
num_layers=num_layers
batch_size=batch_size, num_heads=num_heads, seq_len=seq_len, head_dim=head_dim, num_layers=num_layers
)
# Insert a single token with a distinct fill value to all layers
@ -47,7 +45,9 @@ def test_kv_cache_resize():
insert_token(4)
# Verify that the cache actually resized
new_seq_len = kv_cache.kv_cache.shape[4]
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
assert new_seq_len > original_seq_len, (
f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
)
# Verify that the original 4 tokens are still intact after resize
for layer_idx in range(num_layers):
@ -57,8 +57,12 @@ def test_kv_cache_resize():
expected_v = float(token_idx * 100)
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
assert (actual_k == expected_k).all(), (
f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
)
assert (actual_v == expected_v).all(), (
f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
)
# And that the original cache matches resized cache
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]

View File

@ -18,18 +18,23 @@ python -m pytest tests/test_rustbpe.py -v -s
-v is verbose, -s is show prints
"""
import regex as re
from collections import Counter, defaultdict
import time
import rustbpe
import tiktoken
import pytest
from collections import Counter, defaultdict
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
import pytest
import regex as re
import tiktoken
import rustbpe
GPT4_SPLIT_PATTERN = (
r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
)
# -----------------------------------------------------------------------------
# Reference tokenizer, pretty much copy pasted and pruned a bit from minbpe
def get_stats(ids, counts=None):
"""
Given a list of integers, return a dictionary of counts of consecutive pairs
@ -37,10 +42,11 @@ def get_stats(ids, counts=None):
Optionally allows to update an existing dictionary of counts
"""
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]): # iterate consecutive elements
for pair in zip(ids, ids[1:]): # iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
"""
In the list of integers (ids), replace all consecutive occurrences
@ -51,7 +57,7 @@ def merge(ids, pair, idx):
i = 0
while i < len(ids):
# if not at the very last position AND the pair matches, replace it
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i + 1] == pair[1]:
newids.append(idx)
i += 2
else:
@ -59,8 +65,8 @@ def merge(ids, pair, idx):
i += 1
return newids
class RegexTokenizer:
class RegexTokenizer:
def __init__(self, pattern=None):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
@ -68,7 +74,7 @@ class RegexTokenizer:
example: {'<|endoftext|>': 100257}
"""
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
self.merges = {} # (int, int) -> int
self.merges = {} # (int, int) -> int
self.compiled_pattern = re.compile(self.pattern)
self.special_tokens = {}
self.inverse_special_tokens = {}
@ -97,8 +103,8 @@ class RegexTokenizer:
ids = [list(ch.encode("utf-8")) for ch in text_chunks]
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
for i in range(num_merges):
# count the number of times every consecutive pair appears
stats = {}
@ -125,11 +131,11 @@ class RegexTokenizer:
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
return ambiguous
def _encode_chunk(self, text_bytes):
@ -145,7 +151,7 @@ class RegexTokenizer:
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
@ -158,14 +164,16 @@ class RegexTokenizer:
# all chunks of text are encoded separately, then results are joined
ids = []
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_ids = self._encode_chunk(chunk_bytes)
ids.extend(chunk_ids)
return ids
# -----------------------------------------------------------------------------
# Faster Python tokenizer, optimized version of the reference tokenizer
def fast_merge_inplace(ids, pair, idx):
"""
In the list of integers (ids), replace all consecutive occurrences
@ -175,16 +183,15 @@ def fast_merge_inplace(ids, pair, idx):
# Find all positions where the pair occurs
i = 0
while i < len(ids) - 1:
if ids[i] == pair[0] and ids[i+1] == pair[1]:
if ids[i] == pair[0] and ids[i + 1] == pair[1]:
ids[i] = idx
ids.pop(i+1)
ids.pop(i + 1)
else:
i += 1
return ids
class FastRegexTokenizer:
def __init__(self, pattern=None):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
@ -229,8 +236,8 @@ class FastRegexTokenizer:
# input text preprocessing
ids = [list(ch.encode("utf-8")) for ch in unique_chunks]
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
# Initial count: build stats and position tracking
stats = defaultdict(int)
@ -262,31 +269,31 @@ class FastRegexTokenizer:
chunk_count = chunk_counts[chunk_idx]
ix = 0
while ix < len(chunk_ids) - 1:
if chunk_ids[ix] == pair[0] and chunk_ids[ix+1] == pair[1]:
if chunk_ids[ix] == pair[0] and chunk_ids[ix + 1] == pair[1]:
# Track what pairs are being removed/added
# Remove: (prev, A), (A, B), (B, next)
if ix > 0:
old_left = (chunk_ids[ix-1], chunk_ids[ix])
old_left = (chunk_ids[ix - 1], chunk_ids[ix])
count_changes[old_left] -= chunk_count
# The merged pair disappears
count_changes[pair] -= chunk_count
if ix + 2 < len(chunk_ids):
old_right = (chunk_ids[ix+1], chunk_ids[ix+2])
old_right = (chunk_ids[ix + 1], chunk_ids[ix + 2])
count_changes[old_right] -= chunk_count
# Apply the merge
chunk_ids[ix] = idx
chunk_ids.pop(ix+1)
chunk_ids.pop(ix + 1)
# Add: (prev, C), (C, next)
if ix > 0:
new_left = (chunk_ids[ix-1], chunk_ids[ix])
new_left = (chunk_ids[ix - 1], chunk_ids[ix])
count_changes[new_left] += chunk_count
if ix + 1 < len(chunk_ids):
new_right = (chunk_ids[ix], chunk_ids[ix+1])
new_right = (chunk_ids[ix], chunk_ids[ix + 1])
count_changes[new_right] += chunk_count
else:
ix += 1
@ -302,8 +309,9 @@ class FastRegexTokenizer:
# Update positions for changed pairs - only check affected chunks
for chunk_idx in affected_chunks:
chunk_ids = ids[chunk_idx]
contains_pair = any((chunk_ids[j], chunk_ids[j+1]) == changed_pair
for j in range(len(chunk_ids) - 1))
contains_pair = any(
(chunk_ids[j], chunk_ids[j + 1]) == changed_pair for j in range(len(chunk_ids) - 1)
)
if contains_pair:
positions[changed_pair].add(chunk_idx)
else:
@ -318,8 +326,8 @@ class FastRegexTokenizer:
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
def register_special_tokens(self, special_tokens):
# special_tokens is a dictionary of str -> int
@ -354,7 +362,7 @@ class FastRegexTokenizer:
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = fast_merge_inplace(ids, pair, idx)
@ -367,18 +375,20 @@ class FastRegexTokenizer:
# all chunks of text are encoded separately, then results are joined
ids = []
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_ids = self._encode_chunk(chunk_bytes)
ids.extend(chunk_ids)
return ids
# -----------------------------------------------------------------------------
# HuggingFace tokenizer
from tokenizers import Regex, decoders, pre_tokenizers
from tokenizers import Tokenizer as HFTokenizer
from tokenizers import pre_tokenizers, decoders, Regex
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
class HuggingFaceTokenizer:
"""Light wrapper around HuggingFace Tokenizer for some utilities"""
@ -389,19 +399,23 @@ class HuggingFaceTokenizer:
def train_from_iterator(cls, text_iterator, vocab_size):
# train from an iterator of text
# Configure the HuggingFace Tokenizer
tokenizer = HFTokenizer(BPE(
byte_fallback=True, # needed!
unk_token=None,
fuse_unk=False,
))
tokenizer = HFTokenizer(
BPE(
byte_fallback=True, # needed!
unk_token=None,
fuse_unk=False,
)
)
# Normalizer: None
tokenizer.normalizer = None
# Pre-tokenizer: GPT-4 style
gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
])
gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
]
)
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
tokenizer.decoder = decoders.ByteLevel()
# Post-processor: None
@ -410,9 +424,9 @@ class HuggingFaceTokenizer:
trainer = BpeTrainer(
vocab_size=vocab_size,
show_progress=True,
min_frequency=0, # no minimum frequency
min_frequency=0, # no minimum frequency
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
special_tokens=[], # no special tokens
special_tokens=[], # no special tokens
)
# Kick off the training
tokenizer.train_from_iterator(text_iterator, trainer)
@ -422,15 +436,19 @@ class HuggingFaceTokenizer:
ids = self.tokenizer.encode(text, add_special_tokens=False).ids
return ids
# -----------------------------------------------------------------------------
# Test all of the above
@pytest.fixture(scope="module")
def enwik8_path():
"""Fixture to download and cache enwik8 dataset."""
import os
import zipfile
from nanochat.common import get_base_dir
base_dir = get_base_dir()
# download and unzip enwik8 to .cache directory
enwik8_url = "https://mattmahoney.net/dc/enwik8.zip"
@ -439,6 +457,7 @@ def enwik8_path():
if not os.path.exists(enwik8_local_path):
print(f"Downloading enwik8 to {enwik8_local_path_zip}")
import requests
response = requests.get(enwik8_url)
with open(enwik8_local_path_zip, "wb") as f:
f.write(response.content)
@ -455,15 +474,17 @@ def enwik8_path():
@pytest.fixture(scope="module")
def enwik8_small(enwik8_path):
"""Fixture providing 100KB of enwik8 for quick tests."""
with open(enwik8_path, "r", encoding="utf-8") as f:
with open(enwik8_path, encoding="utf-8") as f:
return f.read(100_000)
@pytest.fixture(scope="module")
def enwik8_large(enwik8_path):
"""Fixture providing 10MB of enwik8 for performance tests."""
with open(enwik8_path, "r", encoding="utf-8") as f:
with open(enwik8_path, encoding="utf-8") as f:
return f.read(10**7)
def time_function(func, *args, **kwargs):
"""Time a function call and return the result and elapsed time"""
start_time = time.time()
@ -472,6 +493,7 @@ def time_function(func, *args, **kwargs):
elapsed = end_time - start_time
return result, elapsed
def test_correctness(enwik8_small):
"""Test that all tokenizer implementations produce the same results."""
text = enwik8_small
@ -482,7 +504,9 @@ def test_correctness(enwik8_small):
print("\nTraining slow reference...")
slow_reference_tokenizer = RegexTokenizer()
ambiguous_flag, slow_reference_train_time = time_function(slow_reference_tokenizer.train, text, vocab_size)
slow_reference_ids, slow_reference_encode_time = time_function(slow_reference_tokenizer.encode_ordinary, encode_text)
slow_reference_ids, slow_reference_encode_time = time_function(
slow_reference_tokenizer.encode_ordinary, encode_text
)
print(f"Slow reference train time: {slow_reference_train_time:.4f}s")
print(f"Slow reference encode time: {slow_reference_encode_time:.4f}s")
print(slow_reference_ids[:20])
@ -497,7 +521,9 @@ def test_correctness(enwik8_small):
print("\nTraining fast reference...")
fast_reference_tokenizer = FastRegexTokenizer()
_, fast_reference_train_time = time_function(fast_reference_tokenizer.train, text, vocab_size)
fast_reference_ids, fast_reference_encode_time = time_function(fast_reference_tokenizer.encode_ordinary, encode_text)
fast_reference_ids, fast_reference_encode_time = time_function(
fast_reference_tokenizer.encode_ordinary, encode_text
)
print(f"Fast reference train time: {fast_reference_train_time:.4f}s")
print(f"Fast reference encode time: {fast_reference_encode_time:.4f}s")
print(fast_reference_ids[:20])
@ -589,14 +615,16 @@ def test_training_performance(enwik8_large):
assert hf_train_time > 0, "Training should take some time"
# Print comparison
print(f"\n📊 Performance comparison:")
print("\n📊 Performance comparison:")
print(f" RustBPE: {rustbpe_train_time:.4f}s")
print(f" HuggingFace: {hf_train_time:.4f}s")
print(f" Speedup: {hf_train_time/rustbpe_train_time:.2f}x")
print(f" Speedup: {hf_train_time / rustbpe_train_time:.2f}x")
def test_interface(enwik8_small):
"""Test the RustBPETokenizer interface for training, encoding, decoding, and serialization."""
import tempfile
from nanochat.tokenizer import RustBPETokenizer
# Simple train test

69
uv.lock
View File

@ -188,6 +188,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e5/48/1549795ba7742c948d2ad169c1c8cdbae65bc450d6cd753d124b17c8cd32/certifi-2025.8.3-py3-none-any.whl", hash = "sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5", size = 161216, upload-time = "2025-08-03T03:07:45.777Z" },
]
[[package]]
name = "cfgv"
version = "3.5.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/4e/b5/721b8799b04bf9afe054a3899c6cf4e880fcf8563cc71c15610242490a0c/cfgv-3.5.0.tar.gz", hash = "sha256:d5b1034354820651caa73ede66a6294d6e95c1b00acc5e9b098e917404669132", size = 7334, upload-time = "2025-11-19T20:55:51.612Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/db/3c/33bac158f8ab7f89b2e59426d5fe2e4f63f7ed25df84c036890172b412b5/cfgv-3.5.0-py2.py3-none-any.whl", hash = "sha256:a8dc6b26ad22ff227d2634a65cb388215ce6cc96bbcc5cfde7641ae87e8dacc0", size = 7445, upload-time = "2025-11-19T20:55:50.744Z" },
]
[[package]]
name = "charset-normalizer"
version = "3.4.3"
@ -306,6 +315,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252, upload-time = "2024-01-27T23:42:14.239Z" },
]
[[package]]
name = "distlib"
version = "0.4.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605, upload-time = "2025-07-17T16:52:00.465Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" },
]
[[package]]
name = "exceptiongroup"
version = "1.3.0"
@ -528,6 +546,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/39/7b/bb06b061991107cd8783f300adff3e7b7f284e330fd82f507f2a1417b11d/huggingface_hub-0.34.4-py3-none-any.whl", hash = "sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a", size = 561452, upload-time = "2025-08-08T09:14:50.159Z" },
]
[[package]]
name = "identify"
version = "2.6.15"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ff/e7/685de97986c916a6d93b3876139e00eef26ad5bbbd61925d670ae8013449/identify-2.6.15.tar.gz", hash = "sha256:e4f4864b96c6557ef2a1e1c951771838f4edc9df3a72ec7118b338801b11c7bf", size = 99311, upload-time = "2025-10-02T17:43:40.631Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0f/1c/e5fd8f973d4f375adb21565739498e2e9a1e54c858a97b9a8ccfdc81da9b/identify-2.6.15-py2.py3-none-any.whl", hash = "sha256:1181ef7608e00704db228516541eb83a88a9f94433a8c80bb9b5bd54b1d81757", size = 99183, upload-time = "2025-10-02T17:43:39.137Z" },
]
[[package]]
name = "idna"
version = "3.10"
@ -802,6 +829,7 @@ gpu = [
[package.dev-dependencies]
dev = [
{ name = "maturin" },
{ name = "pre-commit" },
{ name = "pytest" },
]
@ -826,6 +854,7 @@ provides-extras = ["cpu", "gpu"]
[package.metadata.requires-dev]
dev = [
{ name = "maturin", specifier = ">=1.9.4" },
{ name = "pre-commit", specifier = ">=3.8.0" },
{ name = "pytest", specifier = ">=8.0.0" },
]
@ -872,6 +901,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" },
]
[[package]]
name = "nodeenv"
version = "1.9.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437, upload-time = "2024-06-04T18:44:11.171Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" },
]
[[package]]
name = "numpy"
version = "1.26.4"
@ -1131,6 +1169,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]]
name = "pre-commit"
version = "4.5.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cfgv" },
{ name = "identify" },
{ name = "nodeenv" },
{ name = "pyyaml" },
{ name = "virtualenv" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f4/9b/6a4ffb4ed980519da959e1cf3122fc6cb41211daa58dbae1c73c0e519a37/pre_commit-4.5.0.tar.gz", hash = "sha256:dc5a065e932b19fc1d4c653c6939068fe54325af8e741e74e88db4d28a4dd66b", size = 198428, upload-time = "2025-11-22T21:02:42.304Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5d/c4/b2d28e9d2edf4f1713eb3c29307f1a63f3d67cf09bdda29715a36a68921a/pre_commit-4.5.0-py2.py3-none-any.whl", hash = "sha256:25e2ce09595174d9c97860a95609f9f852c0614ba602de3561e267547f2335e1", size = 226429, upload-time = "2025-11-22T21:02:40.836Z" },
]
[[package]]
name = "propcache"
version = "0.3.2"
@ -2016,6 +2070,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/96/06/5cc0542b47c0338c1cb676b348e24a1c29acabc81000bced518231dded6f/uvicorn-0.36.0-py3-none-any.whl", hash = "sha256:6bb4ba67f16024883af8adf13aba3a9919e415358604ce46780d3f9bdc36d731", size = 67675, upload-time = "2025-09-20T01:07:12.984Z" },
]
[[package]]
name = "virtualenv"
version = "20.35.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "distlib" },
{ name = "filelock" },
{ name = "platformdirs" },
{ name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/20/28/e6f1a6f655d620846bd9df527390ecc26b3805a0c5989048c210e22c5ca9/virtualenv-20.35.4.tar.gz", hash = "sha256:643d3914d73d3eeb0c552cbb12d7e82adf0e504dbf86a3182f8771a153a1971c", size = 6028799, upload-time = "2025-10-29T06:57:40.511Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/79/0c/c05523fa3181fdf0c9c52a6ba91a23fbf3246cc095f26f6516f9c60e6771/virtualenv-20.35.4-py3-none-any.whl", hash = "sha256:c21c9cede36c9753eeade68ba7d523529f228a403463376cf821eaae2b650f1b", size = 6005095, upload-time = "2025-10-29T06:57:37.598Z" },
]
[[package]]
name = "wandb"
version = "0.21.3"