mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Compare commits
4 Commits
660f1bff6c
...
2288750906
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2288750906 | ||
|
|
59ed9392ed | ||
|
|
449494c8b6 | ||
|
|
6587063479 |
44
.github/workflows/pre-commit.yml
vendored
Normal file
44
.github/workflows/pre-commit.yml
vendored
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
name: Pre-commit
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
run-pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Cache uv & pre-commit
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
.venv
|
||||
~/.cache/uv
|
||||
~/.cache/pre-commit
|
||||
key: ${{ runner.os }}-uv-${{ hashFiles('uv.lock', '.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-uv-
|
||||
|
||||
- name: Install dev dependencies
|
||||
run: uv sync --group dev
|
||||
|
||||
- name: Run pre-commit
|
||||
run: uv run pre-commit run --all-files
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -4,4 +4,4 @@ __pycache__/
|
|||
rustbpe/target/
|
||||
dev-ignore/
|
||||
report.md
|
||||
eval_bundle/
|
||||
eval_bundle/
|
||||
|
|
|
|||
27
.pre-commit-config.yaml
Normal file
27
.pre-commit-config.yaml
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v6.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-added-large-files
|
||||
args: [--maxkb=128]
|
||||
- id: fix-byte-order-marker
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: mixed-line-ending
|
||||
args: [--fix=lf]
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.8
|
||||
hooks:
|
||||
- id: ruff-check
|
||||
- id: ruff-format
|
||||
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.4.1 # Use the latest stable version
|
||||
hooks:
|
||||
- id: codespell
|
||||
additional_dependencies:
|
||||
- tomli
|
||||
16
README.md
16
README.md
|
|
@ -123,6 +123,22 @@ I haven't invested too much here but some tests exist, especially for the tokeni
|
|||
python -m pytest tests/test_rustbpe.py -v -s
|
||||
```
|
||||
|
||||
## Pre-commit hooks
|
||||
|
||||
Linting and formatting are enforced with [pre-commit](https://pre-commit.com/) both locally and in CI via GitHub Actions. To match the checks that run in PRs:
|
||||
|
||||
- Make sure the dev extras are installed (`uv sync --group dev`).
|
||||
- Run the suite on demand: `uv run pre-commit run --all-files`.
|
||||
- (optional) Install the git hook once (for automation during `git commit`): `uv run pre-commit install`.
|
||||
|
||||
Hook coverage (auto-fixes most issues; review and stage the changes afterward):
|
||||
|
||||
- [`ruff`](https://github.com/astral-sh/ruff): a fast Rust-based linter and formatter that replaces multiple tools:
|
||||
- **Linting** (`ruff-check`): removes unused imports (like autoflake), upgrades syntax (like pyupgrade), and sorts imports (like isort).
|
||||
- **Formatting** (`ruff-format`): applies consistent code formatting (like black), with quote style preserved.
|
||||
- [`pre-commit-hooks`](https://github.com/pre-commit/pre-commit-hooks): repo hygiene (trim trailing whitespace, enforce LF endings/newlines, detect merge conflicts, block oversized files).
|
||||
- [`codespell`](https://github.com/codespell-project/codespell): catches common spelling mistakes in code and docs (add false positives to `[tool.codespell].ignore-words-list` in `pyproject.toml`).
|
||||
|
||||
## File structure
|
||||
|
||||
```
|
||||
|
|
|
|||
|
|
@ -28,24 +28,23 @@ NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the
|
|||
(obviously you can tune this arbitrarily to your liking)
|
||||
NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139
|
||||
"""
|
||||
import requests
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import copy
|
||||
import random
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import requests
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
|
||||
api_key = open("openroutertoken.txt", "r", encoding="utf-8").read().strip()
|
||||
api_key = open("openroutertoken.txt", encoding="utf-8").read().strip()
|
||||
|
||||
url = "https://openrouter.ai/api/v1/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
readme = open("README.md", "r", encoding="utf-8").read().strip()
|
||||
readme = open("README.md", encoding="utf-8").read().strip()
|
||||
prompt = r"""
|
||||
I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want:
|
||||
|
||||
|
|
@ -276,48 +275,46 @@ prompt = prompt.replace("%README%", readme)
|
|||
|
||||
# Define the JSON schema for structured output
|
||||
response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "conversation",
|
||||
"strict": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"description": "A list of conversation messages alternating between user and assistant, with the first message being a user message",
|
||||
"items": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "conversation",
|
||||
"strict": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {
|
||||
"type": "string",
|
||||
"description": "The role of the speaker, either 'user' or 'assistant'"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The message content"
|
||||
}
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"description": "A list of conversation messages alternating between user and assistant, with the first message being a user message",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {
|
||||
"type": "string",
|
||||
"description": "The role of the speaker, either 'user' or 'assistant'",
|
||||
},
|
||||
"content": {"type": "string", "description": "The message content"},
|
||||
},
|
||||
"required": ["role", "content"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["role", "content"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["messages"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
"required": ["messages"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Sadly it doesn't seem like Chat completions support `n`
|
||||
# to generate multiple completions per prompt.
|
||||
base_payload = {
|
||||
"model": "google/gemini-2.5-flash",
|
||||
"stream": False,
|
||||
"response_format": response_format,
|
||||
"temperature": 1.0,
|
||||
"model": "google/gemini-2.5-flash",
|
||||
"stream": False,
|
||||
"response_format": response_format,
|
||||
"temperature": 1.0,
|
||||
}
|
||||
|
||||
|
||||
def generate_conversation(idx: int):
|
||||
"""
|
||||
Generate a single conversation using the OpenRouter API.
|
||||
|
|
@ -325,7 +322,7 @@ def generate_conversation(idx: int):
|
|||
"""
|
||||
|
||||
# pick 5 example user first messages and insert them into prompt as inspiration
|
||||
rng = random.Random(idx) # use idx as seed to the rng
|
||||
rng = random.Random(idx) # use idx as seed to the rng
|
||||
user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5))
|
||||
payload = copy.deepcopy(base_payload)
|
||||
modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt)
|
||||
|
|
@ -357,7 +354,6 @@ print(f"Generating {num_conversations} conversations with {num_workers} workers.
|
|||
completed_count = 0
|
||||
error_count = 0
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
|
||||
# Submit all tasks
|
||||
futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)]
|
||||
|
||||
|
|
@ -369,7 +365,9 @@ with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|||
# Lightly validate the conversation structure
|
||||
for i, message in enumerate(messages):
|
||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||
assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
assert message['role'] == expected_role, (
|
||||
f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
)
|
||||
|
||||
# If all looks good, write the messages to file
|
||||
with open(output_file, 'a') as f:
|
||||
|
|
@ -384,4 +382,3 @@ with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|||
print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}")
|
||||
if error_count > 0:
|
||||
print(f"Encountered {error_count} errors during generation")
|
||||
|
||||
|
|
|
|||
|
|
@ -26,4 +26,4 @@
|
|||
svg.innerHTML += `<path d="M200,-12 L212,0 L200,12 L188,0 Z" transform="translate(0,200)" fill="#000"/>`;
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
</html>
|
||||
|
|
|
|||
|
|
@ -13,24 +13,25 @@ training latency.
|
|||
NOTE: This file is meant only as reference/documentation of the
|
||||
dataset preparation and it is not used during the project runtime.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from datasets import load_dataset
|
||||
import pyarrow.parquet as pq
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from datasets import load_dataset
|
||||
|
||||
# Source dataset
|
||||
dataset_kwargs = {
|
||||
"path": "HuggingFaceFW/fineweb-edu",
|
||||
"split": "train",
|
||||
"name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total
|
||||
"name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total
|
||||
}
|
||||
ds = load_dataset(**dataset_kwargs)
|
||||
|
||||
# Shuffle to scramble the order
|
||||
ds = ds.shuffle(seed=42)
|
||||
ndocs = len(ds) # total number of documents to process
|
||||
ndocs = len(ds) # total number of documents to process
|
||||
print(f"Total number of documents: {ndocs}")
|
||||
|
||||
# Repackage into parquet files
|
||||
|
|
@ -39,7 +40,7 @@ os.makedirs(output_dir, exist_ok=True)
|
|||
|
||||
# Write to parquet files
|
||||
chars_per_shard = 250_000_000
|
||||
row_group_size = 1024 # HF uses 1000 but we use multiple of 2, nicer for distributed data loader later
|
||||
row_group_size = 1024 # HF uses 1000 but we use multiple of 2, nicer for distributed data loader later
|
||||
shard_docs = []
|
||||
shard_index = 0
|
||||
shard_characters = 0
|
||||
|
|
@ -52,20 +53,20 @@ for doc in ds:
|
|||
shard_characters += len(text)
|
||||
collected_enough_chars = shard_characters >= chars_per_shard
|
||||
docs_multiple_of_row_group_size = len(shard_docs) % row_group_size == 0
|
||||
if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed)
|
||||
if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed)
|
||||
shard_path = os.path.join(output_dir, f"shard_{shard_index:05d}.parquet")
|
||||
shard_table = pa.Table.from_pydict({"text": shard_docs})
|
||||
pq.write_table(
|
||||
shard_table,
|
||||
shard_path,
|
||||
row_group_size=row_group_size,
|
||||
use_dictionary=False, # this is usually used for categorical data
|
||||
compression="zstd", # Valid values: {‘NONE’, ‘SNAPPY’, ‘GZIP’, ‘BROTLI’, ‘LZ4’, ‘ZSTD’}
|
||||
use_dictionary=False, # this is usually used for categorical data
|
||||
compression="zstd", # Valid values: {‘NONE’, ‘SNAPPY’, ‘GZIP’, ‘BROTLI’, ‘LZ4’, ‘ZSTD’}
|
||||
compression_level=3,
|
||||
write_statistics=False, # not needed for text
|
||||
write_statistics=False, # not needed for text
|
||||
)
|
||||
t1 = time.time()
|
||||
dt = t1 - t0 # for this shard alone
|
||||
dt = t1 - t0 # for this shard alone
|
||||
t0 = t1
|
||||
total_docs_processed += len(shard_docs)
|
||||
total_time_spent += dt
|
||||
|
|
@ -73,15 +74,20 @@ for doc in ds:
|
|||
avg_time_per_doc = total_time_spent / total_docs_processed
|
||||
remaining_time = remaining_docs * avg_time_per_doc
|
||||
remaining_time_hours = remaining_time / 3600
|
||||
print(f"Wrote {shard_path}. #documents: {len(shard_docs)} | #characters: {shard_characters} | time: {dt:.2f}s | remaining time: {remaining_time_hours:.2f}h")
|
||||
print(
|
||||
f"Wrote {shard_path}. #documents: {len(shard_docs)} | #characters: {shard_characters} | time: {dt:.2f}s | remaining time: {remaining_time_hours:.2f}h"
|
||||
)
|
||||
shard_docs = []
|
||||
shard_characters = 0
|
||||
shard_index += 1
|
||||
|
||||
|
||||
# Demonstration of how the data was later uploaded to HuggingFace
|
||||
def upload():
|
||||
import os
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
token = os.getenv("HF_TOKEN")
|
||||
api = HfApi(token=token)
|
||||
api.upload_large_folder(
|
||||
|
|
@ -89,4 +95,6 @@ def upload():
|
|||
repo_id="karpathy/fineweb-edu-100b-shuffle",
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
|
||||
# upload()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
|
||||
Not a general optimizer! But works for our specific use.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
|
@ -12,7 +13,15 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
Distributed AdamW optimizer.
|
||||
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
|
||||
"""
|
||||
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
param_groups,
|
||||
lr: float = 1e-3,
|
||||
betas: tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0.01,
|
||||
):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super().__init__(param_groups, defaults)
|
||||
|
||||
|
|
@ -30,7 +39,9 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
grad = params[base_i].grad
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
reduce_scatter_futures.append(
|
||||
dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
)
|
||||
grad_slices.append(grad_slice)
|
||||
|
||||
idx = 0
|
||||
|
|
@ -43,7 +54,7 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
reduce_scatter_futures[idx].wait()
|
||||
p = params[base]
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
p_slice = p[rank * rank_size : (rank + 1) * rank_size]
|
||||
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
||||
state = self.state[p]
|
||||
g_slice = grad_slices[idx]
|
||||
|
|
@ -64,8 +75,8 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
|
||||
# bias corrections
|
||||
bias1 = 1 - beta1 ** t
|
||||
bias2 = 1 - beta2 ** t
|
||||
bias1 = 1 - beta1**t
|
||||
bias2 = 1 - beta2**t
|
||||
# compute step
|
||||
denom = exp_avg_sq.sqrt().add_(eps)
|
||||
step_size = lr * (torch.sqrt(bias2) / bias1)
|
||||
|
|
|
|||
|
|
@ -1,25 +1,29 @@
|
|||
"""
|
||||
Utilities for saving and loading model/optim/state checkpoints.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
from nanochat.common import get_base_dir, setup_default_logging
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
from nanochat.common import setup_default_logging
|
||||
|
||||
# Set up logging
|
||||
setup_default_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def log0(message):
|
||||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logger.info(message)
|
||||
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
|
@ -38,6 +42,7 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data,
|
|||
torch.save(optimizer_data, optimizer_path)
|
||||
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
||||
# Load the model state
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
|
|
@ -49,7 +54,7 @@ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
|||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
# Load the metadata
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "r", encoding="utf-8") as f:
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta_data = json.load(f)
|
||||
return model_data, optimizer_data, meta_data
|
||||
|
||||
|
|
@ -66,10 +71,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
||||
if device.type in {"cpu", "mps"}:
|
||||
# Convert bfloat16 tensors to float for CPU inference
|
||||
model_data = {
|
||||
k: v.float() if v.dtype == torch.bfloat16 else v
|
||||
for k, v in model_data.items()
|
||||
}
|
||||
model_data = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items()}
|
||||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
||||
model_config_kwargs = meta_data["model_config"]
|
||||
|
|
@ -79,7 +81,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
model = GPT(model_config)
|
||||
# Load the model state
|
||||
model.to_empty(device=device)
|
||||
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
|
||||
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
|
||||
model.load_state_dict(model_data, strict=True, assign=True)
|
||||
# Put the model in the right training phase / mode
|
||||
if phase == "eval":
|
||||
|
|
@ -121,9 +123,11 @@ def find_last_step(checkpoint_dir):
|
|||
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
|
||||
return last_step
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# convenience functions that take into account nanochat's directory structure
|
||||
|
||||
|
||||
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
|
||||
if model_tag is None:
|
||||
# guess the model tag by defaulting to the largest model
|
||||
|
|
@ -139,6 +143,7 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non
|
|||
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
|
||||
return model, tokenizer, meta_data
|
||||
|
||||
|
||||
def load_model(source, *args, **kwargs):
|
||||
model_dir = {
|
||||
"base": "base_checkpoints",
|
||||
|
|
|
|||
|
|
@ -2,26 +2,30 @@
|
|||
Common utilities for nanochat.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
import urllib.request
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from filelock import FileLock
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""Custom formatter that adds colors to log messages."""
|
||||
|
||||
# ANSI color codes
|
||||
COLORS = {
|
||||
'DEBUG': '\033[36m', # Cyan
|
||||
'INFO': '\033[32m', # Green
|
||||
'DEBUG': '\033[36m', # Cyan
|
||||
'INFO': '\033[32m', # Green
|
||||
'WARNING': '\033[33m', # Yellow
|
||||
'ERROR': '\033[31m', # Red
|
||||
'CRITICAL': '\033[35m', # Magenta
|
||||
'ERROR': '\033[31m', # Red
|
||||
'CRITICAL': '\033[35m', # Magenta
|
||||
}
|
||||
RESET = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
|
||||
def format(self, record):
|
||||
# Add color to the level name
|
||||
levelname = record.levelname
|
||||
|
|
@ -36,17 +40,17 @@ class ColoredFormatter(logging.Formatter):
|
|||
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
|
||||
return message
|
||||
|
||||
|
||||
def setup_default_logging():
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[handler]
|
||||
)
|
||||
logging.basicConfig(level=logging.INFO, handlers=[handler])
|
||||
|
||||
|
||||
setup_default_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_base_dir():
|
||||
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
|
||||
if os.environ.get("NANOCHAT_BASE_DIR"):
|
||||
|
|
@ -58,6 +62,7 @@ def get_base_dir():
|
|||
os.makedirs(nanochat_dir, exist_ok=True)
|
||||
return nanochat_dir
|
||||
|
||||
|
||||
def download_file_with_lock(url, filename, postprocess_fn=None):
|
||||
"""
|
||||
Downloads a file from a URL to a local path in the base directory.
|
||||
|
|
@ -81,7 +86,7 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
|
|||
# Download the content as bytes
|
||||
print(f"Downloading {url}...")
|
||||
with urllib.request.urlopen(url) as response:
|
||||
content = response.read() # bytes
|
||||
content = response.read() # bytes
|
||||
|
||||
# Write to local file
|
||||
with open(file_path, 'wb') as f:
|
||||
|
|
@ -94,11 +99,13 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
|
|||
|
||||
return file_path
|
||||
|
||||
def print0(s="",**kwargs):
|
||||
|
||||
def print0(s="", **kwargs):
|
||||
ddp_rank = int(os.environ.get('RANK', 0))
|
||||
if ddp_rank == 0:
|
||||
print(s, **kwargs)
|
||||
|
||||
|
||||
def print_banner():
|
||||
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
|
||||
banner = """
|
||||
|
|
@ -113,10 +120,12 @@ def print_banner():
|
|||
"""
|
||||
print0(banner)
|
||||
|
||||
|
||||
def is_ddp():
|
||||
# TODO is there a proper way
|
||||
return int(os.environ.get('RANK', -1)) != -1
|
||||
|
||||
|
||||
def get_dist_info():
|
||||
if is_ddp():
|
||||
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
||||
|
|
@ -127,6 +136,7 @@ def get_dist_info():
|
|||
else:
|
||||
return False, 0, 0, 1
|
||||
|
||||
|
||||
def autodetect_device_type():
|
||||
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
||||
if torch.cuda.is_available():
|
||||
|
|
@ -138,14 +148,19 @@ def autodetect_device_type():
|
|||
print0(f"Autodetected device type: {device_type}")
|
||||
return device_type
|
||||
|
||||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
|
||||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
||||
if device_type == "cuda":
|
||||
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
||||
assert torch.cuda.is_available(), (
|
||||
"Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
||||
)
|
||||
if device_type == "mps":
|
||||
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
assert torch.backends.mps.is_available(), (
|
||||
"Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
)
|
||||
|
||||
# Reproducibility
|
||||
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
|
||||
|
|
@ -158,7 +173,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
|||
|
||||
# Precision
|
||||
if device_type == "cuda":
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
|
@ -168,23 +183,28 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
|||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
else:
|
||||
device = torch.device(device_type) # mps|cpu
|
||||
device = torch.device(device_type) # mps|cpu
|
||||
|
||||
if ddp_rank == 0:
|
||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||
|
||||
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
|
||||
|
||||
|
||||
def compute_cleanup():
|
||||
"""Companion function to compute_init, to clean things up before script exit"""
|
||||
if is_ddp():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
class DummyWandb:
|
||||
"""Useful if we wish to not use wandb but have all the same signatures"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def finish(self):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -18,11 +18,13 @@ import os
|
|||
import sys
|
||||
from ast import literal_eval
|
||||
|
||||
def print0(s="",**kwargs):
|
||||
|
||||
def print0(s="", **kwargs):
|
||||
ddp_rank = int(os.environ.get('RANK', 0))
|
||||
if ddp_rank == 0:
|
||||
print(s, **kwargs)
|
||||
|
||||
|
||||
for arg in sys.argv[1:]:
|
||||
if '=' not in arg:
|
||||
# assume it's the name of a config file
|
||||
|
|
|
|||
|
|
@ -5,15 +5,17 @@ https://arxiv.org/abs/2406.11794
|
|||
TODOs:
|
||||
- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
from jinja2 import Template
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from jinja2 import Template
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Prompt rendering utilities
|
||||
|
||||
|
||||
def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
|
||||
"""Render complete prompts for a multiple choice question"""
|
||||
template_str = """
|
||||
|
|
@ -24,11 +26,7 @@ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
|
|||
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
|
||||
template = Template(template_str)
|
||||
fewshot_examples = fewshot_examples or []
|
||||
context = {
|
||||
'fewshot_examples': fewshot_examples,
|
||||
'continuation_delimiter': continuation_delimiter,
|
||||
'item': item
|
||||
}
|
||||
context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
|
||||
prompts = [template.render(choice=choice, **context) for choice in item['choices']]
|
||||
return prompts
|
||||
|
||||
|
|
@ -43,13 +41,8 @@ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
|
|||
{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
|
||||
template = Template(template_str)
|
||||
fewshot_examples = fewshot_examples or []
|
||||
context = {
|
||||
'fewshot_examples': fewshot_examples,
|
||||
'continuation_delimiter': continuation_delimiter,
|
||||
'item': item
|
||||
}
|
||||
prompts = [template.render(context=context_option, **context)
|
||||
for context_option in item['context_options']]
|
||||
context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
|
||||
prompts = [template.render(context=context_option, **context) for context_option in item['context_options']]
|
||||
return prompts
|
||||
|
||||
|
||||
|
|
@ -67,11 +60,7 @@ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
|
|||
{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
|
||||
template = Template(template_str)
|
||||
fewshot_examples = fewshot_examples or []
|
||||
context = {
|
||||
'fewshot_examples': fewshot_examples,
|
||||
'continuation_delimiter': continuation_delimiter,
|
||||
'item': item
|
||||
}
|
||||
context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
|
||||
# Return two prompts: without and with the continuation
|
||||
prompt_without = template.render(include_continuation=False, **context)
|
||||
prompt_with = template.render(include_continuation=True, **context)
|
||||
|
|
@ -89,10 +78,7 @@ def find_common_length(token_sequences, direction='left'):
|
|||
- direction: 'left' for prefix, 'right' for suffix
|
||||
"""
|
||||
min_len = min(len(seq) for seq in token_sequences)
|
||||
indices = {
|
||||
'left': range(min_len),
|
||||
'right': range(-1, -min_len-1, -1)
|
||||
}[direction]
|
||||
indices = {'left': range(min_len), 'right': range(-1, -min_len - 1, -1)}[direction]
|
||||
# Find the first position where the token sequences differ
|
||||
for i, idx in enumerate(indices):
|
||||
token = token_sequences[0][idx]
|
||||
|
|
@ -106,7 +92,7 @@ def stack_sequences(tokens, pad_token_id):
|
|||
bsz, seq_len = len(tokens), max(len(x) for x in tokens)
|
||||
input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
|
||||
for i, x in enumerate(tokens):
|
||||
input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
|
||||
input_ids[i, : len(x)] = torch.tensor(x, dtype=torch.long)
|
||||
return input_ids
|
||||
|
||||
|
||||
|
|
@ -153,9 +139,7 @@ def forward_model(model, input_ids):
|
|||
target_ids = torch.roll(input_ids, shifts=-1, dims=1)
|
||||
# Calculate cross entropy at all positions
|
||||
losses = torch.nn.functional.cross_entropy(
|
||||
outputs.view(batch_size * seq_len, -1),
|
||||
target_ids.view(batch_size * seq_len),
|
||||
reduction='none'
|
||||
outputs.view(batch_size * seq_len, -1), target_ids.view(batch_size * seq_len), reduction='none'
|
||||
).view(batch_size, seq_len)
|
||||
# Set the last column to be nan because there is no autoregressive loss there
|
||||
losses[:, -1] = float('nan')
|
||||
|
|
@ -201,19 +185,19 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
|||
for t, s, e in zip(tokens, start_idxs, end_idxs):
|
||||
if len(t) > max_tokens:
|
||||
num_to_crop = len(t) - max_tokens
|
||||
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
|
||||
new_start_idxs.append(s - num_to_crop) # shift the indices down
|
||||
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
|
||||
new_start_idxs.append(s - num_to_crop) # shift the indices down
|
||||
new_end_idxs.append(e - num_to_crop)
|
||||
assert s - num_to_crop >= 0, "this should never happen right?"
|
||||
assert e - num_to_crop >= 0, "this should never happen right?"
|
||||
else:
|
||||
new_tokens.append(t) # keep unchanged
|
||||
new_tokens.append(t) # keep unchanged
|
||||
new_start_idxs.append(s)
|
||||
new_end_idxs.append(e)
|
||||
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
|
||||
|
||||
# Stack up all the sequences into a batch
|
||||
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
|
||||
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
|
||||
input_ids = stack_sequences(tokens, pad_token_id)
|
||||
input_ids = input_ids.to(device)
|
||||
|
||||
|
|
@ -226,13 +210,12 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
|||
si = start_idxs[0]
|
||||
ei = end_idxs[0]
|
||||
# predictions[i] predict input_ids[i+1] autoregressively
|
||||
predicted_tokens = predictions[0, si-1:ei-1]
|
||||
predicted_tokens = predictions[0, si - 1 : ei - 1]
|
||||
actual_tokens = input_ids[0, si:ei]
|
||||
is_correct = torch.all(predicted_tokens == actual_tokens).item()
|
||||
elif task_type in ['multiple_choice', 'schema']:
|
||||
# For MC/schema: find the option with lowest average loss
|
||||
mean_losses = [losses[i, si-1:ei-1].mean().item()
|
||||
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
|
||||
mean_losses = [losses[i, si - 1 : ei - 1].mean().item() for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
|
||||
pred_idx = mean_losses.index(min(mean_losses))
|
||||
is_correct = pred_idx == item['gold']
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
from collections import deque
|
||||
|
||||
import torch
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
|
||||
from nanochat.common import get_dist_info
|
||||
from nanochat.dataset import list_parquet_files
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state(
|
||||
B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None
|
||||
):
|
||||
"""
|
||||
Stream pretraining text from parquet files, tokenize, yield training batches.
|
||||
|
||||
|
|
@ -24,42 +27,44 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
|
|||
|
||||
# infinite iterator over document batches (list of text strings)
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
||||
def document_batches():
|
||||
parquet_paths = list_parquet_files()
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
|
||||
while True: # iterate infinitely (multi-epoch)
|
||||
while pq_idx < len(parquet_paths): # iterate over all parquet files
|
||||
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
|
||||
while True: # iterate infinitely (multi-epoch)
|
||||
while pq_idx < len(parquet_paths): # iterate over all parquet files
|
||||
filepath = parquet_paths[pq_idx]
|
||||
pf = pq.ParquetFile(filepath)
|
||||
# Start from resume point if resuming on same file, otherwise from DDP rank
|
||||
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
|
||||
if resume_rg_idx is not None:
|
||||
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
|
||||
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
|
||||
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
|
||||
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
|
||||
rg_idx = base_idx * ddp_world_size + ddp_rank
|
||||
resume_rg_idx = None # set to None as we only want to do this a single time
|
||||
resume_rg_idx = None # set to None as we only want to do this a single time
|
||||
else:
|
||||
rg_idx = ddp_rank
|
||||
while rg_idx < pf.num_row_groups:
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
|
||||
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
|
||||
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
|
||||
rg_idx += ddp_world_size # advance to the next row group (in DDP)
|
||||
pq_idx += 1 # advance to the next parquet file
|
||||
yield batch[i : i + tokenizer_batch_size], (pq_idx, rg_idx)
|
||||
rg_idx += ddp_world_size # advance to the next row group (in DDP)
|
||||
pq_idx += 1 # advance to the next parquet file
|
||||
|
||||
batches = document_batches()
|
||||
|
||||
# Now emit batches of tokens.
|
||||
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
||||
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
||||
# get the tokenizer and the bos token
|
||||
tokenizer = get_tokenizer()
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding.
|
||||
while len(token_buffer) < needed_tokens:
|
||||
|
|
@ -71,16 +76,20 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
|
|||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
|
||||
use_cuda_optimizations = device == "cuda"
|
||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
|
||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
|
||||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1]
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
|
||||
state_dict = {
|
||||
"pq_idx": pq_idx,
|
||||
"rg_idx": rg_idx,
|
||||
} # we need this in case we wish to approximately resume training
|
||||
yield inputs, targets, state_dict
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader(*args, **kwargs):
|
||||
# helper function that only emits the inputs/targets and not the state_dict
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -7,13 +7,14 @@ This file contains utilities for:
|
|||
For details of how the dataset was prepared, see `repackage_data_reference.py`.
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
import pyarrow.parquet as pq
|
||||
from multiprocessing import Pool
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
import requests
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -21,8 +22,8 @@ from nanochat.common import get_base_dir
|
|||
|
||||
# The URL on the internet where the data is hosted and downloaded from on demand
|
||||
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
|
||||
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
|
||||
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
|
||||
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
|
||||
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
|
||||
base_dir = get_base_dir()
|
||||
DATA_DIR = os.path.join(base_dir, "base_data")
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
|
|
@ -30,16 +31,15 @@ os.makedirs(DATA_DIR, exist_ok=True)
|
|||
# -----------------------------------------------------------------------------
|
||||
# These functions are useful utilities to other modules, can/should be imported
|
||||
|
||||
|
||||
def list_parquet_files(data_dir=None):
|
||||
""" Looks into a data dir and returns full paths to all parquet files. """
|
||||
"""Looks into a data dir and returns full paths to all parquet files."""
|
||||
data_dir = DATA_DIR if data_dir is None else data_dir
|
||||
parquet_files = sorted([
|
||||
f for f in os.listdir(data_dir)
|
||||
if f.endswith('.parquet') and not f.endswith('.tmp')
|
||||
])
|
||||
parquet_files = sorted([f for f in os.listdir(data_dir) if f.endswith('.parquet') and not f.endswith('.tmp')])
|
||||
parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
|
||||
return parquet_paths
|
||||
|
||||
|
||||
def parquets_iter_batched(split, start=0, step=1):
|
||||
"""
|
||||
Iterate through the dataset, in batches of underlying row_groups for efficiency.
|
||||
|
|
@ -56,9 +56,10 @@ def parquets_iter_batched(split, start=0, step=1):
|
|||
texts = rg.column('text').to_pylist()
|
||||
yield texts
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def download_single_file(index):
|
||||
""" Downloads a single file index, with some backoff """
|
||||
"""Downloads a single file index, with some backoff"""
|
||||
|
||||
# Construct the local filepath for this file and skip if it already exists
|
||||
filename = index_to_filename(index)
|
||||
|
|
@ -78,7 +79,7 @@ def download_single_file(index):
|
|||
response = requests.get(url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
# Write to temporary file first
|
||||
temp_path = filepath + f".tmp"
|
||||
temp_path = filepath + ".tmp"
|
||||
with open(temp_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
|
||||
if chunk:
|
||||
|
|
@ -88,10 +89,10 @@ def download_single_file(index):
|
|||
print(f"Successfully downloaded {filename}")
|
||||
return True
|
||||
|
||||
except (requests.RequestException, IOError) as e:
|
||||
except (OSError, requests.RequestException) as e:
|
||||
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
||||
# Clean up any partial files
|
||||
for path in [filepath + f".tmp", filepath]:
|
||||
for path in [filepath + ".tmp", filepath]:
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
|
|
@ -99,7 +100,7 @@ def download_single_file(index):
|
|||
pass
|
||||
# Try a few times with exponential backoff: 2^attempt seconds
|
||||
if attempt < max_attempts:
|
||||
wait_time = 2 ** attempt
|
||||
wait_time = 2**attempt
|
||||
print(f"Waiting {wait_time} seconds before retry...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
|
|
@ -111,8 +112,12 @@ def download_single_file(index):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
|
||||
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
|
||||
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
|
||||
parser.add_argument(
|
||||
"-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
|
||||
|
|
|
|||
|
|
@ -11,15 +11,17 @@ Notes:
|
|||
The whole thing is made as efficient as possible.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import signal
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from collections import deque
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import autodetect_device_type, compute_init
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Calculator tool helpers
|
||||
|
|
@ -33,17 +35,19 @@ def timeout(duration, formula):
|
|||
yield
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
def eval_with_timeout(formula, max_time=3):
|
||||
try:
|
||||
with timeout(max_time, formula):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", SyntaxWarning)
|
||||
return eval(formula, {"__builtins__": {}}, {})
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
signal.alarm(0)
|
||||
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
|
||||
return None
|
||||
|
||||
|
||||
def use_calculator(expr):
|
||||
"""
|
||||
Evaluate a Python expression safely.
|
||||
|
|
@ -65,9 +69,25 @@ def use_calculator(expr):
|
|||
return None
|
||||
|
||||
# Disallow dangerous patterns
|
||||
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
|
||||
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
|
||||
'getattr', 'setattr', 'delattr', 'hasattr']
|
||||
dangerous_patterns = [
|
||||
'__',
|
||||
'import',
|
||||
'exec',
|
||||
'eval',
|
||||
'compile',
|
||||
'open',
|
||||
'file',
|
||||
'input',
|
||||
'raw_input',
|
||||
'globals',
|
||||
'locals',
|
||||
'vars',
|
||||
'dir',
|
||||
'getattr',
|
||||
'setattr',
|
||||
'delattr',
|
||||
'hasattr',
|
||||
]
|
||||
expr_lower = expr.lower()
|
||||
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
||||
return None
|
||||
|
|
@ -79,6 +99,7 @@ def use_calculator(expr):
|
|||
# Evaluate with timeout
|
||||
return eval_with_timeout(expr)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class KVCache:
|
||||
"""
|
||||
|
|
@ -90,7 +111,7 @@ class KVCache:
|
|||
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
|
||||
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
|
||||
self.kv_cache = None
|
||||
self.pos = 0 # current position in time in the cache
|
||||
self.pos = 0 # current position in time in the cache
|
||||
|
||||
def reset(self):
|
||||
self.pos = 0
|
||||
|
|
@ -122,7 +143,7 @@ class KVCache:
|
|||
dtype, device = other.kv_cache.dtype, other.kv_cache.device
|
||||
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
|
||||
# 3) copy the data over
|
||||
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
|
||||
self.kv_cache[:, :, :, :, : other.pos, :] = other.kv_cache
|
||||
# 4) update the pos
|
||||
self.pos = other.pos
|
||||
|
||||
|
|
@ -135,8 +156,8 @@ class KVCache:
|
|||
t0, t1 = self.pos, self.pos + T_add
|
||||
# Dynamically grow the cache if needed
|
||||
if t1 > self.kv_cache.size(4):
|
||||
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
||||
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
||||
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
||||
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
||||
additional_shape = list(self.kv_cache.shape)
|
||||
additional_shape[4] = t_needed - self.kv_cache.size(4)
|
||||
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
|
||||
|
|
@ -173,22 +194,24 @@ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
|
|||
probs = F.softmax(logits, dim=-1)
|
||||
return torch.multinomial(probs, num_samples=1, generator=rng)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RowState:
|
||||
# Per-row state tracking during generation
|
||||
def __init__(self, current_tokens=None):
|
||||
self.current_tokens = current_tokens or [] # Current token sequence for this row
|
||||
self.forced_tokens = deque() # Queue of tokens to force inject
|
||||
self.in_python_block = False # Whether we are inside a python block
|
||||
self.python_expr_tokens = [] # Tokens of the current python expression
|
||||
self.completed = False # Whether this row has completed generation
|
||||
self.current_tokens = current_tokens or [] # Current token sequence for this row
|
||||
self.forced_tokens = deque() # Queue of tokens to force inject
|
||||
self.in_python_block = False # Whether we are inside a python block
|
||||
self.python_expr_tokens = [] # Tokens of the current python expression
|
||||
self.completed = False # Whether this row has completed generation
|
||||
|
||||
|
||||
class Engine:
|
||||
|
||||
def __init__(self, model, tokenizer):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer # needed for tool use
|
||||
self.tokenizer = tokenizer # needed for tool use
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
||||
|
|
@ -204,8 +227,8 @@ class Engine:
|
|||
python_end = get_special("<|python_end|>")
|
||||
output_start = get_special("<|output_start|>")
|
||||
output_end = get_special("<|output_end|>")
|
||||
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
|
||||
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
|
||||
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
|
||||
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
|
||||
|
||||
# 1) Run a batch 1 prefill of the prompt tokens
|
||||
m = self.model.config
|
||||
|
|
@ -229,7 +252,7 @@ class Engine:
|
|||
**kv_model_kwargs,
|
||||
)
|
||||
kv_cache_decode.prefill(kv_cache_prefill)
|
||||
del kv_cache_prefill # no need to keep this memory around
|
||||
del kv_cache_prefill # no need to keep this memory around
|
||||
|
||||
# 3) Initialize states for each sample
|
||||
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
|
||||
|
|
@ -259,12 +282,12 @@ class Engine:
|
|||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
|
||||
# Process each row: choose the next token, update state, optional tool use
|
||||
token_column = [] # contains the next token id along each row
|
||||
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
|
||||
token_column = [] # contains the next token id along each row
|
||||
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
|
||||
for i, state in enumerate(row_states):
|
||||
# Select the next token in this row
|
||||
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
|
||||
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
|
||||
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
|
||||
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
|
||||
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
|
||||
token_column.append(next_token)
|
||||
# Update the state of this row to include the next token
|
||||
|
|
@ -327,10 +350,13 @@ if __name__ == "__main__":
|
|||
is equivalent to the faster Engine.generate function here.
|
||||
"""
|
||||
import time
|
||||
|
||||
# init compute
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type()
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
|
||||
# load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
|
|
@ -357,12 +383,12 @@ if __name__ == "__main__":
|
|||
# generate tokens with Engine
|
||||
generated_tokens = []
|
||||
engine = Engine(model, tokenizer)
|
||||
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
||||
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in stream:
|
||||
token = token_column[0] # only print out the first row
|
||||
token = token_column[0] # only print out the first row
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
|
|
|
|||
|
|
@ -30,17 +30,18 @@ import platform
|
|||
import signal
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
"""Result of executing Python code in a sandbox."""
|
||||
|
||||
success: bool
|
||||
stdout: str
|
||||
stderr: str
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
timeout: bool = False
|
||||
memory_exceeded: bool = False
|
||||
|
||||
|
|
@ -101,13 +102,13 @@ class WriteOnlyStringIO(io.StringIO):
|
|||
"""StringIO that throws an exception when it's read from"""
|
||||
|
||||
def read(self, *args, **kwargs):
|
||||
raise IOError
|
||||
raise OSError
|
||||
|
||||
def readline(self, *args, **kwargs):
|
||||
raise IOError
|
||||
raise OSError
|
||||
|
||||
def readlines(self, *args, **kwargs):
|
||||
raise IOError
|
||||
raise OSError
|
||||
|
||||
def readable(self, *args, **kwargs):
|
||||
"""Returns True if the IO object can be read."""
|
||||
|
|
@ -131,7 +132,7 @@ def chdir(root):
|
|||
os.chdir(cwd)
|
||||
|
||||
|
||||
def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
||||
def reliability_guard(maximum_memory_bytes: int | None = None):
|
||||
"""
|
||||
This disables various destructive functions and prevents the generated code
|
||||
from interfering with the test (e.g. fork bomb, killing other processes,
|
||||
|
|
@ -147,6 +148,7 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
|||
if platform.uname().system != "Darwin":
|
||||
# These resource limit calls seem to fail on macOS (Darwin), skip?
|
||||
import resource
|
||||
|
||||
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
|
|
@ -211,10 +213,9 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
|||
sys.modules["tkinter"] = None
|
||||
|
||||
|
||||
def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
|
||||
def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: int | None, result_dict):
|
||||
"""Execute code in a subprocess with safety guards. Results are written to result_dict."""
|
||||
with create_tempdir():
|
||||
|
||||
# These system calls are needed when cleaning up tempdir.
|
||||
import os
|
||||
import shutil
|
||||
|
|
@ -228,14 +229,16 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
|||
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
|
||||
|
||||
# Default to failure
|
||||
result_dict.update({
|
||||
"success": False,
|
||||
"stdout": "",
|
||||
"stderr": "",
|
||||
"timeout": False,
|
||||
"memory_exceeded": False,
|
||||
"error": None,
|
||||
})
|
||||
result_dict.update(
|
||||
{
|
||||
"success": False,
|
||||
"stdout": "",
|
||||
"stderr": "",
|
||||
"timeout": False,
|
||||
"memory_exceeded": False,
|
||||
"error": None,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
exec_globals = {}
|
||||
|
|
@ -253,28 +256,36 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
|||
# uncomment the following line and proceed at your own risk:
|
||||
exec(code, exec_globals)
|
||||
|
||||
result_dict.update({
|
||||
"success": True,
|
||||
"stdout": stdout_capture.getvalue(),
|
||||
"stderr": stderr_capture.getvalue(),
|
||||
})
|
||||
result_dict.update(
|
||||
{
|
||||
"success": True,
|
||||
"stdout": stdout_capture.getvalue(),
|
||||
"stderr": stderr_capture.getvalue(),
|
||||
}
|
||||
)
|
||||
|
||||
except TimeoutException:
|
||||
result_dict.update({
|
||||
"timeout": True,
|
||||
"error": "Execution timed out",
|
||||
})
|
||||
result_dict.update(
|
||||
{
|
||||
"timeout": True,
|
||||
"error": "Execution timed out",
|
||||
}
|
||||
)
|
||||
|
||||
except MemoryError as e:
|
||||
result_dict.update({
|
||||
"memory_exceeded": True,
|
||||
"error": f"Memory limit exceeded: {e}",
|
||||
})
|
||||
result_dict.update(
|
||||
{
|
||||
"memory_exceeded": True,
|
||||
"error": f"Memory limit exceeded: {e}",
|
||||
}
|
||||
)
|
||||
|
||||
except BaseException as e:
|
||||
result_dict.update({
|
||||
"error": f"{type(e).__name__}: {e}",
|
||||
})
|
||||
result_dict.update(
|
||||
{
|
||||
"error": f"{type(e).__name__}: {e}",
|
||||
}
|
||||
)
|
||||
|
||||
# Needed for cleaning up.
|
||||
shutil.rmtree = rmtree
|
||||
|
|
@ -285,8 +296,8 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
|||
|
||||
def execute_code(
|
||||
code: str,
|
||||
timeout: float = 5.0, # 5 seconds default
|
||||
maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
|
||||
timeout: float = 5.0, # 5 seconds default
|
||||
maximum_memory_bytes: int | None = 256 * 1024 * 1024, # 256MB default
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Execute Python code in a sandboxed environment.
|
||||
|
|
@ -310,10 +321,7 @@ def execute_code(
|
|||
manager = multiprocessing.Manager()
|
||||
result_dict = manager.dict()
|
||||
|
||||
p = multiprocessing.Process(
|
||||
target=_unsafe_execute,
|
||||
args=(code, timeout, maximum_memory_bytes, result_dict)
|
||||
)
|
||||
p = multiprocessing.Process(target=_unsafe_execute, args=(code, timeout, maximum_memory_bytes, result_dict))
|
||||
p.start()
|
||||
p.join(timeout=timeout + 1)
|
||||
|
||||
|
|
@ -346,4 +354,3 @@ def execute_code(
|
|||
timeout=result_dict["timeout"],
|
||||
memory_exceeded=result_dict["memory_exceeded"],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,24 +12,25 @@ Notable features:
|
|||
"""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.muon import Muon, DistMuon
|
||||
from nanochat.adamw import DistAdamW
|
||||
from nanochat.common import get_dist_info
|
||||
from nanochat.muon import DistMuon, Muon
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
sequence_len: int = 1024
|
||||
vocab_size: int = 50304
|
||||
n_layer: int = 12
|
||||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_embd: int = 768
|
||||
|
||||
|
||||
|
|
@ -41,13 +42,14 @@ def norm(x):
|
|||
def apply_rotary_emb(x, cos, sin):
|
||||
assert x.ndim == 4 # multihead attention
|
||||
d = x.shape[3] // 2
|
||||
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
||||
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
||||
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
||||
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
||||
y2 = x1 * (-sin) + x2 * cos
|
||||
out = torch.cat([y1, y2], 3) # re-assemble
|
||||
out = out.to(x.dtype) # ensure input/output dtypes match
|
||||
out = torch.cat([y1, y2], 3) # re-assemble
|
||||
out = out.to(x.dtype) # ensure input/output dtypes match
|
||||
return out
|
||||
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
|
|
@ -73,18 +75,24 @@ class CausalSelfAttention(nn.Module):
|
|||
|
||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||
cos, sin = cos_sin
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
|
||||
q, k = norm(q), norm(k) # QK norm
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
|
||||
q, k = norm(q), norm(k) # QK norm
|
||||
q, k, v = (
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
||||
|
||||
# Apply KV cache: insert current k,v into cache, get the full view so far
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
|
||||
Tq = q.size(2) # number of queries in this forward pass
|
||||
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
||||
Tq = q.size(2) # number of queries in this forward pass
|
||||
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
||||
|
||||
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
||||
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
||||
enable_gqa = (
|
||||
self.n_head != self.n_kv_head
|
||||
) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
||||
if kv_cache is None or Tq == Tk:
|
||||
# During training (no KV cache), attend as usual with causal attention
|
||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
||||
|
|
@ -96,9 +104,9 @@ class CausalSelfAttention(nn.Module):
|
|||
else:
|
||||
# During inference AND we have a chunk of queries in this forward pass:
|
||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
||||
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
||||
prefix_len = Tk - Tq
|
||||
if prefix_len > 0: # can't be negative but could be zero
|
||||
if prefix_len > 0: # can't be negative but could be zero
|
||||
attn_mask[:, :prefix_len] = True
|
||||
# Then, causal attention within this chunk
|
||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||
|
|
@ -139,19 +147,21 @@ class GPT(nn.Module):
|
|||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = nn.ModuleDict({
|
||||
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.transformer = nn.ModuleDict(
|
||||
{
|
||||
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
}
|
||||
)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's fake
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them, but assert fail if we ever reach that amount.
|
||||
# In the future we can dynamically grow the cache, for now it's fine.
|
||||
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
||||
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
||||
head_dim = config.n_embd // config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
|
||||
def init_weights(self):
|
||||
|
|
@ -195,18 +205,23 @@ class GPT(nn.Module):
|
|||
# calculate the rotation frequencies at each (time, channel) pair
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos, sin = freqs.cos(), freqs.sin()
|
||||
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
def get_device(self):
|
||||
return self.transformer.wte.weight.device
|
||||
|
||||
def estimate_flops(self):
|
||||
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
|
||||
"""Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
nparams_embedding = self.transformer.wte.weight.numel()
|
||||
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
l, h, q, t = (
|
||||
self.config.n_layer,
|
||||
self.config.n_head,
|
||||
self.config.n_embd // self.config.n_head,
|
||||
self.config.sequence_len,
|
||||
)
|
||||
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
||||
return num_flops_per_token
|
||||
|
||||
|
|
@ -245,12 +260,16 @@ class GPT(nn.Module):
|
|||
B, T = idx.size()
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert T <= self.cos.size(1), (
|
||||
f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
)
|
||||
assert idx.device == self.cos.device, (
|
||||
f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
)
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
cos_sin = self.cos[:, T0 : T0 + T], self.sin[:, T0 : T0 + T] # truncate cache to current sequence length
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx)
|
||||
|
|
@ -265,14 +284,16 @@ class GPT(nn.Module):
|
|||
# training mode: compute and return the loss
|
||||
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
||||
logits = self.lm_head(x)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
logits = logits.float() # use tf32/fp32 for logits
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
logits = logits.float() # use tf32/fp32 for logits
|
||||
loss = F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction
|
||||
)
|
||||
return loss
|
||||
else:
|
||||
# inference mode: compute and return the logits
|
||||
logits = self.lm_head(x)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
return logits
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
@ -289,10 +310,10 @@ class GPT(nn.Module):
|
|||
if temperature > 0:
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
||||
for _ in range(max_tokens):
|
||||
logits = self.forward(ids) # (B, T, vocab_size)
|
||||
logits = logits[:, -1, :] # (B, vocab_size)
|
||||
logits = self.forward(ids) # (B, T, vocab_size)
|
||||
logits = logits[:, -1, :] # (B, vocab_size)
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
"""
|
||||
A number of functions that help with evaluating a base model.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_bpb(model, batches, steps, token_bytes):
|
||||
"""
|
||||
|
|
@ -30,20 +33,16 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
batch_iter = iter(batches)
|
||||
for _ in range(steps):
|
||||
x, y = next(batch_iter)
|
||||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
y = y.view(-1) # flatten
|
||||
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
||||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
y = y.view(-1) # flatten
|
||||
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
||||
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
||||
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
||||
valid = y >= 0
|
||||
y_safe = torch.where(valid, y, torch.zeros_like(y))
|
||||
# map valid targets to their byte length; ignored targets contribute 0 bytes
|
||||
num_bytes2d = torch.where(
|
||||
valid,
|
||||
token_bytes[y_safe],
|
||||
torch.zeros_like(y, dtype=token_bytes.dtype)
|
||||
)
|
||||
num_bytes2d = torch.where(valid, token_bytes[y_safe], torch.zeros_like(y, dtype=token_bytes.dtype))
|
||||
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
||||
total_bytes += num_bytes2d.sum()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@
|
|||
Muon optimizer from Keller et al.
|
||||
Also a lot of borrowing of ideas from modded-nanogpt.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@torch.compile
|
||||
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
||||
|
|
@ -17,8 +19,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
|||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
assert (
|
||||
G.ndim >= 2
|
||||
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
|
|
@ -28,13 +32,16 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
|||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
A = X @ X.mT
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
B = (
|
||||
b * A + c * A @ A
|
||||
) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
return X
|
||||
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
|
|
@ -57,6 +64,7 @@ class Muon(torch.optim.Optimizer):
|
|||
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
|
||||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
||||
params: list[Tensor] = [*params]
|
||||
|
|
@ -80,7 +88,7 @@ class Muon(torch.optim.Optimizer):
|
|||
buf.lerp_(g, 1 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||||
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
|
||||
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5)
|
||||
|
||||
|
||||
class DistMuon(torch.optim.Optimizer):
|
||||
|
|
@ -104,14 +112,14 @@ class DistMuon(torch.optim.Optimizer):
|
|||
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
|
||||
ns_steps: number of Newton–Schulz iterations for the orthogonalization
|
||||
"""
|
||||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
||||
nesterov: bool = True, ns_steps: int = 5):
|
||||
|
||||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, nesterov: bool = True, ns_steps: int = 5):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
||||
params = list(params)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
rank = dist.get_rank()
|
||||
# Group all parameters by their shape
|
||||
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
|
||||
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
|
||||
param_groups = []
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
|
|
@ -129,7 +137,9 @@ class DistMuon(torch.optim.Optimizer):
|
|||
world_size = dist.get_world_size()
|
||||
|
||||
# Ensure all grads exist
|
||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), (
|
||||
"All params must have grads"
|
||||
)
|
||||
|
||||
# Kick off all the reduce scatter operations to average up the gradients across all ranks
|
||||
all_reduce_futures = []
|
||||
|
|
@ -141,7 +151,7 @@ class DistMuon(torch.optim.Optimizer):
|
|||
# The compute owner of each param is rank i % world_size
|
||||
owner_idx = base_i + rank
|
||||
# each rank stacks up its chunk of world_size params into a list
|
||||
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
|
||||
rs_input = [p.grad for p in params[base_i : base_i + world_size]]
|
||||
# pad rs_input with the zero buffer to complete the group
|
||||
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
|
||||
# the output buffer gets strided across the group based on the rank
|
||||
|
|
@ -159,9 +169,9 @@ class DistMuon(torch.optim.Optimizer):
|
|||
# Go through params in groups of world_size.
|
||||
for base_i in range(0, len(params), world_size):
|
||||
# The compute owner of each param is rank i % world_size
|
||||
owner_idx = base_i + rank # calculate the index of the param that this rank owns
|
||||
owner_idx = base_i + rank # calculate the index of the param that this rank owns
|
||||
# Wait for the reduce scatter to complete
|
||||
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
|
||||
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
|
||||
future_idx += 1
|
||||
# Owner computes the Muon update, result is in its param
|
||||
if owner_idx < len(params):
|
||||
|
|
@ -174,12 +184,12 @@ class DistMuon(torch.optim.Optimizer):
|
|||
buf.lerp_(g, 1.0 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||||
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
||||
scale = max(1.0, p.size(-2) / p.size(-1)) ** 0.5
|
||||
p.add_(g, alpha=-group["lr"] * scale)
|
||||
# Replicate updated parameters to all ranks
|
||||
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
|
||||
ag_output = params[base_i:base_i + world_size]
|
||||
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
|
||||
ag_output = params[base_i : base_i + world_size]
|
||||
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
|
||||
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
|
||||
all_gather_futures.append(work)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,16 +2,18 @@
|
|||
Utilities for generating training report cards. More messy code than usual, will fix.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import socket
|
||||
import datetime
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
|
||||
def run_command(cmd):
|
||||
"""Run a shell command and return output, or None if it fails."""
|
||||
try:
|
||||
|
|
@ -22,6 +24,7 @@ def run_command(cmd):
|
|||
except:
|
||||
return None
|
||||
|
||||
|
||||
def get_git_info():
|
||||
"""Get current git commit, branch, and dirty status."""
|
||||
info = {}
|
||||
|
|
@ -38,18 +41,14 @@ def get_git_info():
|
|||
|
||||
return info
|
||||
|
||||
|
||||
def get_gpu_info():
|
||||
"""Get GPU information."""
|
||||
if not torch.cuda.is_available():
|
||||
return {"available": False}
|
||||
|
||||
num_devices = torch.cuda.device_count()
|
||||
info = {
|
||||
"available": True,
|
||||
"count": num_devices,
|
||||
"names": [],
|
||||
"memory_gb": []
|
||||
}
|
||||
info = {"available": True, "count": num_devices, "names": [], "memory_gb": []}
|
||||
|
||||
for i in range(num_devices):
|
||||
props = torch.cuda.get_device_properties(i)
|
||||
|
|
@ -61,6 +60,7 @@ def get_gpu_info():
|
|||
|
||||
return info
|
||||
|
||||
|
||||
def get_system_info():
|
||||
"""Get system information."""
|
||||
info = {}
|
||||
|
|
@ -83,6 +83,7 @@ def get_system_info():
|
|||
|
||||
return info
|
||||
|
||||
|
||||
def estimate_cost(gpu_info, runtime_hours=None):
|
||||
"""Estimate training cost based on GPU type and runtime."""
|
||||
|
||||
|
|
@ -111,9 +112,10 @@ def estimate_cost(gpu_info, runtime_hours=None):
|
|||
return {
|
||||
"hourly_rate": hourly_rate,
|
||||
"gpu_type": gpu_name,
|
||||
"estimated_total": hourly_rate * runtime_hours if runtime_hours else None
|
||||
"estimated_total": hourly_rate * runtime_hours if runtime_hours else None,
|
||||
}
|
||||
|
||||
|
||||
def generate_header():
|
||||
"""Generate the header for a training report."""
|
||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
|
@ -165,12 +167,12 @@ Generated: {timestamp}
|
|||
num_chars = len(packaged)
|
||||
num_lines = len(packaged.split('\n'))
|
||||
num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')])
|
||||
num_tokens = num_chars // 4 # assume approximately 4 chars per token
|
||||
num_tokens = num_chars // 4 # assume approximately 4 chars per token
|
||||
|
||||
# count dependencies via uv.lock
|
||||
uv_lock_lines = 0
|
||||
if os.path.exists('uv.lock'):
|
||||
with open('uv.lock', 'r', encoding='utf-8') as f:
|
||||
with open('uv.lock', encoding='utf-8') as f:
|
||||
uv_lock_lines = len(f.readlines())
|
||||
|
||||
header += f"""
|
||||
|
|
@ -184,12 +186,15 @@ Generated: {timestamp}
|
|||
"""
|
||||
return header
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def slugify(text):
|
||||
"""Slugify a text string."""
|
||||
return text.lower().replace(" ", "-")
|
||||
|
||||
|
||||
# the expected files and their order
|
||||
EXPECTED_FILES = [
|
||||
"tokenizer-training.md",
|
||||
|
|
@ -207,10 +212,11 @@ EXPECTED_FILES = [
|
|||
# the metrics we're currently interested in
|
||||
chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
|
||||
|
||||
|
||||
def extract(section, keys):
|
||||
"""simple def to extract a single key from a section"""
|
||||
if not isinstance(keys, list):
|
||||
keys = [keys] # convenience
|
||||
keys = [keys] # convenience
|
||||
out = {}
|
||||
for line in section.split("\n"):
|
||||
for key in keys:
|
||||
|
|
@ -218,6 +224,7 @@ def extract(section, keys):
|
|||
out[key] = line.split(":")[1].strip()
|
||||
return out
|
||||
|
||||
|
||||
def extract_timestamp(content, prefix):
|
||||
"""Extract timestamp from content with given prefix."""
|
||||
for line in content.split('\n'):
|
||||
|
|
@ -229,6 +236,7 @@ def extract_timestamp(content, prefix):
|
|||
pass
|
||||
return None
|
||||
|
||||
|
||||
class Report:
|
||||
"""Maintains a bunch of logs, generates a final markdown report."""
|
||||
|
||||
|
|
@ -269,14 +277,14 @@ class Report:
|
|||
report_dir = self.report_dir
|
||||
report_file = os.path.join(report_dir, "report.md")
|
||||
print(f"Generating report to {report_file}")
|
||||
final_metrics = {} # the most important final metrics we'll add as table at the end
|
||||
final_metrics = {} # the most important final metrics we'll add as table at the end
|
||||
start_time = None
|
||||
end_time = None
|
||||
with open(report_file, "w", encoding="utf-8") as out_file:
|
||||
# write the header first
|
||||
header_file = os.path.join(report_dir, "header.md")
|
||||
if os.path.exists(header_file):
|
||||
with open(header_file, "r", encoding="utf-8") as f:
|
||||
with open(header_file, encoding="utf-8") as f:
|
||||
header_content = f.read()
|
||||
out_file.write(header_content)
|
||||
start_time = extract_timestamp(header_content, "Run started:")
|
||||
|
|
@ -284,7 +292,7 @@ class Report:
|
|||
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
||||
bloat_data = bloat_data.group(1) if bloat_data else ""
|
||||
else:
|
||||
start_time = None # will cause us to not write the total wall clock time
|
||||
start_time = None # will cause us to not write the total wall clock time
|
||||
bloat_data = "[bloat data missing]"
|
||||
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
|
||||
# process all the individual sections
|
||||
|
|
@ -293,7 +301,7 @@ class Report:
|
|||
if not os.path.exists(section_file):
|
||||
print(f"Warning: {section_file} does not exist, skipping")
|
||||
continue
|
||||
with open(section_file, "r", encoding="utf-8") as in_file:
|
||||
with open(section_file, encoding="utf-8") as in_file:
|
||||
section = in_file.read()
|
||||
# Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
|
||||
if "rl" not in file_name:
|
||||
|
|
@ -307,7 +315,7 @@ class Report:
|
|||
if file_name == "chat-evaluation-sft.md":
|
||||
final_metrics["sft"] = extract(section, chat_metrics)
|
||||
if file_name == "chat-evaluation-rl.md":
|
||||
final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
|
||||
final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
|
||||
# append this section of the report
|
||||
out_file.write(section)
|
||||
out_file.write("\n")
|
||||
|
|
@ -354,7 +362,7 @@ class Report:
|
|||
else:
|
||||
out_file.write("Total wall clock time: unknown\n")
|
||||
# also cp the report.md file to current directory
|
||||
print(f"Copying report.md to current directory for convenience")
|
||||
print("Copying report.md to current directory for convenience")
|
||||
shutil.copy(report_file, "report.md")
|
||||
return report_file
|
||||
|
||||
|
|
@ -378,18 +386,23 @@ class Report:
|
|||
f.write(f"Run started: {start_time}\n\n---\n\n")
|
||||
print(f"Reset report and wrote header to {header_file}")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# nanochat-specific convenience functions
|
||||
|
||||
|
||||
class DummyReport:
|
||||
def log(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def reset(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def get_report():
|
||||
# just for convenience, only rank 0 logs to report
|
||||
from nanochat.common import get_base_dir, get_dist_info
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if ddp_rank == 0:
|
||||
report_dir = os.path.join(get_base_dir(), "report")
|
||||
|
|
@ -397,10 +410,18 @@ def get_report():
|
|||
else:
|
||||
return DummyReport()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
|
||||
parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
|
||||
parser.add_argument(
|
||||
"command",
|
||||
nargs="?",
|
||||
default="generate",
|
||||
choices=["generate", "reset"],
|
||||
help="Operation to perform (default: generate)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.command == "generate":
|
||||
get_report().generate()
|
||||
|
|
|
|||
|
|
@ -6,36 +6,39 @@ Two implementations are available:
|
|||
2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
|
||||
"""
|
||||
|
||||
import os
|
||||
import copy
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
SPECIAL_TOKENS = [
|
||||
# every document begins with the Beginning of Sequence (BOS) token that delimits documents
|
||||
"<|bos|>",
|
||||
# tokens below are only used during finetuning to render Conversations into token ids
|
||||
"<|user_start|>", # user messages
|
||||
"<|user_start|>", # user messages
|
||||
"<|user_end|>",
|
||||
"<|assistant_start|>", # assistant messages
|
||||
"<|assistant_start|>", # assistant messages
|
||||
"<|assistant_end|>",
|
||||
"<|python_start|>", # assistant invokes python REPL tool
|
||||
"<|python_start|>", # assistant invokes python REPL tool
|
||||
"<|python_end|>",
|
||||
"<|output_start|>", # python REPL outputs back to assistant
|
||||
"<|output_start|>", # python REPL outputs back to assistant
|
||||
"<|output_end|>",
|
||||
]
|
||||
|
||||
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
|
||||
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
|
||||
# I haven't validated that this is actually a good idea, TODO.
|
||||
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
SPLIT_PATTERN = (
|
||||
r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
|
||||
from tokenizers import Regex, decoders, pre_tokenizers
|
||||
from tokenizers import Tokenizer as HFTokenizer
|
||||
from tokenizers import pre_tokenizers, decoders, Regex
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
|
||||
class HuggingFaceTokenizer:
|
||||
"""Light wrapper around HuggingFace Tokenizer for some utilities"""
|
||||
|
||||
|
|
@ -59,11 +62,13 @@ class HuggingFaceTokenizer:
|
|||
def train_from_iterator(cls, text_iterator, vocab_size):
|
||||
# train from an iterator of text
|
||||
# Configure the HuggingFace Tokenizer
|
||||
tokenizer = HFTokenizer(BPE(
|
||||
byte_fallback=True, # needed!
|
||||
unk_token=None,
|
||||
fuse_unk=False,
|
||||
))
|
||||
tokenizer = HFTokenizer(
|
||||
BPE(
|
||||
byte_fallback=True, # needed!
|
||||
unk_token=None,
|
||||
fuse_unk=False,
|
||||
)
|
||||
)
|
||||
# Normalizer: None
|
||||
tokenizer.normalizer = None
|
||||
# Pre-tokenizer: GPT-4 style
|
||||
|
|
@ -71,11 +76,13 @@ class HuggingFaceTokenizer:
|
|||
# NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
|
||||
# very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
|
||||
# (but I haven't validated this! TODO)
|
||||
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
||||
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
|
||||
])
|
||||
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||
[
|
||||
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
|
||||
]
|
||||
)
|
||||
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
# Post-processor: None
|
||||
|
|
@ -84,7 +91,7 @@ class HuggingFaceTokenizer:
|
|||
trainer = BpeTrainer(
|
||||
vocab_size=vocab_size,
|
||||
show_progress=True,
|
||||
min_frequency=0, # no minimum frequency
|
||||
min_frequency=0, # no minimum frequency
|
||||
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
||||
special_tokens=SPECIAL_TOKENS,
|
||||
)
|
||||
|
|
@ -146,12 +153,16 @@ class HuggingFaceTokenizer:
|
|||
self.tokenizer.save(tokenizer_path)
|
||||
print(f"Saved tokenizer to {tokenizer_path}")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tokenizer based on rustbpe + tiktoken combo
|
||||
import pickle
|
||||
import rustbpe
|
||||
|
||||
import tiktoken
|
||||
|
||||
import rustbpe
|
||||
|
||||
|
||||
class RustBPETokenizer:
|
||||
"""Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
|
||||
|
||||
|
|
@ -176,8 +187,8 @@ class RustBPETokenizer:
|
|||
enc = tiktoken.Encoding(
|
||||
name="rustbpe",
|
||||
pat_str=pattern,
|
||||
mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
|
||||
special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
|
||||
mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
|
||||
special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
|
||||
)
|
||||
return cls(enc, "<|bos|>")
|
||||
|
||||
|
|
@ -225,14 +236,14 @@ class RustBPETokenizer:
|
|||
if isinstance(text, str):
|
||||
ids = self.enc.encode_ordinary(text)
|
||||
if prepend is not None:
|
||||
ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
|
||||
ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
|
||||
if append is not None:
|
||||
ids.append(append_id)
|
||||
elif isinstance(text, list):
|
||||
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
|
||||
if prepend is not None:
|
||||
for ids_row in ids:
|
||||
ids_row.insert(0, prepend_id) # TODO: same
|
||||
ids_row.insert(0, prepend_id) # TODO: same
|
||||
if append is not None:
|
||||
for ids_row in ids:
|
||||
ids_row.append(append_id)
|
||||
|
|
@ -264,6 +275,7 @@ class RustBPETokenizer:
|
|||
"""
|
||||
# ids, masks that we will return and a helper function to help build them up.
|
||||
ids, mask = [], []
|
||||
|
||||
def add_tokens(token_ids, mask_val):
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
|
|
@ -274,7 +286,7 @@ class RustBPETokenizer:
|
|||
# => just merge it with the second (user) message
|
||||
if conversation["messages"][0]["role"] == "system":
|
||||
# some conversation surgery is necessary here for now...
|
||||
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
||||
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
||||
messages = conversation["messages"]
|
||||
assert messages[1]["role"] == "user", "System message must be followed by a user message"
|
||||
messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
|
||||
|
|
@ -286,17 +298,21 @@ class RustBPETokenizer:
|
|||
# fetch all the special tokens we need
|
||||
bos = self.get_bos_token_id()
|
||||
user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
|
||||
assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
|
||||
assistant_start, assistant_end = (
|
||||
self.encode_special("<|assistant_start|>"),
|
||||
self.encode_special("<|assistant_end|>"),
|
||||
)
|
||||
python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
|
||||
output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
|
||||
|
||||
# now we can tokenize the conversation
|
||||
add_tokens(bos, 0)
|
||||
for i, message in enumerate(messages):
|
||||
|
||||
# some sanity checking here around assumptions, to prevent footguns
|
||||
must_be_from = "user" if i % 2 == 0 else "assistant"
|
||||
assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
|
||||
assert message["role"] == must_be_from, (
|
||||
f"Message {i} is from {message['role']} but should be from {must_be_from}"
|
||||
)
|
||||
|
||||
# content can be either a simple string or a list of parts (e.g. containing tool calls)
|
||||
content = message["content"]
|
||||
|
|
@ -363,10 +379,10 @@ class RustBPETokenizer:
|
|||
Unlike the Chat SFT case, we don't need to return the mask.
|
||||
"""
|
||||
# We have some surgery to do: we need to pop the last message (of the Assistant)
|
||||
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
||||
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
||||
messages = conversation["messages"]
|
||||
assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
|
||||
messages.pop() # remove the last message (of the Assistant) inplace
|
||||
messages.pop() # remove the last message (of the Assistant) inplace
|
||||
|
||||
# Now tokenize the conversation
|
||||
ids, mask = self.render_conversation(conversation)
|
||||
|
|
@ -376,23 +392,31 @@ class RustBPETokenizer:
|
|||
ids.append(assistant_start)
|
||||
return ids
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# nanochat-specific convenience functions
|
||||
|
||||
|
||||
def get_tokenizer():
|
||||
from nanochat.common import get_base_dir
|
||||
|
||||
base_dir = get_base_dir()
|
||||
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
||||
# return HuggingFaceTokenizer.from_directory(tokenizer_dir)
|
||||
return RustBPETokenizer.from_directory(tokenizer_dir)
|
||||
|
||||
|
||||
def get_token_bytes(device="cpu"):
|
||||
import torch
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
|
||||
base_dir = get_base_dir()
|
||||
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
||||
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
|
||||
assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
|
||||
assert os.path.exists(token_bytes_path), (
|
||||
f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
|
||||
)
|
||||
with open(token_bytes_path, "rb") as f:
|
||||
token_bytes = torch.load(f, map_location=device)
|
||||
return token_bytes
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ manifest-path = "rustbpe/Cargo.toml"
|
|||
dev = [
|
||||
"maturin>=1.9.4",
|
||||
"pytest>=8.0.0",
|
||||
"pre-commit>=3.8.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
|
@ -45,33 +46,58 @@ python_functions = ["test_*"]
|
|||
|
||||
# target torch to cuda 12.8 or CPU
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", extra = "cpu" },
|
||||
{ index = "pytorch-cu128", extra = "gpu" },
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", extra = "cpu" },
|
||||
{ index = "pytorch-cu128", extra = "gpu" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[project.optional-dependencies]
|
||||
cpu = [
|
||||
"torch>=2.8.0",
|
||||
]
|
||||
gpu = [
|
||||
"torch>=2.8.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "cpu" },
|
||||
{ extra = "gpu" },
|
||||
],
|
||||
]
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[project.optional-dependencies]
|
||||
cpu = [
|
||||
"torch>=2.8.0",
|
||||
]
|
||||
gpu = [
|
||||
"torch>=2.8.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "cpu" },
|
||||
{ extra = "gpu" },
|
||||
],
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
line-length = 120
|
||||
fix = true
|
||||
unsafe-fixes = true
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"F", # Pyflakes (unused imports) - replaces autoflake
|
||||
"I", # isort - replaces isort
|
||||
"UP", # pyupgrade - replaces pyupgrade
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["nanochat"]
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "preserve"
|
||||
|
||||
[tool.codespell]
|
||||
write-changes = true
|
||||
interactive = 1
|
||||
skip = "tests/*,dev/*,scripts/tok_eval.py,tasks/spellingbee.py"
|
||||
ignore-words-list = "re-use,astroid"
|
||||
|
|
|
|||
|
|
@ -9,23 +9,31 @@ torchrun --nproc_per_node=8 -m scripts.base_eval
|
|||
|
||||
The script will print the CORE metric to the console.
|
||||
"""
|
||||
import os
|
||||
|
||||
import csv
|
||||
import time
|
||||
import json
|
||||
import yaml
|
||||
import shutil
|
||||
import os
|
||||
import random
|
||||
import zipfile
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import zipfile
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import (
|
||||
autodetect_device_type,
|
||||
compute_cleanup,
|
||||
compute_init,
|
||||
download_file_with_lock,
|
||||
get_base_dir,
|
||||
print0,
|
||||
)
|
||||
from nanochat.core_eval import evaluate_task
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# nanochat specific function dealing with I/O etc.
|
||||
|
|
@ -33,6 +41,7 @@ from nanochat.core_eval import evaluate_task
|
|||
# ~162MB of data needed to evaluate the CORE metric
|
||||
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
|
||||
|
||||
|
||||
def place_eval_bundle(file_path):
|
||||
# here file_path is the path to the eval_bundle.zip file
|
||||
# we need to unzip it and place it in the base directory
|
||||
|
|
@ -45,6 +54,7 @@ def place_eval_bundle(file_path):
|
|||
shutil.move(extracted_bundle_dir, eval_bundle_dir)
|
||||
print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
|
||||
|
||||
|
||||
def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||
"""
|
||||
Evaluate a base model on the CORE benchmark.
|
||||
|
|
@ -59,13 +69,13 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
config_path = os.path.join(eval_bundle_dir, "core.yaml")
|
||||
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
|
||||
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
with open(config_path, encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
tasks = config['icl_tasks']
|
||||
|
||||
# Load random baseline values from eval metadata
|
||||
random_baselines = {}
|
||||
with open(eval_meta_data, 'r', encoding='utf-8') as f:
|
||||
with open(eval_meta_data, encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
task_name = row['Eval Task']
|
||||
|
|
@ -82,13 +92,13 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
'task_type': task['icl_task_type'],
|
||||
'dataset_uri': task['dataset_uri'],
|
||||
'num_fewshot': task['num_fewshot'][0],
|
||||
'continuation_delimiter': task.get('continuation_delimiter', ' ')
|
||||
'continuation_delimiter': task.get('continuation_delimiter', ' '),
|
||||
}
|
||||
print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
|
||||
|
||||
# Load data for this task
|
||||
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
with open(data_path, encoding='utf-8') as f:
|
||||
data = [json.loads(line.strip()) for line in f]
|
||||
|
||||
# shuffle the data because in many cases it appears ordered but we want
|
||||
|
|
@ -109,18 +119,17 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
|
||||
|
||||
core_metric = sum(centered_results.values()) / len(centered_results)
|
||||
out = {
|
||||
"results": results,
|
||||
"centered_results": centered_results,
|
||||
"core_metric": core_metric
|
||||
}
|
||||
out = {"results": results, "centered_results": centered_results, "core_metric": core_metric}
|
||||
return out
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HuggingFace loading utilities and light wrappers for a model
|
||||
|
||||
|
||||
class ModelWrapper:
|
||||
"""Lightweight wrapper for a HuggingFace model"""
|
||||
|
||||
def __init__(self, model, max_seq_len=None):
|
||||
self.model = model
|
||||
self.max_seq_len = max_seq_len
|
||||
|
|
@ -130,10 +139,12 @@ class ModelWrapper:
|
|||
logits = outputs.logits
|
||||
return logits
|
||||
|
||||
|
||||
def load_hf_model(hf_path: str, device):
|
||||
print0(f"Loading model from: {hf_path}")
|
||||
# Load the model
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
|
@ -143,9 +154,11 @@ def load_hf_model(hf_path: str, device):
|
|||
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||
|
|
@ -154,7 +167,9 @@ def main():
|
|||
# distributed / precision setup
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
|
||||
# Load model and tokenizer from command line or from file system
|
||||
if args.hf_path is not None:
|
||||
|
|
@ -162,13 +177,13 @@ def main():
|
|||
hf_path = args.hf_path
|
||||
print0(f"Loading huggingface model from: {hf_path}")
|
||||
model, tokenizer = load_hf_model(hf_path, device)
|
||||
model_name = hf_path # just for logging
|
||||
model_slug = hf_path.replace("/", "-") # for the output csv file
|
||||
model_name = hf_path # just for logging
|
||||
model_slug = hf_path.replace("/", "-") # for the output csv file
|
||||
else:
|
||||
# load a local model from the file system
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
model_name = f"base_model (step {meta['step']})" # just for logging
|
||||
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
||||
model_name = f"base_model (step {meta['step']})" # just for logging
|
||||
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
||||
|
||||
# Evaluate the model
|
||||
with autocast_ctx:
|
||||
|
|
@ -190,23 +205,28 @@ def main():
|
|||
f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n")
|
||||
f.write(f"{'CORE':<35}, {'':<10}, {core_metric:<10.6f}\n")
|
||||
# Print the content of the csv file to console too
|
||||
print0("="*80)
|
||||
print0("=" * 80)
|
||||
print0(f"Model: {model_name}")
|
||||
print0("="*80)
|
||||
with open(output_csv_path, 'r', encoding='utf-8') as f:
|
||||
print0("=" * 80)
|
||||
with open(output_csv_path, encoding='utf-8') as f:
|
||||
print0(f.read())
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Base model evaluation", data=[
|
||||
{
|
||||
"Model": model_name,
|
||||
"CORE metric": core_metric,
|
||||
},
|
||||
centered_results, # the full table
|
||||
])
|
||||
|
||||
get_report().log(
|
||||
section="Base model evaluation",
|
||||
data=[
|
||||
{
|
||||
"Model": model_name,
|
||||
"CORE metric": core_metric,
|
||||
},
|
||||
centered_results, # the full table
|
||||
],
|
||||
)
|
||||
|
||||
compute_cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -6,30 +6,35 @@ Loads a checkpoint, and:
|
|||
Example run as:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
"""
|
||||
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
||||
from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, print0
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
|
||||
# Configuration
|
||||
device_batch_size = 32
|
||||
split_tokens = 20*524288 # number of tokens to evaluate per split
|
||||
model_tag = None # optional model tag for the output directory name
|
||||
model_step = None # optional model step for the output directory name
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
split_tokens = 20 * 524288 # number of tokens to evaluate per split
|
||||
model_tag = None # optional model tag for the output directory name
|
||||
model_step = None # optional model step for the output directory name
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
|
||||
# Evaluate the loss on each split
|
||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||
|
|
@ -67,13 +72,17 @@ if ddp_rank == 0:
|
|||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Base model loss", data=[
|
||||
{
|
||||
"train bpb": bpb_results["train"],
|
||||
"val bpb": bpb_results["val"],
|
||||
},
|
||||
{f"sample {i}": sample for i, sample in enumerate(samples)},
|
||||
])
|
||||
|
||||
get_report().log(
|
||||
section="Base model loss",
|
||||
data=[
|
||||
{
|
||||
"train bpb": bpb_results["train"],
|
||||
"val bpb": bpb_results["val"],
|
||||
},
|
||||
{f"sample {i}": sample for i, sample in enumerate(samples)},
|
||||
],
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
compute_cleanup()
|
||||
|
|
|
|||
|
|
@ -12,67 +12,83 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 -
|
|||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.checkpoint_manager import load_checkpoint, save_checkpoint
|
||||
from nanochat.common import (
|
||||
DummyWandb,
|
||||
autodetect_device_type,
|
||||
compute_cleanup,
|
||||
compute_init,
|
||||
get_base_dir,
|
||||
print0,
|
||||
print_banner,
|
||||
)
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.tokenizer import get_token_bytes, get_tokenizer
|
||||
from scripts.base_eval import evaluate_model
|
||||
|
||||
print_banner()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# User settings
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# Runtime
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
# Training horizon. Only one of these 3 will be used, in this order of precedence.
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
target_flops = (
|
||||
-1.0
|
||||
) # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||
target_param_data_ratio = (
|
||||
20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
)
|
||||
# Optimization
|
||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
||||
total_batch_size = 524288 # total desired batch size, in #tokens
|
||||
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
||||
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
|
||||
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
|
||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
||||
total_batch_size = 524288 # total desired batch size, in #tokens
|
||||
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
||||
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
|
||||
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
|
||||
# Evaluation
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
eval_tokens = 20 * 524288 # number of tokens to evaluate val loss on
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
|
||||
# Output
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
|
|
@ -88,9 +104,9 @@ print0(f"Vocab size: {vocab_size:,}")
|
|||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
num_layers = depth
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
|
|
@ -98,8 +114,8 @@ print0(f"num_kv_heads: {num_kv_heads}")
|
|||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
|
|
@ -110,7 +126,14 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {
|
|||
# Initialize the Model
|
||||
|
||||
# Create a new model with random weights
|
||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
model_config_kwargs = dict(
|
||||
sequence_len=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_heads,
|
||||
n_kv_head=num_kv_heads,
|
||||
n_embd=model_dim,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
|
|
@ -119,17 +142,19 @@ model.init_weights()
|
|||
|
||||
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
||||
base_dir = get_base_dir()
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||
resuming = resume_from_step != -1
|
||||
if resuming:
|
||||
print0(f"Resuming optimization from step {resume_from_step}")
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank)
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(
|
||||
checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank
|
||||
)
|
||||
model.load_state_dict(model_data, strict=True, assign=True)
|
||||
del model_data # free up this memory after the copy
|
||||
del model_data # free up this memory after the copy
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print0(f"Number of parameters: {num_params:,}")
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
|
|
@ -152,30 +177,37 @@ else:
|
|||
raise ValueError("No training horizon specified")
|
||||
total_tokens = total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay
|
||||
)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
if resuming:
|
||||
for opt, dat in zip(optimizers, optimizer_data):
|
||||
opt.load_state_dict(dat)
|
||||
del optimizer_data # free up the memory
|
||||
del optimizer_data # free up the memory
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
train_loader = tokenizing_distributed_data_loader_with_state(
|
||||
device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict
|
||||
)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(
|
||||
device_batch_size, max_seq_len, split="val", device=device
|
||||
)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Set up hyperparameter schedulers
|
||||
|
||||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(warmup_ratio * num_iterations)
|
||||
|
|
@ -188,20 +220,22 @@ def get_lr_multiplier(it):
|
|||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * final_lr_frac
|
||||
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
frac = min(it / 300, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
return momentum
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Loop state (variables updated by the training loop)
|
||||
|
||||
if not resuming:
|
||||
step = 0
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
else:
|
||||
step = meta_data["step"]
|
||||
loop_state = meta_data["loop_state"]
|
||||
|
|
@ -212,7 +246,7 @@ else:
|
|||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
while True:
|
||||
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
|
|
@ -225,12 +259,14 @@ while True:
|
|||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"val/bpb": val_bpb,
|
||||
})
|
||||
wandb_run.log(
|
||||
{
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"val/bpb": val_bpb,
|
||||
}
|
||||
)
|
||||
model.train()
|
||||
|
||||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
|
|
@ -241,12 +277,14 @@ while True:
|
|||
with autocast_ctx:
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"core_metric": results["core_metric"],
|
||||
"centered_results": results["centered_results"],
|
||||
})
|
||||
wandb_run.log(
|
||||
{
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"core_metric": results["core_metric"],
|
||||
"centered_results": results["centered_results"],
|
||||
}
|
||||
)
|
||||
model.train()
|
||||
|
||||
# once in a while: sample from the model (only on master process)
|
||||
|
|
@ -262,7 +300,7 @@ while True:
|
|||
"My favorite color is",
|
||||
"If 5*x + 3 = 13, then x is",
|
||||
]
|
||||
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
||||
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
|
|
@ -275,17 +313,17 @@ while True:
|
|||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(), # model parameters
|
||||
[opt.state_dict() for opt in optimizers], # optimizer states
|
||||
{ # metadata saved as json
|
||||
orig_model.state_dict(), # model parameters
|
||||
[opt.state_dict() for opt in optimizers], # optimizer states
|
||||
{ # metadata saved as json
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
"dataloader_state_dict": dataloader_state_dict,
|
||||
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||
"min_val_bpb": min_val_bpb,
|
||||
"smooth_train_loss": smooth_train_loss,
|
||||
"total_training_time": total_training_time,
|
||||
|
|
@ -306,15 +344,17 @@ while True:
|
|||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
x, y, dataloader_state_dict = next(
|
||||
train_loader
|
||||
) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# gradient clipping
|
||||
grad_clip_enabled = grad_clip > 0.0
|
||||
if grad_clip_enabled:
|
||||
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
|
||||
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
|
|
@ -332,18 +372,20 @@ while True:
|
|||
# -------------------------------------------------------------------------
|
||||
|
||||
# logging
|
||||
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else ""
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
print0(
|
||||
f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time / 60:.2f}m"
|
||||
)
|
||||
if step % 100 == 0:
|
||||
log_data = {
|
||||
"step": step,
|
||||
|
|
@ -364,35 +406,39 @@ while True:
|
|||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Total training time: {total_training_time / 60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Base model training", data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of parameters": num_params,
|
||||
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
||||
"Calculated number of iterations": num_iterations,
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": warmup_ratio,
|
||||
"warmdown_ratio": warmdown_ratio,
|
||||
"final_lr_frac": final_lr_frac,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
"Final validation bpb": val_bpb,
|
||||
"CORE metric estimate": results.get("core_metric", None),
|
||||
"MFU %": f"{mfu:.2f}%",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time/60:.2f}m",
|
||||
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
||||
}
|
||||
])
|
||||
|
||||
get_report().log(
|
||||
section="Base model training",
|
||||
data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of parameters": num_params,
|
||||
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
||||
"Calculated number of iterations": num_iterations,
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": warmup_ratio,
|
||||
"warmdown_ratio": warmdown_ratio,
|
||||
"final_lr_frac": final_lr_frac,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
"Final validation bpb": val_bpb,
|
||||
"CORE metric estimate": results.get("core_metric", None),
|
||||
"MFU %": f"{mfu:.2f}%",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time / 60:.2f}m",
|
||||
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# cleanup
|
||||
wandb_run.finish() # wandb run finish
|
||||
wandb_run.finish() # wandb run finish
|
||||
compute_cleanup()
|
||||
|
|
|
|||
|
|
@ -4,12 +4,15 @@ New and upgraded chat mode because a lot of the code has changed since the last
|
|||
Intended to be run single GPU only atm:
|
||||
python -m scripts.chat_cli -i mid
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from contextlib import nullcontext
|
||||
from nanochat.engine import Engine
|
||||
|
||||
import torch
|
||||
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import autodetect_device_type, compute_init
|
||||
from nanochat.engine import Engine
|
||||
|
||||
parser = argparse.ArgumentParser(description='Chat with the model')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
||||
|
|
@ -18,7 +21,13 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
|||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument(
|
||||
'--device-type',
|
||||
type=str,
|
||||
default='',
|
||||
choices=['cuda', 'cpu', 'mps'],
|
||||
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
|
||||
)
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -33,7 +42,10 @@ model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag
|
|||
# Special tokens for the chat state machine
|
||||
bos = tokenizer.get_bos_token_id()
|
||||
user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
|
||||
assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>")
|
||||
assistant_start, assistant_end = (
|
||||
tokenizer.encode_special("<|assistant_start|>"),
|
||||
tokenizer.encode_special("<|assistant_end|>"),
|
||||
)
|
||||
|
||||
# Create Engine for efficient generation
|
||||
engine = Engine(model, tokenizer)
|
||||
|
|
@ -47,7 +59,6 @@ print("-" * 50)
|
|||
conversation_tokens = [bos]
|
||||
|
||||
while True:
|
||||
|
||||
if args.prompt:
|
||||
# Get the prompt from the launch command
|
||||
user_input = args.prompt
|
||||
|
|
@ -89,7 +100,7 @@ while True:
|
|||
print("\nAssistant: ", end="", flush=True)
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
|
||||
token = token_column[0] # pop the batch dimension (num_samples=1)
|
||||
token = token_column[0] # pop the batch dimension (num_samples=1)
|
||||
response_tokens.append(token)
|
||||
token_text = tokenizer.decode([token])
|
||||
print(token_text, end="", flush=True)
|
||||
|
|
|
|||
|
|
@ -9,27 +9,28 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
|||
"""
|
||||
|
||||
import argparse
|
||||
from functools import partial
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, get_dist_info, print0
|
||||
from nanochat.engine import Engine
|
||||
|
||||
from tasks.humaneval import HumanEval
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.humaneval import HumanEval
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.spellingbee import SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
|
||||
|
||||
def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None):
|
||||
|
||||
def run_generative_eval(
|
||||
task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None
|
||||
):
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
device = model.get_device()
|
||||
|
||||
|
|
@ -62,7 +63,7 @@ def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_
|
|||
num_passed += int(passed)
|
||||
|
||||
# Logging (overwrite the same line in the console)
|
||||
print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100*num_passed/total:.2f}%)", end='', flush=True)
|
||||
print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100 * num_passed / total:.2f}%)", end='', flush=True)
|
||||
|
||||
# Finish the in-place progress line with a newline before final summary
|
||||
print()
|
||||
|
|
@ -77,21 +78,22 @@ def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_
|
|||
total = total_tensor.item()
|
||||
|
||||
print0("=" * 50)
|
||||
print0(f"Final: {num_passed}/{total} ({100*num_passed/total:.2f}%)")
|
||||
print0(f"Final: {num_passed}/{total} ({100 * num_passed / total:.2f}%)")
|
||||
|
||||
# Return the accuracy
|
||||
return num_passed/total
|
||||
return num_passed / total
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Categorical evaluation loop
|
||||
# A lot easier because we don't have to sample. Therefore, we can actually go
|
||||
# batches at a time and just check the logits for correct answer choices.
|
||||
|
||||
def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
|
||||
|
||||
def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
device = model.get_device()
|
||||
bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
|
||||
bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
|
||||
|
||||
# We'll process batches of independent problems at a time because there is no sampling needed
|
||||
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
|
||||
|
|
@ -99,22 +101,26 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
|
|||
num_batches = ceil_div(num_problems, batch_size)
|
||||
|
||||
# Run the evaluation
|
||||
letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work
|
||||
letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work
|
||||
num_passed, total = 0, 0
|
||||
for i in range(ddp_rank, num_batches, ddp_world_size):
|
||||
i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems)
|
||||
|
||||
# Prepare the batch of problems. They might all be of different length, so we pad/collate them.
|
||||
conversations = [task_object[ii] for ii in range(i0, i1)]
|
||||
prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works
|
||||
prompt_ids = [
|
||||
tokenizer.render_for_completion(conversation) for conversation in conversations
|
||||
] # TODO: remake the way this works
|
||||
max_length = max(len(ids) for ids in prompt_ids)
|
||||
answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer)
|
||||
answer_time_positions = [
|
||||
len(ids) - 1 for ids in prompt_ids
|
||||
] # where the last token is (and the predicted answer)
|
||||
padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids]
|
||||
prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device)
|
||||
|
||||
# Get the logits for the whole batch of conversations in parallel (efficiency win here)
|
||||
with torch.no_grad():
|
||||
logits = model(prompt_ids) # (B, T, V)
|
||||
logits = model(prompt_ids) # (B, T, V)
|
||||
|
||||
# Focus on the available answer on just the letters corresponding to choices
|
||||
# Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters
|
||||
|
|
@ -150,15 +156,26 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
|
|||
num_passed = num_passed_tensor.item()
|
||||
total = total_tensor.item()
|
||||
|
||||
average = num_passed/total
|
||||
print0(f"Final: {num_passed}/{total} ({100*average:.2f}%)")
|
||||
average = num_passed / total
|
||||
print0(f"Final: {num_passed}/{total} ({100 * average:.2f}%)")
|
||||
return average
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def run_chat_eval(task_name, model, tokenizer, engine,
|
||||
batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50,
|
||||
max_problems=None):
|
||||
|
||||
def run_chat_eval(
|
||||
task_name,
|
||||
model,
|
||||
tokenizer,
|
||||
engine,
|
||||
batch_size=1,
|
||||
num_samples=1,
|
||||
max_new_tokens=512,
|
||||
temperature=0.0,
|
||||
top_k=50,
|
||||
max_problems=None,
|
||||
):
|
||||
# Create the evaluation object
|
||||
task_module = {
|
||||
'HumanEval': HumanEval,
|
||||
|
|
@ -171,20 +188,36 @@ def run_chat_eval(task_name, model, tokenizer, engine,
|
|||
task_object = task_module()
|
||||
# Run the evaluation
|
||||
if task_object.eval_type == 'generative':
|
||||
acc = run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=max_problems)
|
||||
acc = run_generative_eval(
|
||||
task_object,
|
||||
tokenizer,
|
||||
model,
|
||||
engine,
|
||||
num_samples,
|
||||
max_new_tokens,
|
||||
temperature,
|
||||
top_k,
|
||||
max_problems=max_problems,
|
||||
)
|
||||
elif task_object.eval_type == 'categorical':
|
||||
acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems)
|
||||
else:
|
||||
raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}")
|
||||
return acc
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl")
|
||||
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
|
||||
parser.add_argument(
|
||||
'-a',
|
||||
'--task-name',
|
||||
type=str,
|
||||
default=None,
|
||||
help="Task name. Default = all tasks. Use | to split multiple tasks.",
|
||||
)
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.0)
|
||||
parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
|
||||
|
|
@ -194,13 +227,21 @@ if __name__ == "__main__":
|
|||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument(
|
||||
'--device-type',
|
||||
type=str,
|
||||
default='',
|
||||
choices=['cuda', 'cpu', 'mps'],
|
||||
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
|
@ -208,12 +249,12 @@ if __name__ == "__main__":
|
|||
# Get the tasks to evaluate on
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
|
||||
baseline_accuracies = {
|
||||
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'GSM8K': 0.0, # open-ended => 0%
|
||||
'HumanEval': 0.0, # open-ended => 0%
|
||||
'SpellingBee': 0.0, # open-ended => 0%
|
||||
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'GSM8K': 0.0, # open-ended => 0%
|
||||
'HumanEval': 0.0, # open-ended => 0%
|
||||
'SpellingBee': 0.0, # open-ended => 0%
|
||||
}
|
||||
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
|
||||
|
||||
|
|
@ -223,7 +264,9 @@ if __name__ == "__main__":
|
|||
with autocast_ctx:
|
||||
acc = run_chat_eval(
|
||||
task_name,
|
||||
model, tokenizer, engine,
|
||||
model,
|
||||
tokenizer,
|
||||
engine,
|
||||
batch_size=args.batch_size,
|
||||
num_samples=args.num_samples,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
|
|
@ -236,6 +279,7 @@ if __name__ == "__main__":
|
|||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
|
||||
all_tasks_were_evaluated = all(task_name in results for task_name in all_tasks)
|
||||
# calculate the ChatCORE metric if we can (similar to CORE, it's the mean centered accuracy)
|
||||
# this way, ChatCORE ranges from 0 (at random baseline) to 1 (peak performance)
|
||||
|
|
@ -248,10 +292,13 @@ if __name__ == "__main__":
|
|||
centered_mean += centered_acc
|
||||
chatcore_metric = centered_mean / len(results)
|
||||
chatcore_metric_dict = {"ChatCORE metric": chatcore_metric}
|
||||
get_report().log(section="Chat evaluation " + args.source, data=[
|
||||
vars(args), # CLI args
|
||||
results,
|
||||
chatcore_metric_dict,
|
||||
])
|
||||
get_report().log(
|
||||
section="Chat evaluation " + args.source,
|
||||
data=[
|
||||
vars(args), # CLI args
|
||||
results,
|
||||
chatcore_metric_dict,
|
||||
],
|
||||
)
|
||||
|
||||
compute_cleanup()
|
||||
|
|
|
|||
|
|
@ -16,46 +16,46 @@ python -m scripts.chat_rl
|
|||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
|
||||
"""
|
||||
|
||||
import os
|
||||
import itertools
|
||||
import re
|
||||
import wandb
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||
from nanochat.checkpoint_manager import load_model, save_checkpoint
|
||||
from nanochat.common import DummyWandb, compute_cleanup, compute_init, get_base_dir, print0
|
||||
from nanochat.engine import Engine
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
||||
# RL hyperparameters
|
||||
run = "dummy" # wandb run name
|
||||
source = "sft" # mid|sft
|
||||
run = "dummy" # wandb run name
|
||||
source = "sft" # mid|sft
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 8 # no forward pass will go above this to not OOM
|
||||
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
|
||||
num_samples = 16 # number of samples per example (/question)
|
||||
device_batch_size = 8 # no forward pass will go above this to not OOM
|
||||
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
|
||||
num_samples = 16 # number of samples per example (/question)
|
||||
max_new_tokens = 256
|
||||
temperature = 1.0
|
||||
top_k = 50 # TODO: try None?
|
||||
top_k = 50 # TODO: try None?
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
weight_decay = 0.0
|
||||
init_lr_frac = 0.05
|
||||
num_epochs = 1 # how many epochs of gsm8k to train on
|
||||
save_every = 60 # every how many steps to save the model
|
||||
eval_every = 60 # every how many steps to evaluate the model for val pass@k
|
||||
eval_examples = 400 # number of examples used for evaluating pass@k
|
||||
num_epochs = 1 # how many epochs of gsm8k to train on
|
||||
save_every = 60 # every how many steps to save the model
|
||||
eval_every = 60 # every how many steps to evaluate the model for val pass@k
|
||||
eval_examples = 400 # number of examples used for evaluating pass@k
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Init compute/precision
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
|
||||
|
|
@ -65,7 +65,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl
|
|||
|
||||
# Init model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval")
|
||||
engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Rollout / sampling generator loop that yields batches of examples for training
|
||||
|
|
@ -75,12 +75,16 @@ val_task = GSM8K(subset="main", split="test")
|
|||
num_steps = (len(train_task) // examples_per_step) * num_epochs
|
||||
print0(f"Calculated number of steps: {num_steps}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_batch():
|
||||
assistant_end = tokenizer.encode_special("<|assistant_end|>") # ok to use this token, it's only for padding and isn't used in the loss.
|
||||
rank_indices = range(ddp_rank, len(train_task), ddp_world_size) # each rank is responsible for different examples in the training data
|
||||
assistant_end = tokenizer.encode_special(
|
||||
"<|assistant_end|>"
|
||||
) # ok to use this token, it's only for padding and isn't used in the loss.
|
||||
rank_indices = range(
|
||||
ddp_rank, len(train_task), ddp_world_size
|
||||
) # each rank is responsible for different examples in the training data
|
||||
for example_idx in itertools.cycle(rank_indices):
|
||||
|
||||
# First get the full conversation of both user and assistant messages
|
||||
conversation = train_task[example_idx]
|
||||
|
||||
|
|
@ -90,12 +94,12 @@ def get_batch():
|
|||
prefix_length = len(tokens)
|
||||
|
||||
# Generate num_samples samples using batched generation, use loop to avoid OOMs
|
||||
model.eval() # ensure the model is in eval mode
|
||||
model.eval() # ensure the model is in eval mode
|
||||
generated_token_sequences = []
|
||||
masks = []
|
||||
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
|
||||
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
|
||||
for sampling_step in range(num_sampling_steps):
|
||||
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
|
||||
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
|
||||
with autocast_ctx:
|
||||
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
||||
tokens,
|
||||
|
|
@ -103,7 +107,7 @@ def get_batch():
|
|||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
seed=seed, # must make sure to change the seed for each sampling step
|
||||
seed=seed, # must make sure to change the seed for each sampling step
|
||||
)
|
||||
generated_token_sequences.extend(generated_token_sequences_batch)
|
||||
masks.extend(masks_batch)
|
||||
|
|
@ -121,15 +125,17 @@ def get_batch():
|
|||
|
||||
# Pad the sequences so that their lengths (in time) match
|
||||
max_length = max(len(seq) for seq in generated_token_sequences)
|
||||
padded_generated_token_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences]
|
||||
padded_generated_token_sequences = [
|
||||
seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences
|
||||
]
|
||||
padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
|
||||
# Stack up the sequences and masks into PyTorch tensors
|
||||
ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
|
||||
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
|
||||
# Generate autoregressive inputs and targets to the Transformer
|
||||
inputs = ids[:, :-1]
|
||||
targets = ids[:, 1:].clone() # clone to avoid in-place modification:
|
||||
targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
|
||||
targets = ids[:, 1:].clone() # clone to avoid in-place modification:
|
||||
targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
|
||||
# NOTE also that the Engine returns mask=0 for BOTH the prompt tokens AND the tool use tokens.
|
||||
# So we will (correctly) end up not training on the prompt tokens, or the tool use forced tokens.
|
||||
rewards = torch.tensor(rewards, dtype=torch.float, device=device)
|
||||
|
|
@ -139,14 +145,11 @@ def get_batch():
|
|||
# yield inputs/targets as (B, T) of ids and rewards as (B,) of floats
|
||||
yield generated_token_sequences, inputs, targets, rewards, advantages
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Simple evaluation loop for GSM8K pass@k
|
||||
def run_gsm8k_eval(task, tokenizer, engine,
|
||||
max_examples=None,
|
||||
num_samples=1,
|
||||
max_completion_tokens=256,
|
||||
temperature=0.0,
|
||||
top_k=50
|
||||
def run_gsm8k_eval(
|
||||
task, tokenizer, engine, max_examples=None, num_samples=1, max_completion_tokens=256, temperature=0.0, top_k=50
|
||||
):
|
||||
"""
|
||||
Evaluates GSM8K task and returns a list of records of evaluation outcomes.
|
||||
|
|
@ -160,13 +163,9 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
|||
tokens = tokenizer.render_for_completion(conversation)
|
||||
prefix_length = len(tokens)
|
||||
# Generate k samples using batched generation inside the Engine
|
||||
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
|
||||
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
|
||||
generated_token_sequences, masks = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=num_samples,
|
||||
max_tokens=max_completion_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k
|
||||
tokens, num_samples=num_samples, max_tokens=max_completion_tokens, temperature=temperature, top_k=top_k
|
||||
)
|
||||
# Check each sample for correctness
|
||||
outcomes = []
|
||||
|
|
@ -174,9 +173,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
|||
generated_tokens = sample_tokens[prefix_length:]
|
||||
generated_text = tokenizer.decode(generated_tokens)
|
||||
is_correct = task.evaluate(conversation, generated_text)
|
||||
outcomes.append({
|
||||
"is_correct": is_correct
|
||||
})
|
||||
outcomes.append({"is_correct": is_correct})
|
||||
# A bit bloated because I wanted to do more complex logging at one point.
|
||||
record = {
|
||||
"idx": idx,
|
||||
|
|
@ -184,6 +181,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
|||
}
|
||||
yield record
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
|
||||
|
|
@ -199,44 +197,49 @@ optimizers = model.setup_optimizers(
|
|||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
|
||||
# Learning rate scheduler: simple rampdown to zero over num_steps
|
||||
def get_lr_multiplier(it):
|
||||
lrm = 1.0 - it / num_steps
|
||||
return lrm
|
||||
|
||||
|
||||
# Calculate the number of examples each rank handles to achieve the desired examples_per_step
|
||||
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
|
||||
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
|
||||
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
|
||||
examples_per_rank = examples_per_step // ddp_world_size # per GPU
|
||||
examples_per_rank = examples_per_step // ddp_world_size # per GPU
|
||||
print0(f"Calculated examples per rank: {examples_per_rank}")
|
||||
|
||||
# Kick off the training loop
|
||||
batch_iterator = get_batch()
|
||||
for step in range(num_steps):
|
||||
|
||||
# Evaluate the model once in a while and log to wandb
|
||||
if step % eval_every == 0:
|
||||
model.eval()
|
||||
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
with autocast_ctx:
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0)
|
||||
records = list(records_iter) # collect all records
|
||||
records_iter = run_gsm8k_eval(
|
||||
val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0
|
||||
)
|
||||
records = list(records_iter) # collect all records
|
||||
for k in range(1, device_batch_size + 1):
|
||||
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
|
||||
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
|
||||
if ddp:
|
||||
dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
|
||||
passk = passk / num_records.item() # normalize by the total number of records
|
||||
passk = passk / num_records.item() # normalize by the total number of records
|
||||
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
|
||||
print0(f"Step {step} | {', '.join(print_passk)}")
|
||||
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
**log_passk,
|
||||
})
|
||||
wandb_run.log(
|
||||
{
|
||||
"step": step,
|
||||
**log_passk,
|
||||
}
|
||||
)
|
||||
|
||||
# Forward/Backward on rollouts over multiple examples in the dataset
|
||||
rewards_list = []
|
||||
|
|
@ -245,7 +248,7 @@ for step in range(num_steps):
|
|||
# Get one batch corresponding to one example in the training dataset
|
||||
sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
|
||||
# Evaluate the loss and gradients
|
||||
model.train() # ensure the model is in train mode
|
||||
model.train() # ensure the model is in train mode
|
||||
# We need one more loop because we can never exceed the device_batch_size
|
||||
assert inputs_all.size(0) % device_batch_size == 0
|
||||
num_passes = inputs_all.size(0) // device_batch_size
|
||||
|
|
@ -258,7 +261,7 @@ for step in range(num_steps):
|
|||
advantages = advantages_all[b0:b1]
|
||||
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
|
||||
with autocast_ctx:
|
||||
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
|
||||
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
|
||||
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
|
||||
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
|
||||
# normalize by the number of valid tokens, number of passes, and examples_per_rank
|
||||
|
|
@ -268,7 +271,9 @@ for step in range(num_steps):
|
|||
# Finally, formulate the loss that we want to minimize (instead of objective we wish to maximize)
|
||||
loss = -pg_obj
|
||||
loss.backward()
|
||||
print0(f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}")
|
||||
print0(
|
||||
f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}"
|
||||
)
|
||||
# For logging
|
||||
rewards_list.append(rewards_all.mean().item())
|
||||
sequence_lengths.extend(len(seq) for seq in sequences_all)
|
||||
|
|
@ -276,56 +281,66 @@ for step in range(num_steps):
|
|||
# A bunch of logging for how the rollouts went this step
|
||||
mean_reward = sum(rewards_list) / len(rewards_list)
|
||||
mean_sequence_length = sum(sequence_lengths) / len(sequence_lengths)
|
||||
if ddp: # aggregate across ranks
|
||||
if ddp: # aggregate across ranks
|
||||
mean_reward_tensor = torch.tensor(mean_reward, dtype=torch.float, device=device)
|
||||
mean_sequence_length_tensor = torch.tensor(mean_sequence_length, dtype=torch.float, device=device)
|
||||
dist.all_reduce(mean_reward_tensor, op=dist.ReduceOp.AVG)
|
||||
dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG)
|
||||
mean_reward = mean_reward_tensor.item()
|
||||
mean_sequence_length = mean_sequence_length_tensor.item()
|
||||
print0(f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"reward": mean_reward,
|
||||
"sequence_length": mean_sequence_length,
|
||||
})
|
||||
print0(
|
||||
f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}"
|
||||
)
|
||||
wandb_run.log(
|
||||
{
|
||||
"step": step,
|
||||
"reward": mean_reward,
|
||||
"sequence_length": mean_sequence_length,
|
||||
}
|
||||
)
|
||||
|
||||
# Update the model parameters
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers: # first set the learning rate
|
||||
for opt in optimizers: # first set the learning rate
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
for opt in optimizers: # then step the optimizers
|
||||
for opt in optimizers: # then step the optimizers
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"lrm": lrm,
|
||||
})
|
||||
wandb_run.log(
|
||||
{
|
||||
"step": step,
|
||||
"lrm": lrm,
|
||||
}
|
||||
)
|
||||
|
||||
# Master process saves the model once in a while. Skip first step. Save last step.
|
||||
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
model.state_dict(),
|
||||
None, # note: we don't bother to save the optimizer state
|
||||
None, # note: we don't bother to save the optimizer state
|
||||
{
|
||||
"model_config": model_config_kwargs,
|
||||
}
|
||||
},
|
||||
)
|
||||
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Chat RL", data=[
|
||||
user_config, # CLI args
|
||||
])
|
||||
|
||||
wandb_run.finish() # wandb run finish
|
||||
get_report().log(
|
||||
section="Chat RL",
|
||||
data=[
|
||||
user_config, # CLI args
|
||||
],
|
||||
)
|
||||
|
||||
wandb_run.finish() # wandb run finish
|
||||
compute_cleanup()
|
||||
|
|
|
|||
|
|
@ -10,40 +10,40 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
|
|||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
|
||||
from nanochat.checkpoint_manager import load_model, save_checkpoint
|
||||
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
|
||||
from nanochat.engine import Engine
|
||||
from scripts.chat_eval import run_chat_eval
|
||||
|
||||
from tasks.common import TaskMixture
|
||||
from tasks.arc import ARC
|
||||
from tasks.common import TaskMixture
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SFT Hyperparameters
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# input model options
|
||||
source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
# compute/precision
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
|
||||
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
|
||||
target_examples_per_step = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
|
|
@ -56,9 +56,9 @@ eval_steps = 100
|
|||
eval_metrics_every = 200
|
||||
eval_metrics_max_problems = 1024
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
|
|
@ -70,52 +70,63 @@ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if dev
|
|||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
|
||||
wandb_run = (
|
||||
DummyWandb()
|
||||
if use_dummy_wandb
|
||||
else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
|
||||
)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||
orig_model = model # original, uncompiled model
|
||||
orig_model = model # original, uncompiled model
|
||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Task data mixture we'll train on
|
||||
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
|
||||
train_ds = TaskMixture([
|
||||
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
|
||||
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
|
||||
GSM8K(subset="main", split="train"), # 8K rows
|
||||
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
||||
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
|
||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
||||
train_ds = TaskMixture(
|
||||
[
|
||||
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
|
||||
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
|
||||
GSM8K(subset="main", split="train"), # 8K rows
|
||||
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
||||
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]
|
||||
) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
|
||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DataLoader
|
||||
|
||||
|
||||
def sft_data_generator(dataset, batch_size):
|
||||
pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
|
||||
pad_token_id = tokenizer.encode_special(
|
||||
"<|assistant_end|>"
|
||||
) # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
|
||||
|
||||
# prepares a list of tokenized conversations into a batch and yields
|
||||
def collate_and_yield(batch):
|
||||
nrows = len(batch)
|
||||
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
|
||||
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
|
||||
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
|
||||
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
|
||||
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
|
||||
for i, (ids, mask) in enumerate(batch):
|
||||
n = len(ids)
|
||||
ids_tensor = torch.tensor(ids, dtype=torch.long)
|
||||
inputs[i, :n-1] = ids_tensor[:-1]
|
||||
inputs[i, : n - 1] = ids_tensor[:-1]
|
||||
# recall -1 is the ignore index, so mask out targets where mask is 0
|
||||
row_targets = ids_tensor[1:]
|
||||
# mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok
|
||||
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
|
||||
row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0
|
||||
targets[i, :n-1] = row_targets
|
||||
inputs = inputs.to(device) # move to device
|
||||
row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0
|
||||
targets[i, : n - 1] = row_targets
|
||||
inputs = inputs.to(device) # move to device
|
||||
targets = targets.to(device)
|
||||
return inputs, targets
|
||||
|
||||
# iterates over the dataset in epochs, tokenizes
|
||||
batch = []
|
||||
while True:
|
||||
|
|
@ -127,11 +138,14 @@ def sft_data_generator(dataset, batch_size):
|
|||
yield collate_and_yield(batch)
|
||||
batch = []
|
||||
|
||||
|
||||
examples_per_step = device_batch_size * ddp_world_size
|
||||
print0(f"Target examples per step: {target_examples_per_step}")
|
||||
print0(f"Device batch size: {device_batch_size}")
|
||||
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
|
||||
assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
|
||||
assert target_examples_per_step % examples_per_step == 0, (
|
||||
"Target examples per step must be divisible by examples per step"
|
||||
)
|
||||
grad_accum_steps = target_examples_per_step // examples_per_step
|
||||
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
|
||||
|
||||
|
|
@ -155,16 +169,18 @@ optimizers = model.setup_optimizers(
|
|||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
|
||||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(it):
|
||||
lrm = 1.0 - it / num_iterations
|
||||
return lrm
|
||||
|
||||
|
||||
# Go!
|
||||
step = 0
|
||||
train_iter = iter(train_loader)
|
||||
|
|
@ -181,15 +197,17 @@ for step in range(num_iterations):
|
|||
with torch.no_grad(), autocast_ctx:
|
||||
loss = model(val_inputs, val_targets)
|
||||
losses.append(loss)
|
||||
val_loss = torch.stack(losses).mean() # average over eval_steps
|
||||
val_loss = torch.stack(losses).mean() # average over eval_steps
|
||||
if ddp:
|
||||
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
|
||||
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
|
||||
val_loss = val_loss.item()
|
||||
print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"val_loss": val_loss,
|
||||
})
|
||||
wandb_run.log(
|
||||
{
|
||||
"step": step,
|
||||
"val_loss": val_loss,
|
||||
}
|
||||
)
|
||||
model.train()
|
||||
|
||||
# evaluate accuracy of the multiple choice tasks (which are quick to run)
|
||||
|
|
@ -198,31 +216,47 @@ for step in range(num_iterations):
|
|||
metrics = {}
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics["mmlu_acc"] = run_chat_eval(
|
||||
"MMLU",
|
||||
model,
|
||||
tokenizer,
|
||||
engine,
|
||||
batch_size=device_batch_size * 2,
|
||||
max_problems=eval_metrics_max_problems,
|
||||
)
|
||||
metrics["arc_easy_acc"] = run_chat_eval(
|
||||
"ARC-Easy",
|
||||
model,
|
||||
tokenizer,
|
||||
engine,
|
||||
batch_size=device_batch_size * 2,
|
||||
max_problems=eval_metrics_max_problems,
|
||||
)
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
**metrics,
|
||||
})
|
||||
wandb_run.log(
|
||||
{
|
||||
"step": step,
|
||||
**metrics,
|
||||
}
|
||||
)
|
||||
model.train()
|
||||
|
||||
if last_step:
|
||||
break
|
||||
|
||||
# evaluate the gradient
|
||||
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
||||
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
||||
for micro_step in range(grad_accum_steps):
|
||||
train_inputs, train_targets = next(train_iter)
|
||||
with autocast_ctx:
|
||||
loss = model(train_inputs, train_targets)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward() # accumulate the gradient
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward() # accumulate the gradient
|
||||
num_tokens += (train_targets >= 0).sum()
|
||||
if ddp:
|
||||
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
|
||||
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
|
||||
|
||||
# learning rate scheduler
|
||||
lrm = get_lr_multiplier(step)
|
||||
|
|
@ -238,47 +272,55 @@ for step in range(num_iterations):
|
|||
# logging
|
||||
train_loss_item = train_loss.item()
|
||||
num_tokens_item = num_tokens.item()
|
||||
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"lrm": lrm,
|
||||
"train_loss": train_loss_item,
|
||||
"num_tokens": num_tokens_item,
|
||||
})
|
||||
print0(
|
||||
f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}"
|
||||
)
|
||||
wandb_run.log(
|
||||
{
|
||||
"step": step,
|
||||
"lrm": lrm,
|
||||
"train_loss": train_loss_item,
|
||||
"num_tokens": num_tokens_item,
|
||||
}
|
||||
)
|
||||
step += 1
|
||||
|
||||
# Save the model at the end of the run
|
||||
if master_process:
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
model.state_dict(),
|
||||
None, # note: we don't bother to save the optimizer state
|
||||
None, # note: we don't bother to save the optimizer state
|
||||
{
|
||||
"step": step,
|
||||
"val_loss": val_loss,
|
||||
**metrics,
|
||||
"model_config": model_config_kwargs,
|
||||
}
|
||||
},
|
||||
)
|
||||
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Chat SFT", data=[
|
||||
user_config, # CLI args
|
||||
{
|
||||
"Training rows": len(train_ds),
|
||||
"Number of iterations": num_iterations,
|
||||
"Training loss": train_loss_item,
|
||||
"Validation loss": val_loss,
|
||||
},
|
||||
])
|
||||
|
||||
get_report().log(
|
||||
section="Chat SFT",
|
||||
data=[
|
||||
user_config, # CLI args
|
||||
{
|
||||
"Training rows": len(train_ds),
|
||||
"Number of iterations": num_iterations,
|
||||
"Training loss": train_loss_item,
|
||||
"Validation loss": val_loss,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
wandb_run.finish()
|
||||
|
|
|
|||
|
|
@ -31,22 +31,23 @@ Abuse Prevention:
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import torch
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||
from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import autodetect_device_type, compute_init
|
||||
from nanochat.engine import Engine
|
||||
|
||||
# Abuse prevention limits
|
||||
|
|
@ -70,70 +71,70 @@ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag
|
|||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument(
|
||||
'--device-type',
|
||||
type=str,
|
||||
default='',
|
||||
choices=['cuda', 'cpu', 'mps'],
|
||||
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
|
||||
)
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for conversation traffic
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
"""A worker with a model loaded on a specific GPU."""
|
||||
|
||||
gpu_id: int
|
||||
device: torch.device
|
||||
engine: Engine
|
||||
tokenizer: object
|
||||
autocast_ctx: torch.amp.autocast
|
||||
|
||||
|
||||
class WorkerPool:
|
||||
"""Pool of workers, each with a model replica on a different GPU."""
|
||||
|
||||
def __init__(self, num_gpus: Optional[int] = None):
|
||||
def __init__(self, num_gpus: int | None = None):
|
||||
if num_gpus is None:
|
||||
if device_type == "cuda":
|
||||
num_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
num_gpus = 1 # e.g. cpu|mps
|
||||
num_gpus = 1 # e.g. cpu|mps
|
||||
self.num_gpus = num_gpus
|
||||
self.workers: List[Worker] = []
|
||||
self.workers: list[Worker] = []
|
||||
self.available_workers: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
|
||||
async def initialize(self, source: str, model_tag: str | None = None, step: int | None = None):
|
||||
"""Load model on each GPU."""
|
||||
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
|
||||
if self.num_gpus > 1:
|
||||
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
|
||||
|
||||
for gpu_id in range(self.num_gpus):
|
||||
|
||||
if device_type == "cuda":
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
print(f"Loading model on GPU {gpu_id}...")
|
||||
else:
|
||||
device = torch.device(device_type) # e.g. cpu|mps
|
||||
device = torch.device(device_type) # e.g. cpu|mps
|
||||
print(f"Loading model on {device_type}...")
|
||||
|
||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
worker = Worker(
|
||||
gpu_id=gpu_id,
|
||||
device=device,
|
||||
engine=engine,
|
||||
tokenizer=tokenizer,
|
||||
autocast_ctx=autocast_ctx
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
|
||||
worker = Worker(gpu_id=gpu_id, device=device, engine=engine, tokenizer=tokenizer, autocast_ctx=autocast_ctx)
|
||||
self.workers.append(worker)
|
||||
await self.available_workers.put(worker)
|
||||
|
||||
|
|
@ -147,15 +148,18 @@ class WorkerPool:
|
|||
"""Return a worker to the pool."""
|
||||
await self.available_workers.put(worker)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
messages: list[ChatMessage]
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
top_k: int | None = None
|
||||
|
||||
|
||||
def validate_chat_request(request: ChatRequest):
|
||||
"""Validate chat request to prevent abuse."""
|
||||
|
|
@ -165,7 +169,7 @@ def validate_chat_request(request: ChatRequest):
|
|||
if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request"
|
||||
detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request",
|
||||
)
|
||||
|
||||
# Check individual message lengths and total conversation length
|
||||
|
|
@ -178,48 +182,43 @@ def validate_chat_request(request: ChatRequest):
|
|||
if msg_length > MAX_MESSAGE_LENGTH:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message"
|
||||
detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message",
|
||||
)
|
||||
total_length += msg_length
|
||||
|
||||
if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed"
|
||||
detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed",
|
||||
)
|
||||
|
||||
# Validate role values
|
||||
for i, message in enumerate(request.messages):
|
||||
if message.role not in ["user", "assistant"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
|
||||
status_code=400, detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
|
||||
)
|
||||
|
||||
# Validate temperature
|
||||
if request.temperature is not None:
|
||||
if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
|
||||
status_code=400, detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
|
||||
)
|
||||
|
||||
# Validate top_k
|
||||
if request.top_k is not None:
|
||||
if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}")
|
||||
|
||||
# Validate max_tokens
|
||||
if request.max_tokens is not None:
|
||||
if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
|
||||
status_code=400, detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Load models on all GPUs on startup."""
|
||||
|
|
@ -229,6 +228,7 @@ async def lifespan(app: FastAPI):
|
|||
print(f"Server ready at http://localhost:{args.port}")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
|
|
@ -239,16 +239,16 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Serve the chat UI."""
|
||||
ui_html_path = os.path.join("nanochat", "ui.html")
|
||||
with open(ui_html_path, "r", encoding="utf-8") as f:
|
||||
with open(ui_html_path, encoding="utf-8") as f:
|
||||
html_content = f.read()
|
||||
# Replace the API_URL to use the same origin
|
||||
html_content = html_content.replace(
|
||||
"const API_URL = `http://${window.location.hostname}:8000`;",
|
||||
"const API_URL = '';"
|
||||
"const API_URL = `http://${window.location.hostname}:8000`;", "const API_URL = '';"
|
||||
)
|
||||
return HTMLResponse(content=html_content)
|
||||
|
||||
|
|
@ -259,12 +259,9 @@ async def logo():
|
|||
logo_path = os.path.join("nanochat", "logo.svg")
|
||||
return FileResponse(logo_path, media_type="image/svg+xml")
|
||||
|
||||
|
||||
async def generate_stream(
|
||||
worker: Worker,
|
||||
tokens,
|
||||
temperature=None,
|
||||
max_new_tokens=None,
|
||||
top_k=None
|
||||
worker: Worker, tokens, temperature=None, max_new_tokens=None, top_k=None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate assistant response with streaming."""
|
||||
temperature = temperature if temperature is not None else args.temperature
|
||||
|
|
@ -286,7 +283,7 @@ async def generate_stream(
|
|||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
seed=random.randint(0, 2**31 - 1)
|
||||
seed=random.randint(0, 2**31 - 1),
|
||||
):
|
||||
token = token_column[0]
|
||||
|
||||
|
|
@ -303,13 +300,14 @@ async def generate_stream(
|
|||
# This ensures we don't emit incomplete UTF-8 sequences
|
||||
if not current_text.endswith('<EFBFBD>'):
|
||||
# Extract only the new text since last clean decode
|
||||
new_text = current_text[len(last_clean_text):]
|
||||
new_text = current_text[len(last_clean_text) :]
|
||||
if new_text: # Only yield if there's new content
|
||||
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
||||
last_clean_text = current_text
|
||||
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
|
||||
|
||||
@app.post("/chat/completions")
|
||||
async def chat_completions(request: ChatRequest):
|
||||
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
|
||||
|
|
@ -318,10 +316,10 @@ async def chat_completions(request: ChatRequest):
|
|||
validate_chat_request(request)
|
||||
|
||||
# Log incoming conversation to console
|
||||
logger.info("="*20)
|
||||
logger.info("=" * 20)
|
||||
for i, message in enumerate(request.messages):
|
||||
logger.info(f"[{message.role.upper()}]: {message.content}")
|
||||
logger.info("-"*20)
|
||||
logger.info("-" * 20)
|
||||
|
||||
# Acquire a worker from the pool (will wait if all are busy)
|
||||
worker_pool = app.state.worker_pool
|
||||
|
|
@ -350,6 +348,7 @@ async def chat_completions(request: ChatRequest):
|
|||
|
||||
# Streaming response with worker release after completion
|
||||
response_tokens = []
|
||||
|
||||
async def stream_and_release():
|
||||
try:
|
||||
async for chunk in generate_stream(
|
||||
|
|
@ -357,7 +356,7 @@ async def chat_completions(request: ChatRequest):
|
|||
conversation_tokens,
|
||||
temperature=request.temperature,
|
||||
max_new_tokens=request.max_tokens,
|
||||
top_k=request.top_k
|
||||
top_k=request.top_k,
|
||||
):
|
||||
# Accumulate response for logging
|
||||
chunk_data = json.loads(chunk.replace("data: ", "").strip())
|
||||
|
|
@ -368,19 +367,17 @@ async def chat_completions(request: ChatRequest):
|
|||
# Log the assistant response to console
|
||||
full_response = "".join(response_tokens)
|
||||
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
|
||||
logger.info("="*20)
|
||||
logger.info("=" * 20)
|
||||
# Release worker back to pool after streaming is done
|
||||
await worker_pool.release_worker(worker)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_and_release(),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
return StreamingResponse(stream_and_release(), media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
# Make sure to release worker even on error
|
||||
await worker_pool.release_worker(worker)
|
||||
raise e
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
|
|
@ -389,9 +386,10 @@ async def health():
|
|||
"status": "ok",
|
||||
"ready": worker_pool is not None and len(worker_pool.workers) > 0,
|
||||
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
|
||||
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
|
||||
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
async def stats():
|
||||
"""Get worker pool statistics."""
|
||||
|
|
@ -400,16 +398,13 @@ async def stats():
|
|||
"total_workers": len(worker_pool.workers),
|
||||
"available_workers": worker_pool.available_workers.qsize(),
|
||||
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
|
||||
"workers": [
|
||||
{
|
||||
"gpu_id": w.gpu_id,
|
||||
"device": str(w.device)
|
||||
} for w in worker_pool.workers
|
||||
]
|
||||
"workers": [{"gpu_id": w.gpu_id, "device": str(w.device)} for w in worker_pool.workers],
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
print(f"Starting NanoChat Web Server")
|
||||
|
||||
print("Starting NanoChat Web Server")
|
||||
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
|
|
|||
|
|
@ -9,55 +9,58 @@ Or torchrun for training:
|
|||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
import os
|
||||
from collections import deque
|
||||
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
import wandb
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
import torch.distributed as dist
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
|
||||
from nanochat.checkpoint_manager import load_model, save_checkpoint
|
||||
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from tasks.common import TaskMixture
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||
weight_decay = 0.0
|
||||
eval_every = 150 # -1 = disable
|
||||
eval_tokens = 20*524288
|
||||
eval_every = 150 # -1 = disable
|
||||
eval_tokens = 20 * 524288
|
||||
total_batch_size = 524288
|
||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
|
|
@ -69,13 +72,15 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi
|
|||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
|
||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
||||
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||
print0(
|
||||
f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?"
|
||||
)
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
|
|
@ -84,48 +89,58 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {
|
|||
token_bytes = get_token_bytes(device=device)
|
||||
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay
|
||||
)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Midtraining data mixture and DataLoader
|
||||
base_dir = get_base_dir()
|
||||
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
||||
train_dataset = TaskMixture([
|
||||
SmolTalk(split="train"), # 460K rows of general conversations
|
||||
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
|
||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
|
||||
val_dataset = TaskMixture([
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
|
||||
]) # total: 24K + 14K + 1.32K ~= 39K rows
|
||||
train_dataset = TaskMixture(
|
||||
[
|
||||
SmolTalk(split="train"), # 460K rows of general conversations
|
||||
MMLU(
|
||||
subset="auxiliary_train", split="train"
|
||||
), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
|
||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]
|
||||
) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
|
||||
val_dataset = TaskMixture(
|
||||
[
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
|
||||
]
|
||||
) # total: 24K + 14K + 1.32K ~= 39K rows
|
||||
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
|
||||
# A big problem is that we don't know the final num_iterations in advance. So we create
|
||||
# these two global variables and update them from within the data generator.
|
||||
last_step = False # we will toggle this to True when we reach the end of the dataset
|
||||
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
||||
last_step = False # we will toggle this to True when we reach the end of the dataset
|
||||
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
||||
|
||||
|
||||
def mid_data_generator(split):
|
||||
global last_step, approx_progress
|
||||
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
||||
dataset = train_dataset if split == "train" else val_dataset
|
||||
dataset_size = len(dataset)
|
||||
assert dataset_size > 0
|
||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
token_buffer = deque()
|
||||
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding
|
||||
while len(token_buffer) < needed_tokens:
|
||||
|
|
@ -134,49 +149,55 @@ def mid_data_generator(split):
|
|||
token_buffer.extend(ids)
|
||||
cursor += ddp_world_size
|
||||
if cursor >= dataset_size:
|
||||
cursor -= dataset_size # wrap around for another epoch
|
||||
cursor -= dataset_size # wrap around for another epoch
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if num_iterations > 0 and it >= num_iterations:
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(
|
||||
device=device, dtype=torch.int64, non_blocking=True
|
||||
)
|
||||
if split == "train":
|
||||
if num_iterations > 0:
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
yield inputs, targets
|
||||
|
||||
|
||||
train_loader = mid_data_generator("train")
|
||||
build_val_loader = lambda: mid_data_generator("val")
|
||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||
|
||||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(progress):
|
||||
# first 80% of training: no decay, then linearly ramp down to 0.
|
||||
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
|
||||
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
frac = min(it / 300, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
return momentum
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
x, y = next(train_loader) # prefetch the very first batch of data
|
||||
x, y = next(train_loader) # prefetch the very first batch of data
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
step = 0
|
||||
while True:
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
|
@ -197,26 +218,28 @@ while True:
|
|||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"val/bpb": val_bpb,
|
||||
})
|
||||
wandb_run.log(
|
||||
{
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"val/bpb": val_bpb,
|
||||
}
|
||||
)
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step and not dry_run:
|
||||
output_dirname = f"d{depth}" # e.g. d12
|
||||
output_dirname = f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
{
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": {
|
||||
"sequence_len": max_seq_len,
|
||||
"vocab_size": tokenizer.get_vocab_size(),
|
||||
|
|
@ -225,8 +248,8 @@ while True:
|
|||
"n_kv_head": model.config.n_kv_head,
|
||||
"n_embd": model.config.n_embd,
|
||||
},
|
||||
"user_config": user_config, # inputs to the training script
|
||||
}
|
||||
"user_config": user_config, # inputs to the training script
|
||||
},
|
||||
)
|
||||
|
||||
if last_step:
|
||||
|
|
@ -240,11 +263,11 @@ while True:
|
|||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(progress)
|
||||
for opt in optimizers:
|
||||
|
|
@ -265,47 +288,55 @@ while True:
|
|||
step += 1
|
||||
|
||||
# logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1)) # debias the EMA
|
||||
pct_done = 100 * progress
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(
|
||||
f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time / 60:.2f}m"
|
||||
)
|
||||
if step % 10 == 0:
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
})
|
||||
wandb_run.log(
|
||||
{
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
"train/lrm": lrm,
|
||||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
}
|
||||
)
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Total training time: {total_training_time / 60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
if not dry_run:
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Midtraining", data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of iterations": step,
|
||||
"DDP world size": ddp_world_size,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
}
|
||||
])
|
||||
|
||||
get_report().log(
|
||||
section="Midtraining",
|
||||
data=[
|
||||
user_config, # CLI args
|
||||
{ # stats about the training setup
|
||||
"Number of iterations": step,
|
||||
"DDP world size": ddp_world_size,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# cleanup
|
||||
wandb_run.finish() # wandb run finish
|
||||
wandb_run.finish() # wandb run finish
|
||||
compute_cleanup()
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
Evaluate compression ratio of the tokenizer.
|
||||
"""
|
||||
|
||||
from nanochat.tokenizer import get_tokenizer, RustBPETokenizer
|
||||
from nanochat.dataset import parquets_iter_batched
|
||||
from nanochat.tokenizer import RustBPETokenizer, get_tokenizer
|
||||
|
||||
# Random text I got from a random website this morning
|
||||
news_text = r"""
|
||||
|
|
@ -165,11 +165,10 @@ tokenizer_results = {}
|
|||
vocab_sizes = {}
|
||||
|
||||
for tokenizer_name in ["gpt2", "gpt4", "ours"]:
|
||||
|
||||
if tokenizer_name == "gpt2":
|
||||
tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer
|
||||
tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer
|
||||
elif tokenizer_name == "gpt4":
|
||||
tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer
|
||||
tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer
|
||||
else:
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
|
|
@ -183,11 +182,7 @@ for tokenizer_name in ["gpt2", "gpt4", "ours"]:
|
|||
|
||||
encoded_bytes = text.encode('utf-8')
|
||||
ratio = len(encoded_bytes) / len(encoded)
|
||||
tokenizer_results[tokenizer_name][name] = {
|
||||
'bytes': len(encoded_bytes),
|
||||
'tokens': len(encoded),
|
||||
'ratio': ratio
|
||||
}
|
||||
tokenizer_results[tokenizer_name][name] = {'bytes': len(encoded_bytes), 'tokens': len(encoded), 'ratio': ratio}
|
||||
|
||||
# ANSI color codes
|
||||
GREEN = '\033[92m'
|
||||
|
|
@ -195,11 +190,12 @@ RED = '\033[91m'
|
|||
RESET = '\033[0m'
|
||||
|
||||
# Print vocab sizes
|
||||
print(f"\nVocab sizes:")
|
||||
print("\nVocab sizes:")
|
||||
print(f"GPT-2: {vocab_sizes['gpt2']}")
|
||||
print(f"GPT-4: {vocab_sizes['gpt4']}")
|
||||
print(f"Ours: {vocab_sizes['ours']}")
|
||||
|
||||
|
||||
def print_comparison(baseline_name, baseline_results, ours_results, all_text):
|
||||
"""Print comparison table between baseline tokenizer and ours."""
|
||||
print(f"\nComparison with {baseline_name}:")
|
||||
|
|
@ -230,13 +226,16 @@ def print_comparison(baseline_name, baseline_results, ours_results, all_text):
|
|||
better = "Tie"
|
||||
diff_color = ""
|
||||
|
||||
print(f"{name:<10} {baseline_data['bytes']:<8} "
|
||||
f"{baseline_color}{baseline_data['tokens']:<7}{RESET} "
|
||||
f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} "
|
||||
f"{ours_color}{ours_data['tokens']:<7}{RESET} "
|
||||
f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} "
|
||||
f"{diff_color}{relative_diff:+7.1f}%{RESET} "
|
||||
f"{better:<10}")
|
||||
print(
|
||||
f"{name:<10} {baseline_data['bytes']:<8} "
|
||||
f"{baseline_color}{baseline_data['tokens']:<7}{RESET} "
|
||||
f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} "
|
||||
f"{ours_color}{ours_data['tokens']:<7}{RESET} "
|
||||
f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} "
|
||||
f"{diff_color}{relative_diff:+7.1f}%{RESET} "
|
||||
f"{better:<10}"
|
||||
)
|
||||
|
||||
|
||||
# Print comparisons
|
||||
print_comparison("GPT-2", tokenizer_results['gpt2'], tokenizer_results['ours'], all_text)
|
||||
|
|
@ -244,6 +243,7 @@ print_comparison("GPT-4", tokenizer_results['gpt4'], tokenizer_results['ours'],
|
|||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
|
||||
lines = []
|
||||
for baseline_name in ["GPT-2", "GPT-4"]:
|
||||
baseline_key = baseline_name.lower().replace('-', '')
|
||||
|
|
@ -251,15 +251,26 @@ for baseline_name in ["GPT-2", "GPT-4"]:
|
|||
ours_results = tokenizer_results['ours']
|
||||
lines.append(f"### Comparison with {baseline_name}")
|
||||
lines.append("")
|
||||
lines.append("| Text Type | Bytes | " + baseline_name + " Tokens | " + baseline_name + " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |")
|
||||
lines.append(
|
||||
"| Text Type | Bytes | "
|
||||
+ baseline_name
|
||||
+ " Tokens | "
|
||||
+ baseline_name
|
||||
+ " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |"
|
||||
)
|
||||
lines.append("|-----------|-------|--------------|--------------|-------------|------------|-----------------|")
|
||||
for name, text in all_text:
|
||||
baseline_data = baseline_results[name]
|
||||
ours_data = ours_results[name]
|
||||
relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100
|
||||
lines.append(f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |")
|
||||
lines.append(
|
||||
f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |"
|
||||
)
|
||||
lines.append("")
|
||||
report_markdown = "\n".join(lines)
|
||||
get_report().log(section="Tokenizer evaluation", data=[
|
||||
report_markdown,
|
||||
])
|
||||
get_report().log(
|
||||
section="Tokenizer evaluation",
|
||||
data=[
|
||||
report_markdown,
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,19 +2,24 @@
|
|||
Train a tokenizer using the HuggingFace Tokenizers library.
|
||||
In the style of GPT-4 tokenizer.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from nanochat.tokenizer import RustBPETokenizer
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
from nanochat.dataset import parquets_iter_batched
|
||||
from nanochat.tokenizer import RustBPETokenizer
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Parse command line arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
|
||||
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument(
|
||||
'--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)'
|
||||
)
|
||||
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
|
||||
args = parser.parse_args()
|
||||
|
|
@ -25,6 +30,7 @@ print(f"vocab_size: {args.vocab_size:,}")
|
|||
# -----------------------------------------------------------------------------
|
||||
# Text iterator
|
||||
|
||||
|
||||
def text_iterator():
|
||||
"""
|
||||
1) Flatten the batches into a single iterator
|
||||
|
|
@ -36,11 +42,13 @@ def text_iterator():
|
|||
for doc in batch:
|
||||
doc_text = doc
|
||||
if len(doc_text) > args.doc_cap:
|
||||
doc_text = doc_text[:args.doc_cap]
|
||||
doc_text = doc_text[: args.doc_cap]
|
||||
nchars += len(doc_text)
|
||||
yield doc_text
|
||||
if nchars > args.max_chars:
|
||||
return
|
||||
|
||||
|
||||
text_iter = text_iterator()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -78,11 +86,11 @@ special_set = set(tokenizer.get_special_tokens())
|
|||
token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)]
|
||||
token_bytes = []
|
||||
for token_id in range(vocab_size):
|
||||
token_str = token_strings[token_id] # the Python string representation of this token
|
||||
token_str = token_strings[token_id] # the Python string representation of this token
|
||||
if token_str in special_set:
|
||||
token_bytes.append(0) # special characters are not counted
|
||||
token_bytes.append(0) # special characters are not counted
|
||||
else:
|
||||
id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
|
||||
id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
|
||||
token_bytes.append(id_bytes)
|
||||
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu')
|
||||
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
|
||||
|
|
@ -92,15 +100,19 @@ print(f"Saved token_bytes to {token_bytes_path}")
|
|||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
|
||||
token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32)
|
||||
get_report().log(section="Tokenizer training", data=[
|
||||
vars(args), # argparse command line arguments
|
||||
{"train_time": train_time},
|
||||
{"num_special_tokens": len(special_set)},
|
||||
{
|
||||
"token_bytes_min": int(token_bytes_nonzero.min().item()),
|
||||
"token_bytes_max": int(token_bytes_nonzero.max().item()),
|
||||
"token_bytes_mean": token_bytes_nonzero.mean().item(),
|
||||
"token_bytes_std": token_bytes_nonzero.std().item(),
|
||||
}
|
||||
])
|
||||
get_report().log(
|
||||
section="Tokenizer training",
|
||||
data=[
|
||||
vars(args), # argparse command line arguments
|
||||
{"train_time": train_time},
|
||||
{"num_special_tokens": len(special_set)},
|
||||
{
|
||||
"token_bytes_min": int(token_bytes_nonzero.min().item()),
|
||||
"token_bytes_max": int(token_bytes_nonzero.max().item()),
|
||||
"token_bytes_mean": token_bytes_nonzero.mean().item(),
|
||||
"token_bytes_std": token_bytes_nonzero.std().item(),
|
||||
},
|
||||
],
|
||||
)
|
||||
|
|
|
|||
26
tasks/arc.py
26
tasks/arc.py
|
|
@ -4,10 +4,11 @@ https://huggingface.co/datasets/allenai/ai2_arc
|
|||
"""
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from tasks.common import Task, render_mc
|
||||
|
||||
class ARC(Task):
|
||||
|
||||
class ARC(Task):
|
||||
def __init__(self, subset, split, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge"
|
||||
|
|
@ -23,26 +24,25 @@ class ARC(Task):
|
|||
|
||||
def get_example(self, index):
|
||||
row = self.ds[index]
|
||||
question = row["question"] # the question text
|
||||
choices = row["choices"]["text"] # the text of each choice
|
||||
answer_string = row["answerKey"] # e.g. "A", "B", "C", "D"
|
||||
letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"]
|
||||
assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check
|
||||
question = row["question"] # the question text
|
||||
choices = row["choices"]["text"] # the text of each choice
|
||||
answer_string = row["answerKey"] # e.g. "A", "B", "C", "D"
|
||||
letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"]
|
||||
assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check
|
||||
# create and return the Conversation object
|
||||
user_message = render_mc(question, letters, choices)
|
||||
messages = [
|
||||
{"role": "user", "content": user_message},
|
||||
{"role": "assistant", "content": answer_string}
|
||||
]
|
||||
messages = [{"role": "user", "content": user_message}, {"role": "assistant", "content": answer_string}]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
"letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
|
||||
"letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
|
||||
}
|
||||
return conversation
|
||||
|
||||
def evaluate(self, conversation, assistant_response):
|
||||
# the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
|
||||
# I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
|
||||
assert assistant_response in conversation['letters'], f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}"
|
||||
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
|
||||
assert assistant_response in conversation['letters'], (
|
||||
f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}"
|
||||
)
|
||||
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
|
||||
return assistant_response == assistant_message
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk.
|
|||
|
||||
import random
|
||||
|
||||
|
||||
class Task:
|
||||
"""
|
||||
Base class of a Task. Allows for lightweight slicing of the underlying dataset.
|
||||
|
|
@ -18,7 +19,7 @@ class Task:
|
|||
assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}"
|
||||
assert step >= 1, f"Step must be strictly positive, got {step}"
|
||||
self.start = start
|
||||
self.stop = stop # could be None here
|
||||
self.stop = stop # could be None here
|
||||
self.step = step
|
||||
|
||||
@property
|
||||
|
|
@ -37,8 +38,8 @@ class Task:
|
|||
stop = self.num_examples() if self.stop is None else self.stop
|
||||
step = self.step
|
||||
span = stop - start
|
||||
num = (span + step - 1) // step # ceil_div(span, step)
|
||||
assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns
|
||||
num = (span + step - 1) // step # ceil_div(span, step)
|
||||
assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns
|
||||
return num
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
|
|
@ -81,7 +82,9 @@ class TaskMixture(Task):
|
|||
Access conversations according to a deterministic shuffle of all examples.
|
||||
This ensures tasks are mixed throughout training, regardless of dataset size.
|
||||
"""
|
||||
assert 0 <= index < self.num_conversations, f"Index {index} out of range for mixture with {self.num_conversations} conversations"
|
||||
assert 0 <= index < self.num_conversations, (
|
||||
f"Index {index} out of range for mixture with {self.num_conversations} conversations"
|
||||
)
|
||||
task_idx, local_idx = self.index_map[index]
|
||||
return self.tasks[task_idx][local_idx]
|
||||
|
||||
|
|
@ -102,7 +105,9 @@ class TaskSequence(Task):
|
|||
return self.num_conversations
|
||||
|
||||
def get_example(self, index):
|
||||
assert 0 <= index < self.num_conversations, f"Index {index} out of range for sequence with {self.num_conversations} conversations"
|
||||
assert 0 <= index < self.num_conversations, (
|
||||
f"Index {index} out of range for sequence with {self.num_conversations} conversations"
|
||||
)
|
||||
for task_idx, task_length in enumerate(self.lengths):
|
||||
if index < task_length:
|
||||
return self.tasks[task_idx][index]
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@ CustomJSON task for loading conversations from JSONL files.
|
|||
Each line in the JSONL file should be a JSON array of messages.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
from tasks.common import Task
|
||||
|
||||
|
||||
class CustomJSON(Task):
|
||||
"""
|
||||
Load conversations from a JSONL file.
|
||||
|
|
@ -25,14 +27,18 @@ class CustomJSON(Task):
|
|||
print("-" * 80)
|
||||
print(f"Warning: File {filepath} does not exist")
|
||||
print("HINT (Oct 21 2025)")
|
||||
print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations")
|
||||
print(
|
||||
"If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations"
|
||||
)
|
||||
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
|
||||
print("Quick fix: simply run the following command to download the file and you're done:")
|
||||
print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl")
|
||||
print(
|
||||
f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl"
|
||||
)
|
||||
print("-" * 80)
|
||||
|
||||
else:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
with open(filepath, encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line: # skip empty lines
|
||||
|
|
@ -46,7 +52,9 @@ class CustomJSON(Task):
|
|||
assert "role" in message, f"Message {i} missing 'role' field"
|
||||
assert "content" in message, f"Message {i} missing 'content' field"
|
||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
assert message["role"] == expected_role, (
|
||||
f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
)
|
||||
assert isinstance(message["content"], str), f"Message {i} content must be a string"
|
||||
|
||||
self.conversations.append(messages)
|
||||
|
|
@ -62,4 +70,3 @@ class CustomJSON(Task):
|
|||
"messages": messages,
|
||||
}
|
||||
return conversation
|
||||
|
||||
|
|
|
|||
|
|
@ -15,11 +15,14 @@ Notice that GSM8K uses tool calls inside << >> tags.
|
|||
"""
|
||||
|
||||
import re
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from tasks.common import Task
|
||||
|
||||
|
||||
GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
|
||||
|
||||
|
||||
def extract_answer(completion):
|
||||
"""
|
||||
Extract the numerical answer after #### marker.
|
||||
|
|
@ -35,7 +38,6 @@ def extract_answer(completion):
|
|||
|
||||
|
||||
class GSM8K(Task):
|
||||
|
||||
def __init__(self, subset, split, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic"
|
||||
|
|
@ -50,10 +52,10 @@ class GSM8K(Task):
|
|||
return len(self.ds)
|
||||
|
||||
def get_example(self, index):
|
||||
""" Get a single problem from the dataset. """
|
||||
"""Get a single problem from the dataset."""
|
||||
row = self.ds[index]
|
||||
question = row['question'] # string of the question prompt
|
||||
answer = row['answer'] # string of the full solution and the answer after #### marker
|
||||
question = row['question'] # string of the question prompt
|
||||
answer = row['answer'] # string of the full solution and the answer after #### marker
|
||||
# Create and return the Conversation object
|
||||
# This is tricky because GSM8K uses tool calls, which we need to parse here.
|
||||
assistant_message_parts = []
|
||||
|
|
@ -76,8 +78,8 @@ class GSM8K(Task):
|
|||
assistant_message_parts.append({"type": "text", "text": part})
|
||||
# No put it all together
|
||||
messages = [
|
||||
{"role": "user", "content": question}, # note: simple string
|
||||
{"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts)
|
||||
{"role": "user", "content": question}, # note: simple string
|
||||
{"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts)
|
||||
]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
|
|
@ -99,7 +101,7 @@ class GSM8K(Task):
|
|||
assistant_message = conversation['messages'][-1]
|
||||
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
|
||||
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
|
||||
last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K
|
||||
last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K
|
||||
# Extract both the ground truth answer and the predicted answer
|
||||
ref_num = extract_answer(last_text_part)
|
||||
pred_num = extract_answer(assistant_response)
|
||||
|
|
|
|||
|
|
@ -5,10 +5,13 @@ It is a coding benchmark.
|
|||
"""
|
||||
|
||||
import re
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from nanochat.execution import execute_code
|
||||
from tasks.common import Task
|
||||
|
||||
|
||||
def extract_imports(prompt):
|
||||
"""Extract import statements from the beginning of a code block."""
|
||||
imports = []
|
||||
|
|
@ -21,6 +24,7 @@ def extract_imports(prompt):
|
|||
break
|
||||
return '\n'.join(imports)
|
||||
|
||||
|
||||
def extract_program(completion):
|
||||
"""
|
||||
Extract Python code from LLM completion.
|
||||
|
|
@ -44,8 +48,8 @@ def extract_program(completion):
|
|||
# No code blocks found, return the whole completion
|
||||
return completion.strip()
|
||||
|
||||
class HumanEval(Task):
|
||||
|
||||
class HumanEval(Task):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.ds = load_dataset("openai/openai_humaneval", split="test").shuffle(seed=42)
|
||||
|
|
@ -58,12 +62,12 @@ class HumanEval(Task):
|
|||
return len(self.ds)
|
||||
|
||||
def get_example(self, index):
|
||||
""" Get a single problem from the dataset. """
|
||||
"""Get a single problem from the dataset."""
|
||||
row = self.ds[index]
|
||||
prompt = row['prompt'] # prompts in HumanEval are the beginning of the program
|
||||
solution = row['canonical_solution'] # the correct continuation of the program
|
||||
entry_point = row['entry_point'] # the function to check
|
||||
test = row['test'] # the test cases
|
||||
prompt = row['prompt'] # prompts in HumanEval are the beginning of the program
|
||||
solution = row['canonical_solution'] # the correct continuation of the program
|
||||
entry_point = row['entry_point'] # the function to check
|
||||
test = row['test'] # the test cases
|
||||
complete_solution = f"{prompt}\n{solution}"
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
|
|
@ -71,13 +75,13 @@ class HumanEval(Task):
|
|||
]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
"entry_point": entry_point, # needed during evaluation
|
||||
"test": test, # needed during evaluation
|
||||
"entry_point": entry_point, # needed during evaluation
|
||||
"test": test, # needed during evaluation
|
||||
}
|
||||
return conversation
|
||||
|
||||
def evaluate(self, conversation, completion):
|
||||
""" Given (conversation, completion), return boolean success of the completion. """
|
||||
"""Given (conversation, completion), return boolean success of the completion."""
|
||||
# the prompt will contain the imports and the function signature
|
||||
imports = extract_imports(conversation['messages'][0]['content'])
|
||||
# the completion will usually contain the whole function
|
||||
|
|
|
|||
|
|
@ -4,12 +4,71 @@ https://huggingface.co/datasets/cais/mmlu
|
|||
"""
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from tasks.common import Task, render_mc
|
||||
|
||||
class MMLU(Task):
|
||||
|
||||
class MMLU(Task):
|
||||
letters = ('A', 'B', 'C', 'D')
|
||||
groups = ('abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions')
|
||||
groups = (
|
||||
'abstract_algebra',
|
||||
'anatomy',
|
||||
'astronomy',
|
||||
'business_ethics',
|
||||
'clinical_knowledge',
|
||||
'college_biology',
|
||||
'college_chemistry',
|
||||
'college_computer_science',
|
||||
'college_mathematics',
|
||||
'college_medicine',
|
||||
'college_physics',
|
||||
'computer_security',
|
||||
'conceptual_physics',
|
||||
'econometrics',
|
||||
'electrical_engineering',
|
||||
'elementary_mathematics',
|
||||
'formal_logic',
|
||||
'global_facts',
|
||||
'high_school_biology',
|
||||
'high_school_chemistry',
|
||||
'high_school_computer_science',
|
||||
'high_school_european_history',
|
||||
'high_school_geography',
|
||||
'high_school_government_and_politics',
|
||||
'high_school_macroeconomics',
|
||||
'high_school_mathematics',
|
||||
'high_school_microeconomics',
|
||||
'high_school_physics',
|
||||
'high_school_psychology',
|
||||
'high_school_statistics',
|
||||
'high_school_us_history',
|
||||
'high_school_world_history',
|
||||
'human_aging',
|
||||
'human_sexuality',
|
||||
'international_law',
|
||||
'jurisprudence',
|
||||
'logical_fallacies',
|
||||
'machine_learning',
|
||||
'management',
|
||||
'marketing',
|
||||
'medical_genetics',
|
||||
'miscellaneous',
|
||||
'moral_disputes',
|
||||
'moral_scenarios',
|
||||
'nutrition',
|
||||
'philosophy',
|
||||
'prehistory',
|
||||
'professional_accounting',
|
||||
'professional_law',
|
||||
'professional_medicine',
|
||||
'professional_psychology',
|
||||
'public_relations',
|
||||
'security_studies',
|
||||
'sociology',
|
||||
'us_foreign_policy',
|
||||
'virology',
|
||||
'world_religions',
|
||||
)
|
||||
|
||||
def __init__(self, subset, split, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -33,28 +92,27 @@ class MMLU(Task):
|
|||
|
||||
def get_example(self, index):
|
||||
row = self.ds[index]
|
||||
question = row["question"] # the question text
|
||||
choices = row["choices"] # the text of each choice
|
||||
answer = row["answer"] # index of the answer, e.g. 0,1,2,3 (for A,B,C,D)
|
||||
subject = row["subject"] # e.g. "college_biology", "college_chemistry", etc.
|
||||
question = row["question"] # the question text
|
||||
choices = row["choices"] # the text of each choice
|
||||
answer = row["answer"] # index of the answer, e.g. 0,1,2,3 (for A,B,C,D)
|
||||
subject = row["subject"] # e.g. "college_biology", "college_chemistry", etc.
|
||||
assert len(choices) == 4, "MMLU should have 4 choices"
|
||||
# create and return the Conversation object
|
||||
user_message = render_mc(question, self.letters, choices)
|
||||
assistant_message = self.letters[answer]
|
||||
messages = [
|
||||
{"role": "user", "content": user_message},
|
||||
{"role": "assistant", "content": assistant_message}
|
||||
]
|
||||
messages = [{"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_message}]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
"subject": subject, # might be useful later for grouping metrics by subject
|
||||
"letters": self.letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
|
||||
"subject": subject, # might be useful later for grouping metrics by subject
|
||||
"letters": self.letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
|
||||
}
|
||||
return conversation
|
||||
|
||||
def evaluate(self, conversation, assistant_response):
|
||||
# the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
|
||||
# I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
|
||||
assert assistant_response in self.letters, f"MMLU answer {assistant_response} is expected to be one of {self.letters}"
|
||||
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
|
||||
assert assistant_response in self.letters, (
|
||||
f"MMLU answer {assistant_response} is expected to be one of {self.letters}"
|
||||
)
|
||||
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
|
||||
return assistant_response == assistant_message
|
||||
|
|
|
|||
|
|
@ -5,10 +5,12 @@ We use the "smol" version, which is more appropriate for smaller models.
|
|||
"""
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from tasks.common import Task
|
||||
|
||||
|
||||
class SmolTalk(Task):
|
||||
""" smol-smoltalk dataset. train is 460K rows, test is 24K rows. """
|
||||
"""smol-smoltalk dataset. train is 460K rows, test is 24K rows."""
|
||||
|
||||
def __init__(self, split, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -29,14 +31,16 @@ class SmolTalk(Task):
|
|||
assert len(messages) >= 1
|
||||
first_message = messages[0]
|
||||
if first_message["role"] == "system":
|
||||
rest_messages = messages[1:] # optional system message is OK
|
||||
rest_messages = messages[1:] # optional system message is OK
|
||||
else:
|
||||
rest_messages = messages
|
||||
assert len(rest_messages) >= 2, "SmolTalk messages must have at least 2 messages"
|
||||
for i, message in enumerate(rest_messages):
|
||||
# user and assistant alternate as user,assistant,user,assistant,...
|
||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
assert message["role"] == expected_role, (
|
||||
f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
)
|
||||
assert isinstance(message["content"], str), "Content must be a string"
|
||||
# ---------------------------------------------------------------------
|
||||
# create and return the Conversation object (ok to emit the system message too)
|
||||
|
|
|
|||
|
|
@ -26,10 +26,11 @@ To preview a few example conversations, run:
|
|||
python -m tasks.spellingbee
|
||||
"""
|
||||
|
||||
import re
|
||||
import random
|
||||
from tasks.common import Task
|
||||
import re
|
||||
|
||||
from nanochat.common import download_file_with_lock
|
||||
from tasks.common import Task
|
||||
|
||||
# Letters of the alphabet
|
||||
LETTERS = "abcdefghijklmnopqrstuvwxyz"
|
||||
|
|
@ -38,6 +39,8 @@ WORD_LIST_URL = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads
|
|||
|
||||
# Identical to gsm8k's answer extraction
|
||||
ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
|
||||
|
||||
|
||||
def extract_answer(completion):
|
||||
"""
|
||||
Extract the numerical answer after #### marker.
|
||||
|
|
@ -49,6 +52,7 @@ def extract_answer(completion):
|
|||
return match_str
|
||||
return None
|
||||
|
||||
|
||||
# User message templates for data augmentation
|
||||
USER_MSG_TEMPLATES = [
|
||||
"How many {letter} are in the word {word}",
|
||||
|
|
@ -110,8 +114,8 @@ USER_MSG_TEMPLATES = [
|
|||
"{word}に{letter}が何回出てくる",
|
||||
]
|
||||
|
||||
class SpellingBee(Task):
|
||||
|
||||
class SpellingBee(Task):
|
||||
def __init__(self, size=1000, split="train", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||
|
|
@ -119,7 +123,7 @@ class SpellingBee(Task):
|
|||
self.split = split
|
||||
filename = WORD_LIST_URL.split("/")[-1]
|
||||
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
|
||||
with open(word_list_path, 'r', encoding='utf-8') as f:
|
||||
with open(word_list_path, encoding='utf-8') as f:
|
||||
words = [line.strip() for line in f]
|
||||
self.words = words
|
||||
|
||||
|
|
@ -131,7 +135,7 @@ class SpellingBee(Task):
|
|||
return self.size
|
||||
|
||||
def get_example(self, index):
|
||||
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
||||
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
||||
rng = random.Random(seed)
|
||||
|
||||
# pick a random word
|
||||
|
|
@ -148,12 +152,12 @@ class SpellingBee(Task):
|
|||
if rng.random() < 0.3:
|
||||
template = template.lower()
|
||||
quote_options = ['', "'", '"']
|
||||
letter_quote = rng.choice(quote_options) # is the letter quoted?
|
||||
word_quote = rng.choice(quote_options) # is the word quoted?
|
||||
letter_quote = rng.choice(quote_options) # is the letter quoted?
|
||||
word_quote = rng.choice(quote_options) # is the word quoted?
|
||||
letter_wrapped = f"{letter_quote}{letter}{letter_quote}"
|
||||
word_wrapped = f"{word_quote}{word}{word_quote}"
|
||||
user_msg = template.format(letter=letter_wrapped, word=word_wrapped)
|
||||
if rng.random() < 0.5: # 50% of people don't even use question marks
|
||||
if rng.random() < 0.5: # 50% of people don't even use question marks
|
||||
user_msg += "?"
|
||||
|
||||
# Now create the ideal assistant response - build as parts (text + tool calls)
|
||||
|
|
@ -190,13 +194,12 @@ Then count the occurrences of '{letter}':
|
|||
# Part 4: Python output
|
||||
assistant_parts.append({"type": "python_output", "text": str(count)})
|
||||
# Part 5: Final answer
|
||||
assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"})
|
||||
assistant_parts.append(
|
||||
{"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"}
|
||||
)
|
||||
|
||||
# return the full conversation
|
||||
messages = [
|
||||
{"role": "user", "content": user_msg},
|
||||
{"role": "assistant", "content": assistant_parts}
|
||||
]
|
||||
messages = [{"role": "user", "content": user_msg}, {"role": "assistant", "content": assistant_parts}]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
|
|
@ -222,7 +225,7 @@ Then count the occurrences of '{letter}':
|
|||
return is_correct
|
||||
|
||||
def reward(self, conversation, assistant_response):
|
||||
""" Use simple 0-1 reward just like gsm8k."""
|
||||
"""Use simple 0-1 reward just like gsm8k."""
|
||||
is_correct = self.evaluate(conversation, assistant_response)
|
||||
is_correct_float = float(is_correct)
|
||||
return is_correct_float
|
||||
|
|
@ -238,10 +241,10 @@ class SimpleSpelling(Task):
|
|||
self.split = split
|
||||
filename = WORD_LIST_URL.split("/")[-1]
|
||||
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
|
||||
with open(word_list_path, 'r', encoding='utf-8') as f:
|
||||
with open(word_list_path, encoding='utf-8') as f:
|
||||
words = [line.strip() for line in f]
|
||||
rng = random.Random(42)
|
||||
rng.shuffle(words) # use a different word order than the SpellingBee task
|
||||
rng.shuffle(words) # use a different word order than the SpellingBee task
|
||||
self.words = words
|
||||
|
||||
@property
|
||||
|
|
@ -252,7 +255,7 @@ class SimpleSpelling(Task):
|
|||
return self.size
|
||||
|
||||
def get_example(self, index):
|
||||
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
||||
seed = index if self.split == "train" else -(index + 1) # avoid collision at 0
|
||||
rng = random.Random(seed)
|
||||
# pick a random word
|
||||
word = rng.choice(self.words)
|
||||
|
|
@ -260,7 +263,7 @@ class SimpleSpelling(Task):
|
|||
# return the full conversation
|
||||
messages = [
|
||||
{"role": "user", "content": f"Spell the word: {word}"},
|
||||
{"role": "assistant", "content": f"{word}:{word_letters}"}
|
||||
{"role": "assistant", "content": f"{word}:{word_letters}"},
|
||||
]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
|
|
@ -269,7 +272,6 @@ class SimpleSpelling(Task):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# preview the SpellingBee task, first 10 examples
|
||||
task = SpellingBee()
|
||||
for i in range(10):
|
||||
|
|
|
|||
|
|
@ -5,8 +5,10 @@ python -m pytest tests/test_engine.py -v
|
|||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from nanochat.engine import KVCache
|
||||
|
||||
|
||||
def test_kv_cache_resize():
|
||||
"""
|
||||
The KV cache was not resized correctly, more information here:
|
||||
|
|
@ -21,11 +23,7 @@ def test_kv_cache_resize():
|
|||
num_layers = 6
|
||||
|
||||
kv_cache = KVCache(
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
seq_len=seq_len,
|
||||
head_dim=head_dim,
|
||||
num_layers=num_layers
|
||||
batch_size=batch_size, num_heads=num_heads, seq_len=seq_len, head_dim=head_dim, num_layers=num_layers
|
||||
)
|
||||
|
||||
# Insert a single token with a distinct fill value to all layers
|
||||
|
|
@ -47,7 +45,9 @@ def test_kv_cache_resize():
|
|||
insert_token(4)
|
||||
# Verify that the cache actually resized
|
||||
new_seq_len = kv_cache.kv_cache.shape[4]
|
||||
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
|
||||
assert new_seq_len > original_seq_len, (
|
||||
f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
|
||||
)
|
||||
|
||||
# Verify that the original 4 tokens are still intact after resize
|
||||
for layer_idx in range(num_layers):
|
||||
|
|
@ -57,8 +57,12 @@ def test_kv_cache_resize():
|
|||
expected_v = float(token_idx * 100)
|
||||
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
|
||||
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
|
||||
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
|
||||
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
|
||||
assert (actual_k == expected_k).all(), (
|
||||
f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
|
||||
)
|
||||
assert (actual_v == expected_v).all(), (
|
||||
f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
|
||||
)
|
||||
# And that the original cache matches resized cache
|
||||
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
|
||||
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
|
||||
|
|
|
|||
|
|
@ -18,18 +18,23 @@ python -m pytest tests/test_rustbpe.py -v -s
|
|||
-v is verbose, -s is show prints
|
||||
"""
|
||||
|
||||
import regex as re
|
||||
from collections import Counter, defaultdict
|
||||
import time
|
||||
import rustbpe
|
||||
import tiktoken
|
||||
import pytest
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
import pytest
|
||||
import regex as re
|
||||
import tiktoken
|
||||
|
||||
import rustbpe
|
||||
|
||||
GPT4_SPLIT_PATTERN = (
|
||||
r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Reference tokenizer, pretty much copy pasted and pruned a bit from minbpe
|
||||
|
||||
|
||||
def get_stats(ids, counts=None):
|
||||
"""
|
||||
Given a list of integers, return a dictionary of counts of consecutive pairs
|
||||
|
|
@ -37,10 +42,11 @@ def get_stats(ids, counts=None):
|
|||
Optionally allows to update an existing dictionary of counts
|
||||
"""
|
||||
counts = {} if counts is None else counts
|
||||
for pair in zip(ids, ids[1:]): # iterate consecutive elements
|
||||
for pair in zip(ids, ids[1:]): # iterate consecutive elements
|
||||
counts[pair] = counts.get(pair, 0) + 1
|
||||
return counts
|
||||
|
||||
|
||||
def merge(ids, pair, idx):
|
||||
"""
|
||||
In the list of integers (ids), replace all consecutive occurrences
|
||||
|
|
@ -51,7 +57,7 @@ def merge(ids, pair, idx):
|
|||
i = 0
|
||||
while i < len(ids):
|
||||
# if not at the very last position AND the pair matches, replace it
|
||||
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
|
||||
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i + 1] == pair[1]:
|
||||
newids.append(idx)
|
||||
i += 2
|
||||
else:
|
||||
|
|
@ -59,8 +65,8 @@ def merge(ids, pair, idx):
|
|||
i += 1
|
||||
return newids
|
||||
|
||||
class RegexTokenizer:
|
||||
|
||||
class RegexTokenizer:
|
||||
def __init__(self, pattern=None):
|
||||
"""
|
||||
- pattern: optional string to override the default (GPT-4 split pattern)
|
||||
|
|
@ -68,7 +74,7 @@ class RegexTokenizer:
|
|||
example: {'<|endoftext|>': 100257}
|
||||
"""
|
||||
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
|
||||
self.merges = {} # (int, int) -> int
|
||||
self.merges = {} # (int, int) -> int
|
||||
self.compiled_pattern = re.compile(self.pattern)
|
||||
self.special_tokens = {}
|
||||
self.inverse_special_tokens = {}
|
||||
|
|
@ -97,8 +103,8 @@ class RegexTokenizer:
|
|||
ids = [list(ch.encode("utf-8")) for ch in text_chunks]
|
||||
|
||||
# iteratively merge the most common pairs to create new tokens
|
||||
merges = {} # (int, int) -> int
|
||||
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
|
||||
merges = {} # (int, int) -> int
|
||||
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
|
||||
for i in range(num_merges):
|
||||
# count the number of times every consecutive pair appears
|
||||
stats = {}
|
||||
|
|
@ -125,11 +131,11 @@ class RegexTokenizer:
|
|||
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
|
||||
# prints
|
||||
if verbose:
|
||||
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
|
||||
print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
|
||||
|
||||
# save class variables
|
||||
self.merges = merges # used in encode()
|
||||
self.vocab = vocab # used in decode()
|
||||
self.merges = merges # used in encode()
|
||||
self.vocab = vocab # used in decode()
|
||||
return ambiguous
|
||||
|
||||
def _encode_chunk(self, text_bytes):
|
||||
|
|
@ -145,7 +151,7 @@ class RegexTokenizer:
|
|||
# just the first pair in the list, arbitrarily
|
||||
# we can detect this terminating case by a membership check
|
||||
if pair not in self.merges:
|
||||
break # nothing else can be merged anymore
|
||||
break # nothing else can be merged anymore
|
||||
# otherwise let's merge the best pair (lowest merge index)
|
||||
idx = self.merges[pair]
|
||||
ids = merge(ids, pair, idx)
|
||||
|
|
@ -158,14 +164,16 @@ class RegexTokenizer:
|
|||
# all chunks of text are encoded separately, then results are joined
|
||||
ids = []
|
||||
for chunk in text_chunks:
|
||||
chunk_bytes = chunk.encode("utf-8") # raw bytes
|
||||
chunk_bytes = chunk.encode("utf-8") # raw bytes
|
||||
chunk_ids = self._encode_chunk(chunk_bytes)
|
||||
ids.extend(chunk_ids)
|
||||
return ids
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Faster Python tokenizer, optimized version of the reference tokenizer
|
||||
|
||||
|
||||
def fast_merge_inplace(ids, pair, idx):
|
||||
"""
|
||||
In the list of integers (ids), replace all consecutive occurrences
|
||||
|
|
@ -175,16 +183,15 @@ def fast_merge_inplace(ids, pair, idx):
|
|||
# Find all positions where the pair occurs
|
||||
i = 0
|
||||
while i < len(ids) - 1:
|
||||
if ids[i] == pair[0] and ids[i+1] == pair[1]:
|
||||
if ids[i] == pair[0] and ids[i + 1] == pair[1]:
|
||||
ids[i] = idx
|
||||
ids.pop(i+1)
|
||||
ids.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
return ids
|
||||
|
||||
|
||||
class FastRegexTokenizer:
|
||||
|
||||
def __init__(self, pattern=None):
|
||||
"""
|
||||
- pattern: optional string to override the default (GPT-4 split pattern)
|
||||
|
|
@ -229,8 +236,8 @@ class FastRegexTokenizer:
|
|||
# input text preprocessing
|
||||
ids = [list(ch.encode("utf-8")) for ch in unique_chunks]
|
||||
# iteratively merge the most common pairs to create new tokens
|
||||
merges = {} # (int, int) -> int
|
||||
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
|
||||
merges = {} # (int, int) -> int
|
||||
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
|
||||
|
||||
# Initial count: build stats and position tracking
|
||||
stats = defaultdict(int)
|
||||
|
|
@ -262,31 +269,31 @@ class FastRegexTokenizer:
|
|||
chunk_count = chunk_counts[chunk_idx]
|
||||
ix = 0
|
||||
while ix < len(chunk_ids) - 1:
|
||||
if chunk_ids[ix] == pair[0] and chunk_ids[ix+1] == pair[1]:
|
||||
if chunk_ids[ix] == pair[0] and chunk_ids[ix + 1] == pair[1]:
|
||||
# Track what pairs are being removed/added
|
||||
# Remove: (prev, A), (A, B), (B, next)
|
||||
if ix > 0:
|
||||
old_left = (chunk_ids[ix-1], chunk_ids[ix])
|
||||
old_left = (chunk_ids[ix - 1], chunk_ids[ix])
|
||||
count_changes[old_left] -= chunk_count
|
||||
|
||||
# The merged pair disappears
|
||||
count_changes[pair] -= chunk_count
|
||||
|
||||
if ix + 2 < len(chunk_ids):
|
||||
old_right = (chunk_ids[ix+1], chunk_ids[ix+2])
|
||||
old_right = (chunk_ids[ix + 1], chunk_ids[ix + 2])
|
||||
count_changes[old_right] -= chunk_count
|
||||
|
||||
# Apply the merge
|
||||
chunk_ids[ix] = idx
|
||||
chunk_ids.pop(ix+1)
|
||||
chunk_ids.pop(ix + 1)
|
||||
|
||||
# Add: (prev, C), (C, next)
|
||||
if ix > 0:
|
||||
new_left = (chunk_ids[ix-1], chunk_ids[ix])
|
||||
new_left = (chunk_ids[ix - 1], chunk_ids[ix])
|
||||
count_changes[new_left] += chunk_count
|
||||
|
||||
if ix + 1 < len(chunk_ids):
|
||||
new_right = (chunk_ids[ix], chunk_ids[ix+1])
|
||||
new_right = (chunk_ids[ix], chunk_ids[ix + 1])
|
||||
count_changes[new_right] += chunk_count
|
||||
else:
|
||||
ix += 1
|
||||
|
|
@ -302,8 +309,9 @@ class FastRegexTokenizer:
|
|||
# Update positions for changed pairs - only check affected chunks
|
||||
for chunk_idx in affected_chunks:
|
||||
chunk_ids = ids[chunk_idx]
|
||||
contains_pair = any((chunk_ids[j], chunk_ids[j+1]) == changed_pair
|
||||
for j in range(len(chunk_ids) - 1))
|
||||
contains_pair = any(
|
||||
(chunk_ids[j], chunk_ids[j + 1]) == changed_pair for j in range(len(chunk_ids) - 1)
|
||||
)
|
||||
if contains_pair:
|
||||
positions[changed_pair].add(chunk_idx)
|
||||
else:
|
||||
|
|
@ -318,8 +326,8 @@ class FastRegexTokenizer:
|
|||
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
|
||||
|
||||
# save class variables
|
||||
self.merges = merges # used in encode()
|
||||
self.vocab = vocab # used in decode()
|
||||
self.merges = merges # used in encode()
|
||||
self.vocab = vocab # used in decode()
|
||||
|
||||
def register_special_tokens(self, special_tokens):
|
||||
# special_tokens is a dictionary of str -> int
|
||||
|
|
@ -354,7 +362,7 @@ class FastRegexTokenizer:
|
|||
# just the first pair in the list, arbitrarily
|
||||
# we can detect this terminating case by a membership check
|
||||
if pair not in self.merges:
|
||||
break # nothing else can be merged anymore
|
||||
break # nothing else can be merged anymore
|
||||
# otherwise let's merge the best pair (lowest merge index)
|
||||
idx = self.merges[pair]
|
||||
ids = fast_merge_inplace(ids, pair, idx)
|
||||
|
|
@ -367,18 +375,20 @@ class FastRegexTokenizer:
|
|||
# all chunks of text are encoded separately, then results are joined
|
||||
ids = []
|
||||
for chunk in text_chunks:
|
||||
chunk_bytes = chunk.encode("utf-8") # raw bytes
|
||||
chunk_bytes = chunk.encode("utf-8") # raw bytes
|
||||
chunk_ids = self._encode_chunk(chunk_bytes)
|
||||
ids.extend(chunk_ids)
|
||||
return ids
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HuggingFace tokenizer
|
||||
from tokenizers import Regex, decoders, pre_tokenizers
|
||||
from tokenizers import Tokenizer as HFTokenizer
|
||||
from tokenizers import pre_tokenizers, decoders, Regex
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
|
||||
class HuggingFaceTokenizer:
|
||||
"""Light wrapper around HuggingFace Tokenizer for some utilities"""
|
||||
|
||||
|
|
@ -389,19 +399,23 @@ class HuggingFaceTokenizer:
|
|||
def train_from_iterator(cls, text_iterator, vocab_size):
|
||||
# train from an iterator of text
|
||||
# Configure the HuggingFace Tokenizer
|
||||
tokenizer = HFTokenizer(BPE(
|
||||
byte_fallback=True, # needed!
|
||||
unk_token=None,
|
||||
fuse_unk=False,
|
||||
))
|
||||
tokenizer = HFTokenizer(
|
||||
BPE(
|
||||
byte_fallback=True, # needed!
|
||||
unk_token=None,
|
||||
fuse_unk=False,
|
||||
)
|
||||
)
|
||||
# Normalizer: None
|
||||
tokenizer.normalizer = None
|
||||
# Pre-tokenizer: GPT-4 style
|
||||
gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
||||
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
|
||||
])
|
||||
gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||
[
|
||||
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
|
||||
]
|
||||
)
|
||||
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
# Post-processor: None
|
||||
|
|
@ -410,9 +424,9 @@ class HuggingFaceTokenizer:
|
|||
trainer = BpeTrainer(
|
||||
vocab_size=vocab_size,
|
||||
show_progress=True,
|
||||
min_frequency=0, # no minimum frequency
|
||||
min_frequency=0, # no minimum frequency
|
||||
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
||||
special_tokens=[], # no special tokens
|
||||
special_tokens=[], # no special tokens
|
||||
)
|
||||
# Kick off the training
|
||||
tokenizer.train_from_iterator(text_iterator, trainer)
|
||||
|
|
@ -422,15 +436,19 @@ class HuggingFaceTokenizer:
|
|||
ids = self.tokenizer.encode(text, add_special_tokens=False).ids
|
||||
return ids
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test all of the above
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def enwik8_path():
|
||||
"""Fixture to download and cache enwik8 dataset."""
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
|
||||
base_dir = get_base_dir()
|
||||
# download and unzip enwik8 to .cache directory
|
||||
enwik8_url = "https://mattmahoney.net/dc/enwik8.zip"
|
||||
|
|
@ -439,6 +457,7 @@ def enwik8_path():
|
|||
if not os.path.exists(enwik8_local_path):
|
||||
print(f"Downloading enwik8 to {enwik8_local_path_zip}")
|
||||
import requests
|
||||
|
||||
response = requests.get(enwik8_url)
|
||||
with open(enwik8_local_path_zip, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
|
@ -455,15 +474,17 @@ def enwik8_path():
|
|||
@pytest.fixture(scope="module")
|
||||
def enwik8_small(enwik8_path):
|
||||
"""Fixture providing 100KB of enwik8 for quick tests."""
|
||||
with open(enwik8_path, "r", encoding="utf-8") as f:
|
||||
with open(enwik8_path, encoding="utf-8") as f:
|
||||
return f.read(100_000)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def enwik8_large(enwik8_path):
|
||||
"""Fixture providing 10MB of enwik8 for performance tests."""
|
||||
with open(enwik8_path, "r", encoding="utf-8") as f:
|
||||
with open(enwik8_path, encoding="utf-8") as f:
|
||||
return f.read(10**7)
|
||||
|
||||
|
||||
def time_function(func, *args, **kwargs):
|
||||
"""Time a function call and return the result and elapsed time"""
|
||||
start_time = time.time()
|
||||
|
|
@ -472,6 +493,7 @@ def time_function(func, *args, **kwargs):
|
|||
elapsed = end_time - start_time
|
||||
return result, elapsed
|
||||
|
||||
|
||||
def test_correctness(enwik8_small):
|
||||
"""Test that all tokenizer implementations produce the same results."""
|
||||
text = enwik8_small
|
||||
|
|
@ -482,7 +504,9 @@ def test_correctness(enwik8_small):
|
|||
print("\nTraining slow reference...")
|
||||
slow_reference_tokenizer = RegexTokenizer()
|
||||
ambiguous_flag, slow_reference_train_time = time_function(slow_reference_tokenizer.train, text, vocab_size)
|
||||
slow_reference_ids, slow_reference_encode_time = time_function(slow_reference_tokenizer.encode_ordinary, encode_text)
|
||||
slow_reference_ids, slow_reference_encode_time = time_function(
|
||||
slow_reference_tokenizer.encode_ordinary, encode_text
|
||||
)
|
||||
print(f"Slow reference train time: {slow_reference_train_time:.4f}s")
|
||||
print(f"Slow reference encode time: {slow_reference_encode_time:.4f}s")
|
||||
print(slow_reference_ids[:20])
|
||||
|
|
@ -497,7 +521,9 @@ def test_correctness(enwik8_small):
|
|||
print("\nTraining fast reference...")
|
||||
fast_reference_tokenizer = FastRegexTokenizer()
|
||||
_, fast_reference_train_time = time_function(fast_reference_tokenizer.train, text, vocab_size)
|
||||
fast_reference_ids, fast_reference_encode_time = time_function(fast_reference_tokenizer.encode_ordinary, encode_text)
|
||||
fast_reference_ids, fast_reference_encode_time = time_function(
|
||||
fast_reference_tokenizer.encode_ordinary, encode_text
|
||||
)
|
||||
print(f"Fast reference train time: {fast_reference_train_time:.4f}s")
|
||||
print(f"Fast reference encode time: {fast_reference_encode_time:.4f}s")
|
||||
print(fast_reference_ids[:20])
|
||||
|
|
@ -589,14 +615,16 @@ def test_training_performance(enwik8_large):
|
|||
assert hf_train_time > 0, "Training should take some time"
|
||||
|
||||
# Print comparison
|
||||
print(f"\n📊 Performance comparison:")
|
||||
print("\n📊 Performance comparison:")
|
||||
print(f" RustBPE: {rustbpe_train_time:.4f}s")
|
||||
print(f" HuggingFace: {hf_train_time:.4f}s")
|
||||
print(f" Speedup: {hf_train_time/rustbpe_train_time:.2f}x")
|
||||
print(f" Speedup: {hf_train_time / rustbpe_train_time:.2f}x")
|
||||
|
||||
|
||||
def test_interface(enwik8_small):
|
||||
"""Test the RustBPETokenizer interface for training, encoding, decoding, and serialization."""
|
||||
import tempfile
|
||||
|
||||
from nanochat.tokenizer import RustBPETokenizer
|
||||
|
||||
# Simple train test
|
||||
|
|
|
|||
69
uv.lock
69
uv.lock
|
|
@ -188,6 +188,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/e5/48/1549795ba7742c948d2ad169c1c8cdbae65bc450d6cd753d124b17c8cd32/certifi-2025.8.3-py3-none-any.whl", hash = "sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5", size = 161216, upload-time = "2025-08-03T03:07:45.777Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfgv"
|
||||
version = "3.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4e/b5/721b8799b04bf9afe054a3899c6cf4e880fcf8563cc71c15610242490a0c/cfgv-3.5.0.tar.gz", hash = "sha256:d5b1034354820651caa73ede66a6294d6e95c1b00acc5e9b098e917404669132", size = 7334, upload-time = "2025-11-19T20:55:51.612Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/db/3c/33bac158f8ab7f89b2e59426d5fe2e4f63f7ed25df84c036890172b412b5/cfgv-3.5.0-py2.py3-none-any.whl", hash = "sha256:a8dc6b26ad22ff227d2634a65cb388215ce6cc96bbcc5cfde7641ae87e8dacc0", size = 7445, upload-time = "2025-11-19T20:55:50.744Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.4.3"
|
||||
|
|
@ -306,6 +315,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252, upload-time = "2024-01-27T23:42:14.239Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "distlib"
|
||||
version = "0.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605, upload-time = "2025-07-17T16:52:00.465Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.3.0"
|
||||
|
|
@ -528,6 +546,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/39/7b/bb06b061991107cd8783f300adff3e7b7f284e330fd82f507f2a1417b11d/huggingface_hub-0.34.4-py3-none-any.whl", hash = "sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a", size = 561452, upload-time = "2025-08-08T09:14:50.159Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "identify"
|
||||
version = "2.6.15"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ff/e7/685de97986c916a6d93b3876139e00eef26ad5bbbd61925d670ae8013449/identify-2.6.15.tar.gz", hash = "sha256:e4f4864b96c6557ef2a1e1c951771838f4edc9df3a72ec7118b338801b11c7bf", size = 99311, upload-time = "2025-10-02T17:43:40.631Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/1c/e5fd8f973d4f375adb21565739498e2e9a1e54c858a97b9a8ccfdc81da9b/identify-2.6.15-py2.py3-none-any.whl", hash = "sha256:1181ef7608e00704db228516541eb83a88a9f94433a8c80bb9b5bd54b1d81757", size = 99183, upload-time = "2025-10-02T17:43:39.137Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.10"
|
||||
|
|
@ -802,6 +829,7 @@ gpu = [
|
|||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "maturin" },
|
||||
{ name = "pre-commit" },
|
||||
{ name = "pytest" },
|
||||
]
|
||||
|
||||
|
|
@ -826,6 +854,7 @@ provides-extras = ["cpu", "gpu"]
|
|||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "maturin", specifier = ">=1.9.4" },
|
||||
{ name = "pre-commit", specifier = ">=3.8.0" },
|
||||
{ name = "pytest", specifier = ">=8.0.0" },
|
||||
]
|
||||
|
||||
|
|
@ -872,6 +901,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.9.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437, upload-time = "2024-06-04T18:44:11.171Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.26.4"
|
||||
|
|
@ -1131,6 +1169,22 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pre-commit"
|
||||
version = "4.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cfgv" },
|
||||
{ name = "identify" },
|
||||
{ name = "nodeenv" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "virtualenv" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f4/9b/6a4ffb4ed980519da959e1cf3122fc6cb41211daa58dbae1c73c0e519a37/pre_commit-4.5.0.tar.gz", hash = "sha256:dc5a065e932b19fc1d4c653c6939068fe54325af8e741e74e88db4d28a4dd66b", size = 198428, upload-time = "2025-11-22T21:02:42.304Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/c4/b2d28e9d2edf4f1713eb3c29307f1a63f3d67cf09bdda29715a36a68921a/pre_commit-4.5.0-py2.py3-none-any.whl", hash = "sha256:25e2ce09595174d9c97860a95609f9f852c0614ba602de3561e267547f2335e1", size = 226429, upload-time = "2025-11-22T21:02:40.836Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "propcache"
|
||||
version = "0.3.2"
|
||||
|
|
@ -2016,6 +2070,21 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/96/06/5cc0542b47c0338c1cb676b348e24a1c29acabc81000bced518231dded6f/uvicorn-0.36.0-py3-none-any.whl", hash = "sha256:6bb4ba67f16024883af8adf13aba3a9919e415358604ce46780d3f9bdc36d731", size = 67675, upload-time = "2025-09-20T01:07:12.984Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "virtualenv"
|
||||
version = "20.35.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "distlib" },
|
||||
{ name = "filelock" },
|
||||
{ name = "platformdirs" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/20/28/e6f1a6f655d620846bd9df527390ecc26b3805a0c5989048c210e22c5ca9/virtualenv-20.35.4.tar.gz", hash = "sha256:643d3914d73d3eeb0c552cbb12d7e82adf0e504dbf86a3182f8771a153a1971c", size = 6028799, upload-time = "2025-10-29T06:57:40.511Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/79/0c/c05523fa3181fdf0c9c52a6ba91a23fbf3246cc095f26f6516f9c60e6771/virtualenv-20.35.4-py3-none-any.whl", hash = "sha256:c21c9cede36c9753eeade68ba7d523529f228a403463376cf821eaae2b650f1b", size = 6005095, upload-time = "2025-10-29T06:57:37.598Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wandb"
|
||||
version = "0.21.3"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user