Fix (automatically) all pre-commit errors

This commit is contained in:
Eyal Frishman 2025-12-05 18:33:00 +02:00
parent 6587063479
commit 449494c8b6
39 changed files with 1549 additions and 1048 deletions

View File

@ -28,24 +28,23 @@ NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the
(obviously you can tune this arbitrarily to your liking) (obviously you can tune this arbitrarily to your liking)
NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139 NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139
""" """
import requests
import copy
import json import json
import os import os
import copy
import random import random
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import requests
from nanochat.common import get_base_dir from nanochat.common import get_base_dir
api_key = open("openroutertoken.txt", "r", encoding="utf-8").read().strip() api_key = open("openroutertoken.txt", encoding="utf-8").read().strip()
url = "https://openrouter.ai/api/v1/chat/completions" url = "https://openrouter.ai/api/v1/chat/completions"
headers = { headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
readme = open("README.md", "r", encoding="utf-8").read().strip() readme = open("README.md", encoding="utf-8").read().strip()
prompt = r""" prompt = r"""
I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want: I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want:
@ -291,22 +290,19 @@ response_format = {
"properties": { "properties": {
"role": { "role": {
"type": "string", "type": "string",
"description": "The role of the speaker, either 'user' or 'assistant'" "description": "The role of the speaker, either 'user' or 'assistant'",
}, },
"content": { "content": {"type": "string", "description": "The message content"},
"type": "string",
"description": "The message content"
}
}, },
"required": ["role", "content"], "required": ["role", "content"],
"additionalProperties": False "additionalProperties": False,
} },
} }
}, },
"required": ["messages"], "required": ["messages"],
"additionalProperties": False "additionalProperties": False,
} },
} },
} }
# Sadly it doesn't seem like Chat completions support `n` # Sadly it doesn't seem like Chat completions support `n`
@ -318,6 +314,7 @@ base_payload = {
"temperature": 1.0, "temperature": 1.0,
} }
def generate_conversation(idx: int): def generate_conversation(idx: int):
""" """
Generate a single conversation using the OpenRouter API. Generate a single conversation using the OpenRouter API.
@ -357,7 +354,6 @@ print(f"Generating {num_conversations} conversations with {num_workers} workers.
completed_count = 0 completed_count = 0
error_count = 0 error_count = 0
with ThreadPoolExecutor(max_workers=num_workers) as executor: with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Submit all tasks # Submit all tasks
futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)] futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)]
@ -369,7 +365,9 @@ with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Lightly validate the conversation structure # Lightly validate the conversation structure
for i, message in enumerate(messages): for i, message in enumerate(messages):
expected_role = "user" if i % 2 == 0 else "assistant" expected_role = "user" if i % 2 == 0 else "assistant"
assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" assert message['role'] == expected_role, (
f"Message {i} has role {message['role']} but should be {expected_role}"
)
# If all looks good, write the messages to file # If all looks good, write the messages to file
with open(output_file, 'a') as f: with open(output_file, 'a') as f:
@ -384,4 +382,3 @@ with ThreadPoolExecutor(max_workers=num_workers) as executor:
print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}") print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}")
if error_count > 0: if error_count > 0:
print(f"Encountered {error_count} errors during generation") print(f"Encountered {error_count} errors during generation")

View File

@ -13,12 +13,13 @@ training latency.
NOTE: This file is meant only as reference/documentation of the NOTE: This file is meant only as reference/documentation of the
dataset preparation and it is not used during the project runtime. dataset preparation and it is not used during the project runtime.
""" """
import os import os
import time import time
from datasets import load_dataset
import pyarrow.parquet as pq
import pyarrow as pa import pyarrow as pa
import pyarrow.parquet as pq
from datasets import load_dataset
# Source dataset # Source dataset
dataset_kwargs = { dataset_kwargs = {
@ -73,15 +74,20 @@ for doc in ds:
avg_time_per_doc = total_time_spent / total_docs_processed avg_time_per_doc = total_time_spent / total_docs_processed
remaining_time = remaining_docs * avg_time_per_doc remaining_time = remaining_docs * avg_time_per_doc
remaining_time_hours = remaining_time / 3600 remaining_time_hours = remaining_time / 3600
print(f"Wrote {shard_path}. #documents: {len(shard_docs)} | #characters: {shard_characters} | time: {dt:.2f}s | remaining time: {remaining_time_hours:.2f}h") print(
f"Wrote {shard_path}. #documents: {len(shard_docs)} | #characters: {shard_characters} | time: {dt:.2f}s | remaining time: {remaining_time_hours:.2f}h"
)
shard_docs = [] shard_docs = []
shard_characters = 0 shard_characters = 0
shard_index += 1 shard_index += 1
# Demonstration of how the data was later uploaded to HuggingFace # Demonstration of how the data was later uploaded to HuggingFace
def upload(): def upload():
import os import os
from huggingface_hub import HfApi from huggingface_hub import HfApi
token = os.getenv("HF_TOKEN") token = os.getenv("HF_TOKEN")
api = HfApi(token=token) api = HfApi(token=token)
api.upload_large_folder( api.upload_large_folder(
@ -89,4 +95,6 @@ def upload():
repo_id="karpathy/fineweb-edu-100b-shuffle", repo_id="karpathy/fineweb-edu-100b-shuffle",
repo_type="dataset", repo_type="dataset",
) )
# upload() # upload()

View File

@ -2,6 +2,7 @@
Borrowed from modded-nanogpt. By Keller, @vagrawal, et al. Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
Not a general optimizer! But works for our specific use. Not a general optimizer! But works for our specific use.
""" """
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor from torch import Tensor
@ -12,7 +13,15 @@ class DistAdamW(torch.optim.Optimizer):
Distributed AdamW optimizer. Distributed AdamW optimizer.
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
""" """
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
def __init__(
self,
param_groups,
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.01,
):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(param_groups, defaults) super().__init__(param_groups, defaults)
@ -30,7 +39,9 @@ class DistAdamW(torch.optim.Optimizer):
grad = params[base_i].grad grad = params[base_i].grad
rank_size = grad.shape[0] // world_size rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size]) grad_slice = torch.empty_like(grad[:rank_size])
reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) reduce_scatter_futures.append(
dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
)
grad_slices.append(grad_slice) grad_slices.append(grad_slice)
idx = 0 idx = 0

View File

@ -1,25 +1,29 @@
""" """
Utilities for saving and loading model/optim/state checkpoints. Utilities for saving and loading model/optim/state checkpoints.
""" """
import os
import re
import glob import glob
import json import json
import logging import logging
import os
import re
import torch import torch
from nanochat.common import get_base_dir from nanochat.common import get_base_dir, setup_default_logging
from nanochat.gpt import GPT, GPTConfig from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import get_tokenizer from nanochat.tokenizer import get_tokenizer
from nanochat.common import setup_default_logging
# Set up logging # Set up logging
setup_default_logging() setup_default_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def log0(message): def log0(message):
if int(os.environ.get('RANK', 0)) == 0: if int(os.environ.get('RANK', 0)) == 0:
logger.info(message) logger.info(message)
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0: if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
@ -38,6 +42,7 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data,
torch.save(optimizer_data, optimizer_path) torch.save(optimizer_data, optimizer_path)
logger.info(f"Saved optimizer state to: {optimizer_path}") logger.info(f"Saved optimizer state to: {optimizer_path}")
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
# Load the model state # Load the model state
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
@ -49,7 +54,7 @@ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
optimizer_data = torch.load(optimizer_path, map_location=device) optimizer_data = torch.load(optimizer_path, map_location=device)
# Load the metadata # Load the metadata
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
with open(meta_path, "r", encoding="utf-8") as f: with open(meta_path, encoding="utf-8") as f:
meta_data = json.load(f) meta_data = json.load(f)
return model_data, optimizer_data, meta_data return model_data, optimizer_data, meta_data
@ -66,10 +71,7 @@ def build_model(checkpoint_dir, step, device, phase):
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
if device.type in {"cpu", "mps"}: if device.type in {"cpu", "mps"}:
# Convert bfloat16 tensors to float for CPU inference # Convert bfloat16 tensors to float for CPU inference
model_data = { model_data = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items()}
k: v.float() if v.dtype == torch.bfloat16 else v
for k, v in model_data.items()
}
# Hack: fix torch compile issue, which prepends all keys with _orig_mod. # Hack: fix torch compile issue, which prepends all keys with _orig_mod.
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
model_config_kwargs = meta_data["model_config"] model_config_kwargs = meta_data["model_config"]
@ -121,9 +123,11 @@ def find_last_step(checkpoint_dir):
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
return last_step return last_step
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# convenience functions that take into account nanochat's directory structure # convenience functions that take into account nanochat's directory structure
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
if model_tag is None: if model_tag is None:
# guess the model tag by defaulting to the largest model # guess the model tag by defaulting to the largest model
@ -139,6 +143,7 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase) model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
return model, tokenizer, meta_data return model, tokenizer, meta_data
def load_model(source, *args, **kwargs): def load_model(source, *args, **kwargs):
model_dir = { model_dir = {
"base": "base_checkpoints", "base": "base_checkpoints",

View File

@ -2,16 +2,19 @@
Common utilities for nanochat. Common utilities for nanochat.
""" """
import logging
import os import os
import re import re
import logging
import urllib.request import urllib.request
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from filelock import FileLock from filelock import FileLock
class ColoredFormatter(logging.Formatter): class ColoredFormatter(logging.Formatter):
"""Custom formatter that adds colors to log messages.""" """Custom formatter that adds colors to log messages."""
# ANSI color codes # ANSI color codes
COLORS = { COLORS = {
'DEBUG': '\033[36m', # Cyan 'DEBUG': '\033[36m', # Cyan
@ -22,6 +25,7 @@ class ColoredFormatter(logging.Formatter):
} }
RESET = '\033[0m' RESET = '\033[0m'
BOLD = '\033[1m' BOLD = '\033[1m'
def format(self, record): def format(self, record):
# Add color to the level name # Add color to the level name
levelname = record.levelname levelname = record.levelname
@ -36,17 +40,17 @@ class ColoredFormatter(logging.Formatter):
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message) message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
return message return message
def setup_default_logging(): def setup_default_logging():
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logging.basicConfig( logging.basicConfig(level=logging.INFO, handlers=[handler])
level=logging.INFO,
handlers=[handler]
)
setup_default_logging() setup_default_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_base_dir(): def get_base_dir():
# co-locate nanochat intermediates with other cached data in ~/.cache (by default) # co-locate nanochat intermediates with other cached data in ~/.cache (by default)
if os.environ.get("NANOCHAT_BASE_DIR"): if os.environ.get("NANOCHAT_BASE_DIR"):
@ -58,6 +62,7 @@ def get_base_dir():
os.makedirs(nanochat_dir, exist_ok=True) os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir return nanochat_dir
def download_file_with_lock(url, filename, postprocess_fn=None): def download_file_with_lock(url, filename, postprocess_fn=None):
""" """
Downloads a file from a URL to a local path in the base directory. Downloads a file from a URL to a local path in the base directory.
@ -94,11 +99,13 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
return file_path return file_path
def print0(s="", **kwargs): def print0(s="", **kwargs):
ddp_rank = int(os.environ.get('RANK', 0)) ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0: if ddp_rank == 0:
print(s, **kwargs) print(s, **kwargs)
def print_banner(): def print_banner():
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
banner = """ banner = """
@ -113,10 +120,12 @@ def print_banner():
""" """
print0(banner) print0(banner)
def is_ddp(): def is_ddp():
# TODO is there a proper way # TODO is there a proper way
return int(os.environ.get('RANK', -1)) != -1 return int(os.environ.get('RANK', -1)) != -1
def get_dist_info(): def get_dist_info():
if is_ddp(): if is_ddp():
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
@ -127,6 +136,7 @@ def get_dist_info():
else: else:
return False, 0, 0, 1 return False, 0, 0, 1
def autodetect_device_type(): def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -138,14 +148,19 @@ def autodetect_device_type():
print0(f"Autodetected device type: {device_type}") print0(f"Autodetected device type: {device_type}")
return device_type return device_type
def compute_init(device_type="cuda"): # cuda|cpu|mps def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common.""" """Basic initialization that we keep doing over and over, so make common."""
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
if device_type == "cuda": if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" assert torch.cuda.is_available(), (
"Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
)
if device_type == "mps": if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" assert torch.backends.mps.is_available(), (
"Your PyTorch installation is not configured for MPS but device_type is 'mps'"
)
# Reproducibility # Reproducibility
# Note that we set the global seeds here, but most of the code uses explicit rng objects. # Note that we set the global seeds here, but most of the code uses explicit rng objects.
@ -175,16 +190,21 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
def compute_cleanup(): def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit""" """Companion function to compute_init, to clean things up before script exit"""
if is_ddp(): if is_ddp():
dist.destroy_process_group() dist.destroy_process_group()
class DummyWandb: class DummyWandb:
"""Useful if we wish to not use wandb but have all the same signatures""" """Useful if we wish to not use wandb but have all the same signatures"""
def __init__(self): def __init__(self):
pass pass
def log(self, *args, **kwargs): def log(self, *args, **kwargs):
pass pass
def finish(self): def finish(self):
pass pass

View File

@ -18,11 +18,13 @@ import os
import sys import sys
from ast import literal_eval from ast import literal_eval
def print0(s="", **kwargs): def print0(s="", **kwargs):
ddp_rank = int(os.environ.get('RANK', 0)) ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0: if ddp_rank == 0:
print(s, **kwargs) print(s, **kwargs)
for arg in sys.argv[1:]: for arg in sys.argv[1:]:
if '=' not in arg: if '=' not in arg:
# assume it's the name of a config file # assume it's the name of a config file

View File

@ -5,15 +5,17 @@ https://arxiv.org/abs/2406.11794
TODOs: TODOs:
- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why. - All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
""" """
import random import random
from jinja2 import Template
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from jinja2 import Template
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Prompt rendering utilities # Prompt rendering utilities
def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None): def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
"""Render complete prompts for a multiple choice question""" """Render complete prompts for a multiple choice question"""
template_str = """ template_str = """
@ -24,11 +26,7 @@ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip() {{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
template = Template(template_str) template = Template(template_str)
fewshot_examples = fewshot_examples or [] fewshot_examples = fewshot_examples or []
context = { context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
prompts = [template.render(choice=choice, **context) for choice in item['choices']] prompts = [template.render(choice=choice, **context) for choice in item['choices']]
return prompts return prompts
@ -43,13 +41,8 @@ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip() {{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
template = Template(template_str) template = Template(template_str)
fewshot_examples = fewshot_examples or [] fewshot_examples = fewshot_examples or []
context = { context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
'fewshot_examples': fewshot_examples, prompts = [template.render(context=context_option, **context) for context_option in item['context_options']]
'continuation_delimiter': continuation_delimiter,
'item': item
}
prompts = [template.render(context=context_option, **context)
for context_option in item['context_options']]
return prompts return prompts
@ -67,11 +60,7 @@ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip() {{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
template = Template(template_str) template = Template(template_str)
fewshot_examples = fewshot_examples or [] fewshot_examples = fewshot_examples or []
context = { context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
# Return two prompts: without and with the continuation # Return two prompts: without and with the continuation
prompt_without = template.render(include_continuation=False, **context) prompt_without = template.render(include_continuation=False, **context)
prompt_with = template.render(include_continuation=True, **context) prompt_with = template.render(include_continuation=True, **context)
@ -89,10 +78,7 @@ def find_common_length(token_sequences, direction='left'):
- direction: 'left' for prefix, 'right' for suffix - direction: 'left' for prefix, 'right' for suffix
""" """
min_len = min(len(seq) for seq in token_sequences) min_len = min(len(seq) for seq in token_sequences)
indices = { indices = {'left': range(min_len), 'right': range(-1, -min_len - 1, -1)}[direction]
'left': range(min_len),
'right': range(-1, -min_len-1, -1)
}[direction]
# Find the first position where the token sequences differ # Find the first position where the token sequences differ
for i, idx in enumerate(indices): for i, idx in enumerate(indices):
token = token_sequences[0][idx] token = token_sequences[0][idx]
@ -153,9 +139,7 @@ def forward_model(model, input_ids):
target_ids = torch.roll(input_ids, shifts=-1, dims=1) target_ids = torch.roll(input_ids, shifts=-1, dims=1)
# Calculate cross entropy at all positions # Calculate cross entropy at all positions
losses = torch.nn.functional.cross_entropy( losses = torch.nn.functional.cross_entropy(
outputs.view(batch_size * seq_len, -1), outputs.view(batch_size * seq_len, -1), target_ids.view(batch_size * seq_len), reduction='none'
target_ids.view(batch_size * seq_len),
reduction='none'
).view(batch_size, seq_len) ).view(batch_size, seq_len)
# Set the last column to be nan because there is no autoregressive loss there # Set the last column to be nan because there is no autoregressive loss there
losses[:, -1] = float('nan') losses[:, -1] = float('nan')
@ -231,8 +215,7 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
is_correct = torch.all(predicted_tokens == actual_tokens).item() is_correct = torch.all(predicted_tokens == actual_tokens).item()
elif task_type in ['multiple_choice', 'schema']: elif task_type in ['multiple_choice', 'schema']:
# For MC/schema: find the option with lowest average loss # For MC/schema: find the option with lowest average loss
mean_losses = [losses[i, si-1:ei-1].mean().item() mean_losses = [losses[i, si - 1 : ei - 1].mean().item() for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
pred_idx = mean_losses.index(min(mean_losses)) pred_idx = mean_losses.index(min(mean_losses))
is_correct = pred_idx == item['gold'] is_correct = pred_idx == item['gold']
else: else:

View File

@ -1,13 +1,16 @@
from collections import deque from collections import deque
import torch
import pyarrow.parquet as pq import pyarrow.parquet as pq
import torch
from nanochat.common import get_dist_info from nanochat.common import get_dist_info
from nanochat.dataset import list_parquet_files from nanochat.dataset import list_parquet_files
from nanochat.tokenizer import get_tokenizer from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
def tokenizing_distributed_data_loader_with_state(
B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None
):
""" """
Stream pretraining text from parquet files, tokenize, yield training batches. Stream pretraining text from parquet files, tokenize, yield training batches.
@ -24,6 +27,7 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
# infinite iterator over document batches (list of text strings) # infinite iterator over document batches (list of text strings)
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
def document_batches(): def document_batches():
parquet_paths = list_parquet_files() parquet_paths = list_parquet_files()
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
@ -51,6 +55,7 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
yield batch[i : i + tokenizer_batch_size], (pq_idx, rg_idx) yield batch[i : i + tokenizer_batch_size], (pq_idx, rg_idx)
rg_idx += ddp_world_size # advance to the next row group (in DDP) rg_idx += ddp_world_size # advance to the next row group (in DDP)
pq_idx += 1 # advance to the next parquet file pq_idx += 1 # advance to the next parquet file
batches = document_batches() batches = document_batches()
# Now emit batches of tokens. # Now emit batches of tokens.
@ -78,9 +83,13 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
# Reshape to 2D and move to GPU async # Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training state_dict = {
"pq_idx": pq_idx,
"rg_idx": rg_idx,
} # we need this in case we wish to approximately resume training
yield inputs, targets, state_dict yield inputs, targets, state_dict
def tokenizing_distributed_data_loader(*args, **kwargs): def tokenizing_distributed_data_loader(*args, **kwargs):
# helper function that only emits the inputs/targets and not the state_dict # helper function that only emits the inputs/targets and not the state_dict
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs): for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):

View File

@ -7,13 +7,14 @@ This file contains utilities for:
For details of how the dataset was prepared, see `repackage_data_reference.py`. For details of how the dataset was prepared, see `repackage_data_reference.py`.
""" """
import os
import argparse import argparse
import os
import time import time
import requests
import pyarrow.parquet as pq
from multiprocessing import Pool from multiprocessing import Pool
import pyarrow.parquet as pq
import requests
from nanochat.common import get_base_dir from nanochat.common import get_base_dir
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -30,16 +31,15 @@ os.makedirs(DATA_DIR, exist_ok=True)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# These functions are useful utilities to other modules, can/should be imported # These functions are useful utilities to other modules, can/should be imported
def list_parquet_files(data_dir=None): def list_parquet_files(data_dir=None):
"""Looks into a data dir and returns full paths to all parquet files.""" """Looks into a data dir and returns full paths to all parquet files."""
data_dir = DATA_DIR if data_dir is None else data_dir data_dir = DATA_DIR if data_dir is None else data_dir
parquet_files = sorted([ parquet_files = sorted([f for f in os.listdir(data_dir) if f.endswith('.parquet') and not f.endswith('.tmp')])
f for f in os.listdir(data_dir)
if f.endswith('.parquet') and not f.endswith('.tmp')
])
parquet_paths = [os.path.join(data_dir, f) for f in parquet_files] parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
return parquet_paths return parquet_paths
def parquets_iter_batched(split, start=0, step=1): def parquets_iter_batched(split, start=0, step=1):
""" """
Iterate through the dataset, in batches of underlying row_groups for efficiency. Iterate through the dataset, in batches of underlying row_groups for efficiency.
@ -56,6 +56,7 @@ def parquets_iter_batched(split, start=0, step=1):
texts = rg.column('text').to_pylist() texts = rg.column('text').to_pylist()
yield texts yield texts
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def download_single_file(index): def download_single_file(index):
"""Downloads a single file index, with some backoff""" """Downloads a single file index, with some backoff"""
@ -78,7 +79,7 @@ def download_single_file(index):
response = requests.get(url, stream=True, timeout=30) response = requests.get(url, stream=True, timeout=30)
response.raise_for_status() response.raise_for_status()
# Write to temporary file first # Write to temporary file first
temp_path = filepath + f".tmp" temp_path = filepath + ".tmp"
with open(temp_path, 'wb') as f: with open(temp_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
if chunk: if chunk:
@ -88,10 +89,10 @@ def download_single_file(index):
print(f"Successfully downloaded {filename}") print(f"Successfully downloaded {filename}")
return True return True
except (requests.RequestException, IOError) as e: except (OSError, requests.RequestException) as e:
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
# Clean up any partial files # Clean up any partial files
for path in [filepath + f".tmp", filepath]: for path in [filepath + ".tmp", filepath]:
if os.path.exists(path): if os.path.exists(path):
try: try:
os.remove(path) os.remove(path)
@ -111,8 +112,12 @@ def download_single_file(index):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards") parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable") parser.add_argument(
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)") "-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable"
)
parser.add_argument(
"-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)"
)
args = parser.parse_args() args = parser.parse_args()
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)

View File

@ -11,15 +11,17 @@ Notes:
The whole thing is made as efficient as possible. The whole thing is made as efficient as possible.
""" """
import torch
import torch.nn.functional as F
import signal import signal
import warnings import warnings
from contextlib import contextmanager
from collections import deque from collections import deque
from nanochat.common import compute_init, autodetect_device_type from contextlib import contextmanager, nullcontext
import torch
import torch.nn.functional as F
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from contextlib import nullcontext from nanochat.common import autodetect_device_type, compute_init
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Calculator tool helpers # Calculator tool helpers
@ -33,17 +35,19 @@ def timeout(duration, formula):
yield yield
signal.alarm(0) signal.alarm(0)
def eval_with_timeout(formula, max_time=3): def eval_with_timeout(formula, max_time=3):
try: try:
with timeout(max_time, formula): with timeout(max_time, formula):
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", SyntaxWarning) warnings.simplefilter("ignore", SyntaxWarning)
return eval(formula, {"__builtins__": {}}, {}) return eval(formula, {"__builtins__": {}}, {})
except Exception as e: except Exception:
signal.alarm(0) signal.alarm(0)
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
return None return None
def use_calculator(expr): def use_calculator(expr):
""" """
Evaluate a Python expression safely. Evaluate a Python expression safely.
@ -65,9 +69,25 @@ def use_calculator(expr):
return None return None
# Disallow dangerous patterns # Disallow dangerous patterns
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file', dangerous_patterns = [
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir', '__',
'getattr', 'setattr', 'delattr', 'hasattr'] 'import',
'exec',
'eval',
'compile',
'open',
'file',
'input',
'raw_input',
'globals',
'locals',
'vars',
'dir',
'getattr',
'setattr',
'delattr',
'hasattr',
]
expr_lower = expr.lower() expr_lower = expr.lower()
if any(pattern in expr_lower for pattern in dangerous_patterns): if any(pattern in expr_lower for pattern in dangerous_patterns):
return None return None
@ -79,6 +99,7 @@ def use_calculator(expr):
# Evaluate with timeout # Evaluate with timeout
return eval_with_timeout(expr) return eval_with_timeout(expr)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class KVCache: class KVCache:
""" """
@ -173,8 +194,10 @@ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
probs = F.softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=rng) return torch.multinomial(probs, num_samples=1, generator=rng)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class RowState: class RowState:
# Per-row state tracking during generation # Per-row state tracking during generation
def __init__(self, current_tokens=None): def __init__(self, current_tokens=None):
@ -184,8 +207,8 @@ class RowState:
self.python_expr_tokens = [] # Tokens of the current python expression self.python_expr_tokens = [] # Tokens of the current python expression
self.completed = False # Whether this row has completed generation self.completed = False # Whether this row has completed generation
class Engine:
class Engine:
def __init__(self, model, tokenizer): def __init__(self, model, tokenizer):
self.model = model self.model = model
self.tokenizer = tokenizer # needed for tool use self.tokenizer = tokenizer # needed for tool use
@ -327,10 +350,13 @@ if __name__ == "__main__":
is equivalent to the faster Engine.generate function here. is equivalent to the faster Engine.generate function here.
""" """
import time import time
# init compute # init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type() device_type = autodetect_device_type()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
# load the model and tokenizer # load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="eval") model, tokenizer, meta = load_model("base", device, phase="eval")

View File

@ -30,17 +30,18 @@ import platform
import signal import signal
import tempfile import tempfile
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclass @dataclass
class ExecutionResult: class ExecutionResult:
"""Result of executing Python code in a sandbox.""" """Result of executing Python code in a sandbox."""
success: bool success: bool
stdout: str stdout: str
stderr: str stderr: str
error: Optional[str] = None error: str | None = None
timeout: bool = False timeout: bool = False
memory_exceeded: bool = False memory_exceeded: bool = False
@ -101,13 +102,13 @@ class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from""" """StringIO that throws an exception when it's read from"""
def read(self, *args, **kwargs): def read(self, *args, **kwargs):
raise IOError raise OSError
def readline(self, *args, **kwargs): def readline(self, *args, **kwargs):
raise IOError raise OSError
def readlines(self, *args, **kwargs): def readlines(self, *args, **kwargs):
raise IOError raise OSError
def readable(self, *args, **kwargs): def readable(self, *args, **kwargs):
"""Returns True if the IO object can be read.""" """Returns True if the IO object can be read."""
@ -131,7 +132,7 @@ def chdir(root):
os.chdir(cwd) os.chdir(cwd)
def reliability_guard(maximum_memory_bytes: Optional[int] = None): def reliability_guard(maximum_memory_bytes: int | None = None):
""" """
This disables various destructive functions and prevents the generated code This disables various destructive functions and prevents the generated code
from interfering with the test (e.g. fork bomb, killing other processes, from interfering with the test (e.g. fork bomb, killing other processes,
@ -147,6 +148,7 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
if platform.uname().system != "Darwin": if platform.uname().system != "Darwin":
# These resource limit calls seem to fail on macOS (Darwin), skip? # These resource limit calls seem to fail on macOS (Darwin), skip?
import resource import resource
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
@ -211,10 +213,9 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
sys.modules["tkinter"] = None sys.modules["tkinter"] = None
def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict): def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: int | None, result_dict):
"""Execute code in a subprocess with safety guards. Results are written to result_dict.""" """Execute code in a subprocess with safety guards. Results are written to result_dict."""
with create_tempdir(): with create_tempdir():
# These system calls are needed when cleaning up tempdir. # These system calls are needed when cleaning up tempdir.
import os import os
import shutil import shutil
@ -228,14 +229,16 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
reliability_guard(maximum_memory_bytes=maximum_memory_bytes) reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
# Default to failure # Default to failure
result_dict.update({ result_dict.update(
{
"success": False, "success": False,
"stdout": "", "stdout": "",
"stderr": "", "stderr": "",
"timeout": False, "timeout": False,
"memory_exceeded": False, "memory_exceeded": False,
"error": None, "error": None,
}) }
)
try: try:
exec_globals = {} exec_globals = {}
@ -253,28 +256,36 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
# uncomment the following line and proceed at your own risk: # uncomment the following line and proceed at your own risk:
exec(code, exec_globals) exec(code, exec_globals)
result_dict.update({ result_dict.update(
{
"success": True, "success": True,
"stdout": stdout_capture.getvalue(), "stdout": stdout_capture.getvalue(),
"stderr": stderr_capture.getvalue(), "stderr": stderr_capture.getvalue(),
}) }
)
except TimeoutException: except TimeoutException:
result_dict.update({ result_dict.update(
{
"timeout": True, "timeout": True,
"error": "Execution timed out", "error": "Execution timed out",
}) }
)
except MemoryError as e: except MemoryError as e:
result_dict.update({ result_dict.update(
{
"memory_exceeded": True, "memory_exceeded": True,
"error": f"Memory limit exceeded: {e}", "error": f"Memory limit exceeded: {e}",
}) }
)
except BaseException as e: except BaseException as e:
result_dict.update({ result_dict.update(
{
"error": f"{type(e).__name__}: {e}", "error": f"{type(e).__name__}: {e}",
}) }
)
# Needed for cleaning up. # Needed for cleaning up.
shutil.rmtree = rmtree shutil.rmtree = rmtree
@ -286,7 +297,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
def execute_code( def execute_code(
code: str, code: str,
timeout: float = 5.0, # 5 seconds default timeout: float = 5.0, # 5 seconds default
maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default maximum_memory_bytes: int | None = 256 * 1024 * 1024, # 256MB default
) -> ExecutionResult: ) -> ExecutionResult:
""" """
Execute Python code in a sandboxed environment. Execute Python code in a sandboxed environment.
@ -310,10 +321,7 @@ def execute_code(
manager = multiprocessing.Manager() manager = multiprocessing.Manager()
result_dict = manager.dict() result_dict = manager.dict()
p = multiprocessing.Process( p = multiprocessing.Process(target=_unsafe_execute, args=(code, timeout, maximum_memory_bytes, result_dict))
target=_unsafe_execute,
args=(code, timeout, maximum_memory_bytes, result_dict)
)
p.start() p.start()
p.join(timeout=timeout + 1) p.join(timeout=timeout + 1)
@ -346,4 +354,3 @@ def execute_code(
timeout=result_dict["timeout"], timeout=result_dict["timeout"],
memory_exceeded=result_dict["memory_exceeded"], memory_exceeded=result_dict["memory_exceeded"],
) )

View File

@ -12,16 +12,17 @@ Notable features:
""" """
import math import math
from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW from nanochat.adamw import DistAdamW
from nanochat.common import get_dist_info
from nanochat.muon import DistMuon, Muon
@dataclass @dataclass
class GPTConfig: class GPTConfig:
@ -48,6 +49,7 @@ def apply_rotary_emb(x, cos, sin):
out = out.to(x.dtype) # ensure input/output dtypes match out = out.to(x.dtype) # ensure input/output dtypes match
return out return out
class CausalSelfAttention(nn.Module): class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
super().__init__() super().__init__()
@ -75,7 +77,11 @@ class CausalSelfAttention(nn.Module):
cos, sin = cos_sin cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = norm(q), norm(k) # QK norm q, k = norm(q), norm(k) # QK norm
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D) q, k, v = (
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
# Apply KV cache: insert current k,v into cache, get the full view so far # Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None: if kv_cache is not None:
@ -84,7 +90,9 @@ class CausalSelfAttention(nn.Module):
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass) Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
# Attention: queries attend to keys/values autoregressively. A few cases to handle: # Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired enable_gqa = (
self.n_head != self.n_kv_head
) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk: if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention # During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk # And even if there is KV cache, we can still use this simple version when Tq == Tk
@ -139,10 +147,12 @@ class GPT(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = nn.ModuleDict({ self.transformer = nn.ModuleDict(
{
"wte": nn.Embedding(config.vocab_size, config.n_embd), "wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
}) }
)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# To support meta device initialization, we init the rotary embeddings here, but it's fake # To support meta device initialization, we init the rotary embeddings here, but it's fake
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
@ -206,7 +216,12 @@ class GPT(nn.Module):
"""Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311""" """Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311"""
nparams = sum(p.numel() for p in self.parameters()) nparams = sum(p.numel() for p in self.parameters())
nparams_embedding = self.transformer.wte.weight.numel() nparams_embedding = self.transformer.wte.weight.numel()
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len l, h, q, t = (
self.config.n_layer,
self.config.n_head,
self.config.n_embd // self.config.n_head,
self.config.sequence_len,
)
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token return num_flops_per_token
@ -245,8 +260,12 @@ class GPT(nn.Module):
B, T = idx.size() B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert T <= self.cos.size(1), (
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
)
assert idx.device == self.cos.device, (
f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
)
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos() T0 = 0 if kv_cache is None else kv_cache.get_pos()
@ -267,7 +286,9 @@ class GPT(nn.Module):
logits = self.lm_head(x) logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = logits.float() # use tf32/fp32 for logits logits = logits.float() # use tf32/fp32 for logits
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction
)
return loss return loss
else: else:
# inference mode: compute and return the logits # inference mode: compute and return the logits

View File

@ -1,10 +1,13 @@
""" """
A number of functions that help with evaluating a base model. A number of functions that help with evaluating a base model.
""" """
import math import math
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@torch.no_grad() @torch.no_grad()
def evaluate_bpb(model, batches, steps, token_bytes): def evaluate_bpb(model, batches, steps, token_bytes):
""" """
@ -39,11 +42,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
valid = y >= 0 valid = y >= 0
y_safe = torch.where(valid, y, torch.zeros_like(y)) y_safe = torch.where(valid, y, torch.zeros_like(y))
# map valid targets to their byte length; ignored targets contribute 0 bytes # map valid targets to their byte length; ignored targets contribute 0 bytes
num_bytes2d = torch.where( num_bytes2d = torch.where(valid, token_bytes[y_safe], torch.zeros_like(y, dtype=token_bytes.dtype))
valid,
token_bytes[y_safe],
torch.zeros_like(y, dtype=token_bytes.dtype)
)
total_nats += (loss2d * (num_bytes2d > 0)).sum() total_nats += (loss2d * (num_bytes2d > 0)).sum()
total_bytes += num_bytes2d.sum() total_bytes += num_bytes2d.sum()
else: else:

View File

@ -2,9 +2,11 @@
Muon optimizer from Keller et al. Muon optimizer from Keller et al.
Also a lot of borrowing of ideas from modded-nanogpt. Also a lot of borrowing of ideas from modded-nanogpt.
""" """
import torch import torch
from torch import Tensor
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor
@torch.compile @torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
@ -17,7 +19,9 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD. performance at all relative to UV^T, where USV^T = G is the SVD.
""" """
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng assert (
G.ndim >= 2
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315) a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16() X = G.bfloat16()
if G.size(-2) > G.size(-1): if G.size(-2) > G.size(-1):
@ -28,13 +32,16 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
# Perform the NS iterations # Perform the NS iterations
for _ in range(steps): for _ in range(steps):
A = X @ X.mT A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng B = (
b * A + c * A @ A
) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X X = a * X + B @ X
if G.size(-2) > G.size(-1): if G.size(-2) > G.size(-1):
X = X.mT X = X.mT
return X return X
class Muon(torch.optim.Optimizer): class Muon(torch.optim.Optimizer):
""" """
Muon - MomentUm Orthogonalized by Newton-schulz Muon - MomentUm Orthogonalized by Newton-schulz
@ -57,6 +64,7 @@ class Muon(torch.optim.Optimizer):
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use. ns_steps: The number of Newton-Schulz iteration steps to use.
""" """
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params: list[Tensor] = [*params] params: list[Tensor] = [*params]
@ -104,8 +112,8 @@ class DistMuon(torch.optim.Optimizer):
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
ns_steps: number of NewtonSchulz iterations for the orthogonalization ns_steps: number of NewtonSchulz iterations for the orthogonalization
""" """
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
nesterov: bool = True, ns_steps: int = 5): def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, nesterov: bool = True, ns_steps: int = 5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params = list(params) params = list(params)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only" assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
@ -129,7 +137,9 @@ class DistMuon(torch.optim.Optimizer):
world_size = dist.get_world_size() world_size = dist.get_world_size()
# Ensure all grads exist # Ensure all grads exist
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads" assert all(p.grad is not None for group in self.param_groups for p in group["params"]), (
"All params must have grads"
)
# Kick off all the reduce scatter operations to average up the gradients across all ranks # Kick off all the reduce scatter operations to average up the gradients across all ranks
all_reduce_futures = [] all_reduce_futures = []
@ -174,7 +184,7 @@ class DistMuon(torch.optim.Optimizer):
buf.lerp_(g, 1.0 - group["momentum"]) buf.lerp_(g, 1.0 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5) scale = max(1.0, p.size(-2) / p.size(-1)) ** 0.5
p.add_(g, alpha=-group["lr"] * scale) p.add_(g, alpha=-group["lr"] * scale)
# Replicate updated parameters to all ranks # Replicate updated parameters to all ranks
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer

View File

@ -2,16 +2,18 @@
Utilities for generating training report cards. More messy code than usual, will fix. Utilities for generating training report cards. More messy code than usual, will fix.
""" """
import datetime
import os import os
import platform
import re import re
import shutil import shutil
import subprocess
import socket import socket
import datetime import subprocess
import platform
import psutil import psutil
import torch import torch
def run_command(cmd): def run_command(cmd):
"""Run a shell command and return output, or None if it fails.""" """Run a shell command and return output, or None if it fails."""
try: try:
@ -22,6 +24,7 @@ def run_command(cmd):
except: except:
return None return None
def get_git_info(): def get_git_info():
"""Get current git commit, branch, and dirty status.""" """Get current git commit, branch, and dirty status."""
info = {} info = {}
@ -38,18 +41,14 @@ def get_git_info():
return info return info
def get_gpu_info(): def get_gpu_info():
"""Get GPU information.""" """Get GPU information."""
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return {"available": False} return {"available": False}
num_devices = torch.cuda.device_count() num_devices = torch.cuda.device_count()
info = { info = {"available": True, "count": num_devices, "names": [], "memory_gb": []}
"available": True,
"count": num_devices,
"names": [],
"memory_gb": []
}
for i in range(num_devices): for i in range(num_devices):
props = torch.cuda.get_device_properties(i) props = torch.cuda.get_device_properties(i)
@ -61,6 +60,7 @@ def get_gpu_info():
return info return info
def get_system_info(): def get_system_info():
"""Get system information.""" """Get system information."""
info = {} info = {}
@ -83,6 +83,7 @@ def get_system_info():
return info return info
def estimate_cost(gpu_info, runtime_hours=None): def estimate_cost(gpu_info, runtime_hours=None):
"""Estimate training cost based on GPU type and runtime.""" """Estimate training cost based on GPU type and runtime."""
@ -111,9 +112,10 @@ def estimate_cost(gpu_info, runtime_hours=None):
return { return {
"hourly_rate": hourly_rate, "hourly_rate": hourly_rate,
"gpu_type": gpu_name, "gpu_type": gpu_name,
"estimated_total": hourly_rate * runtime_hours if runtime_hours else None "estimated_total": hourly_rate * runtime_hours if runtime_hours else None,
} }
def generate_header(): def generate_header():
"""Generate the header for a training report.""" """Generate the header for a training report."""
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -170,7 +172,7 @@ Generated: {timestamp}
# count dependencies via uv.lock # count dependencies via uv.lock
uv_lock_lines = 0 uv_lock_lines = 0
if os.path.exists('uv.lock'): if os.path.exists('uv.lock'):
with open('uv.lock', 'r', encoding='utf-8') as f: with open('uv.lock', encoding='utf-8') as f:
uv_lock_lines = len(f.readlines()) uv_lock_lines = len(f.readlines())
header += f""" header += f"""
@ -184,12 +186,15 @@ Generated: {timestamp}
""" """
return header return header
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def slugify(text): def slugify(text):
"""Slugify a text string.""" """Slugify a text string."""
return text.lower().replace(" ", "-") return text.lower().replace(" ", "-")
# the expected files and their order # the expected files and their order
EXPECTED_FILES = [ EXPECTED_FILES = [
"tokenizer-training.md", "tokenizer-training.md",
@ -207,6 +212,7 @@ EXPECTED_FILES = [
# the metrics we're currently interested in # the metrics we're currently interested in
chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"] chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
def extract(section, keys): def extract(section, keys):
"""simple def to extract a single key from a section""" """simple def to extract a single key from a section"""
if not isinstance(keys, list): if not isinstance(keys, list):
@ -218,6 +224,7 @@ def extract(section, keys):
out[key] = line.split(":")[1].strip() out[key] = line.split(":")[1].strip()
return out return out
def extract_timestamp(content, prefix): def extract_timestamp(content, prefix):
"""Extract timestamp from content with given prefix.""" """Extract timestamp from content with given prefix."""
for line in content.split('\n'): for line in content.split('\n'):
@ -229,6 +236,7 @@ def extract_timestamp(content, prefix):
pass pass
return None return None
class Report: class Report:
"""Maintains a bunch of logs, generates a final markdown report.""" """Maintains a bunch of logs, generates a final markdown report."""
@ -276,7 +284,7 @@ class Report:
# write the header first # write the header first
header_file = os.path.join(report_dir, "header.md") header_file = os.path.join(report_dir, "header.md")
if os.path.exists(header_file): if os.path.exists(header_file):
with open(header_file, "r", encoding="utf-8") as f: with open(header_file, encoding="utf-8") as f:
header_content = f.read() header_content = f.read()
out_file.write(header_content) out_file.write(header_content)
start_time = extract_timestamp(header_content, "Run started:") start_time = extract_timestamp(header_content, "Run started:")
@ -293,7 +301,7 @@ class Report:
if not os.path.exists(section_file): if not os.path.exists(section_file):
print(f"Warning: {section_file} does not exist, skipping") print(f"Warning: {section_file} does not exist, skipping")
continue continue
with open(section_file, "r", encoding="utf-8") as in_file: with open(section_file, encoding="utf-8") as in_file:
section = in_file.read() section = in_file.read()
# Extract timestamp from this section (the last section's timestamp will "stick" as end_time) # Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
if "rl" not in file_name: if "rl" not in file_name:
@ -354,7 +362,7 @@ class Report:
else: else:
out_file.write("Total wall clock time: unknown\n") out_file.write("Total wall clock time: unknown\n")
# also cp the report.md file to current directory # also cp the report.md file to current directory
print(f"Copying report.md to current directory for convenience") print("Copying report.md to current directory for convenience")
shutil.copy(report_file, "report.md") shutil.copy(report_file, "report.md")
return report_file return report_file
@ -378,18 +386,23 @@ class Report:
f.write(f"Run started: {start_time}\n\n---\n\n") f.write(f"Run started: {start_time}\n\n---\n\n")
print(f"Reset report and wrote header to {header_file}") print(f"Reset report and wrote header to {header_file}")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# nanochat-specific convenience functions # nanochat-specific convenience functions
class DummyReport: class DummyReport:
def log(self, *args, **kwargs): def log(self, *args, **kwargs):
pass pass
def reset(self, *args, **kwargs): def reset(self, *args, **kwargs):
pass pass
def get_report(): def get_report():
# just for convenience, only rank 0 logs to report # just for convenience, only rank 0 logs to report
from nanochat.common import get_base_dir, get_dist_info from nanochat.common import get_base_dir, get_dist_info
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp_rank == 0: if ddp_rank == 0:
report_dir = os.path.join(get_base_dir(), "report") report_dir = os.path.join(get_base_dir(), "report")
@ -397,10 +410,18 @@ def get_report():
else: else:
return DummyReport() return DummyReport()
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.") parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)") parser.add_argument(
"command",
nargs="?",
default="generate",
choices=["generate", "reset"],
help="Operation to perform (default: generate)",
)
args = parser.parse_args() args = parser.parse_args()
if args.command == "generate": if args.command == "generate":
get_report().generate() get_report().generate()

View File

@ -6,8 +6,8 @@ Two implementations are available:
2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference 2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
""" """
import os
import copy import copy
import os
from functools import lru_cache from functools import lru_cache
SPECIAL_TOKENS = [ SPECIAL_TOKENS = [
@ -27,15 +27,18 @@ SPECIAL_TOKENS = [
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3} # NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes. # I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
# I haven't validated that this is actually a good idea, TODO. # I haven't validated that this is actually a good idea, TODO.
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" SPLIT_PATTERN = (
r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Generic GPT-4-style tokenizer based on HuggingFace Tokenizer # Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
from tokenizers import Regex, decoders, pre_tokenizers
from tokenizers import Tokenizer as HFTokenizer from tokenizers import Tokenizer as HFTokenizer
from tokenizers import pre_tokenizers, decoders, Regex
from tokenizers.models import BPE from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer from tokenizers.trainers import BpeTrainer
class HuggingFaceTokenizer: class HuggingFaceTokenizer:
"""Light wrapper around HuggingFace Tokenizer for some utilities""" """Light wrapper around HuggingFace Tokenizer for some utilities"""
@ -59,11 +62,13 @@ class HuggingFaceTokenizer:
def train_from_iterator(cls, text_iterator, vocab_size): def train_from_iterator(cls, text_iterator, vocab_size):
# train from an iterator of text # train from an iterator of text
# Configure the HuggingFace Tokenizer # Configure the HuggingFace Tokenizer
tokenizer = HFTokenizer(BPE( tokenizer = HFTokenizer(
BPE(
byte_fallback=True, # needed! byte_fallback=True, # needed!
unk_token=None, unk_token=None,
fuse_unk=False, fuse_unk=False,
)) )
)
# Normalizer: None # Normalizer: None
tokenizer.normalizer = None tokenizer.normalizer = None
# Pre-tokenizer: GPT-4 style # Pre-tokenizer: GPT-4 style
@ -72,10 +77,12 @@ class HuggingFaceTokenizer:
# very small models and smaller vocab sizes, because it is a little bit wasteful in the token space. # very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
# (but I haven't validated this! TODO) # (but I haven't validated this! TODO)
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!! gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False), pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False) pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
]) ]
)
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer) # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
tokenizer.decoder = decoders.ByteLevel() tokenizer.decoder = decoders.ByteLevel()
# Post-processor: None # Post-processor: None
@ -146,12 +153,16 @@ class HuggingFaceTokenizer:
self.tokenizer.save(tokenizer_path) self.tokenizer.save(tokenizer_path)
print(f"Saved tokenizer to {tokenizer_path}") print(f"Saved tokenizer to {tokenizer_path}")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Tokenizer based on rustbpe + tiktoken combo # Tokenizer based on rustbpe + tiktoken combo
import pickle import pickle
import rustbpe
import tiktoken import tiktoken
import rustbpe
class RustBPETokenizer: class RustBPETokenizer:
"""Light wrapper around tiktoken (for efficient inference) but train with rustbpe""" """Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
@ -264,6 +275,7 @@ class RustBPETokenizer:
""" """
# ids, masks that we will return and a helper function to help build them up. # ids, masks that we will return and a helper function to help build them up.
ids, mask = [], [] ids, mask = [], []
def add_tokens(token_ids, mask_val): def add_tokens(token_ids, mask_val):
if isinstance(token_ids, int): if isinstance(token_ids, int):
token_ids = [token_ids] token_ids = [token_ids]
@ -286,17 +298,21 @@ class RustBPETokenizer:
# fetch all the special tokens we need # fetch all the special tokens we need
bos = self.get_bos_token_id() bos = self.get_bos_token_id()
user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>") user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>") assistant_start, assistant_end = (
self.encode_special("<|assistant_start|>"),
self.encode_special("<|assistant_end|>"),
)
python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>") python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>") output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
# now we can tokenize the conversation # now we can tokenize the conversation
add_tokens(bos, 0) add_tokens(bos, 0)
for i, message in enumerate(messages): for i, message in enumerate(messages):
# some sanity checking here around assumptions, to prevent footguns # some sanity checking here around assumptions, to prevent footguns
must_be_from = "user" if i % 2 == 0 else "assistant" must_be_from = "user" if i % 2 == 0 else "assistant"
assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}" assert message["role"] == must_be_from, (
f"Message {i} is from {message['role']} but should be from {must_be_from}"
)
# content can be either a simple string or a list of parts (e.g. containing tool calls) # content can be either a simple string or a list of parts (e.g. containing tool calls)
content = message["content"] content = message["content"]
@ -376,23 +392,31 @@ class RustBPETokenizer:
ids.append(assistant_start) ids.append(assistant_start)
return ids return ids
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# nanochat-specific convenience functions # nanochat-specific convenience functions
def get_tokenizer(): def get_tokenizer():
from nanochat.common import get_base_dir from nanochat.common import get_base_dir
base_dir = get_base_dir() base_dir = get_base_dir()
tokenizer_dir = os.path.join(base_dir, "tokenizer") tokenizer_dir = os.path.join(base_dir, "tokenizer")
# return HuggingFaceTokenizer.from_directory(tokenizer_dir) # return HuggingFaceTokenizer.from_directory(tokenizer_dir)
return RustBPETokenizer.from_directory(tokenizer_dir) return RustBPETokenizer.from_directory(tokenizer_dir)
def get_token_bytes(device="cpu"): def get_token_bytes(device="cpu"):
import torch import torch
from nanochat.common import get_base_dir from nanochat.common import get_base_dir
base_dir = get_base_dir() base_dir = get_base_dir()
tokenizer_dir = os.path.join(base_dir, "tokenizer") tokenizer_dir = os.path.join(base_dir, "tokenizer")
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt") token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py" assert os.path.exists(token_bytes_path), (
f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
)
with open(token_bytes_path, "rb") as f: with open(token_bytes_path, "rb") as f:
token_bytes = torch.load(f, map_location=device) token_bytes = torch.load(f, map_location=device)
return token_bytes return token_bytes

View File

@ -9,23 +9,31 @@ torchrun --nproc_per_node=8 -m scripts.base_eval
The script will print the CORE metric to the console. The script will print the CORE metric to the console.
""" """
import os
import csv import csv
import time
import json import json
import yaml import os
import shutil
import random import random
import zipfile import shutil
import tempfile import tempfile
import time
import zipfile
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
import yaml
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
from nanochat.tokenizer import HuggingFaceTokenizer
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.common import (
autodetect_device_type,
compute_cleanup,
compute_init,
download_file_with_lock,
get_base_dir,
print0,
)
from nanochat.core_eval import evaluate_task from nanochat.core_eval import evaluate_task
from nanochat.tokenizer import HuggingFaceTokenizer
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# nanochat specific function dealing with I/O etc. # nanochat specific function dealing with I/O etc.
@ -33,6 +41,7 @@ from nanochat.core_eval import evaluate_task
# ~162MB of data needed to evaluate the CORE metric # ~162MB of data needed to evaluate the CORE metric
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip" EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
def place_eval_bundle(file_path): def place_eval_bundle(file_path):
# here file_path is the path to the eval_bundle.zip file # here file_path is the path to the eval_bundle.zip file
# we need to unzip it and place it in the base directory # we need to unzip it and place it in the base directory
@ -45,6 +54,7 @@ def place_eval_bundle(file_path):
shutil.move(extracted_bundle_dir, eval_bundle_dir) shutil.move(extracted_bundle_dir, eval_bundle_dir)
print0(f"Placed eval_bundle directory at {eval_bundle_dir}") print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
def evaluate_model(model, tokenizer, device, max_per_task=-1): def evaluate_model(model, tokenizer, device, max_per_task=-1):
""" """
Evaluate a base model on the CORE benchmark. Evaluate a base model on the CORE benchmark.
@ -59,13 +69,13 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
config_path = os.path.join(eval_bundle_dir, "core.yaml") config_path = os.path.join(eval_bundle_dir, "core.yaml")
data_base_path = os.path.join(eval_bundle_dir, "eval_data") data_base_path = os.path.join(eval_bundle_dir, "eval_data")
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv") eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
with open(config_path, 'r', encoding='utf-8') as f: with open(config_path, encoding='utf-8') as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
tasks = config['icl_tasks'] tasks = config['icl_tasks']
# Load random baseline values from eval metadata # Load random baseline values from eval metadata
random_baselines = {} random_baselines = {}
with open(eval_meta_data, 'r', encoding='utf-8') as f: with open(eval_meta_data, encoding='utf-8') as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
for row in reader: for row in reader:
task_name = row['Eval Task'] task_name = row['Eval Task']
@ -82,13 +92,13 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
'task_type': task['icl_task_type'], 'task_type': task['icl_task_type'],
'dataset_uri': task['dataset_uri'], 'dataset_uri': task['dataset_uri'],
'num_fewshot': task['num_fewshot'][0], 'num_fewshot': task['num_fewshot'][0],
'continuation_delimiter': task.get('continuation_delimiter', ' ') 'continuation_delimiter': task.get('continuation_delimiter', ' '),
} }
print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='') print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
# Load data for this task # Load data for this task
data_path = os.path.join(data_base_path, task_meta['dataset_uri']) data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
with open(data_path, 'r', encoding='utf-8') as f: with open(data_path, encoding='utf-8') as f:
data = [json.loads(line.strip()) for line in f] data = [json.loads(line.strip()) for line in f]
# shuffle the data because in many cases it appears ordered but we want # shuffle the data because in many cases it appears ordered but we want
@ -109,18 +119,17 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s") print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
core_metric = sum(centered_results.values()) / len(centered_results) core_metric = sum(centered_results.values()) / len(centered_results)
out = { out = {"results": results, "centered_results": centered_results, "core_metric": core_metric}
"results": results,
"centered_results": centered_results,
"core_metric": core_metric
}
return out return out
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# HuggingFace loading utilities and light wrappers for a model # HuggingFace loading utilities and light wrappers for a model
class ModelWrapper: class ModelWrapper:
"""Lightweight wrapper for a HuggingFace model""" """Lightweight wrapper for a HuggingFace model"""
def __init__(self, model, max_seq_len=None): def __init__(self, model, max_seq_len=None):
self.model = model self.model = model
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
@ -130,10 +139,12 @@ class ModelWrapper:
logits = outputs.logits logits = outputs.logits
return logits return logits
def load_hf_model(hf_path: str, device): def load_hf_model(hf_path: str, device):
print0(f"Loading model from: {hf_path}") print0(f"Loading model from: {hf_path}")
# Load the model # Load the model
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(hf_path) model = AutoModelForCausalLM.from_pretrained(hf_path)
model.to(device) model.to(device)
model.eval() model.eval()
@ -143,9 +154,11 @@ def load_hf_model(hf_path: str, device):
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path) tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
return model, tokenizer return model, tokenizer
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def main(): def main():
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate') parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)') parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
@ -154,7 +167,9 @@ def main():
# distributed / precision setup # distributed / precision setup
device_type = autodetect_device_type() device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
# Load model and tokenizer from command line or from file system # Load model and tokenizer from command line or from file system
if args.hf_path is not None: if args.hf_path is not None:
@ -193,20 +208,25 @@ def main():
print0("=" * 80) print0("=" * 80)
print0(f"Model: {model_name}") print0(f"Model: {model_name}")
print0("=" * 80) print0("=" * 80)
with open(output_csv_path, 'r', encoding='utf-8') as f: with open(output_csv_path, encoding='utf-8') as f:
print0(f.read()) print0(f.read())
# Log to report # Log to report
from nanochat.report import get_report from nanochat.report import get_report
get_report().log(section="Base model evaluation", data=[
get_report().log(
section="Base model evaluation",
data=[
{ {
"Model": model_name, "Model": model_name,
"CORE metric": core_metric, "CORE metric": core_metric,
}, },
centered_results, # the full table centered_results, # the full table
]) ],
)
compute_cleanup() compute_cleanup()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -6,15 +6,18 @@ Loads a checkpoint, and:
Example run as: Example run as:
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
""" """
import os import os
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, print0
from nanochat.dataloader import tokenizing_distributed_data_loader from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.tokenizer import get_token_bytes
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine from nanochat.engine import Engine
from nanochat.loss_eval import evaluate_bpb
from nanochat.tokenizer import get_token_bytes
# Configuration # Configuration
device_batch_size = 32 device_batch_size = 32
@ -29,7 +32,9 @@ device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step) model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
# Evaluate the loss on each split # Evaluate the loss on each split
tokens_per_step = device_batch_size * sequence_len * ddp_world_size tokens_per_step = device_batch_size * sequence_len * ddp_world_size
@ -67,13 +72,17 @@ if ddp_rank == 0:
# Log to report # Log to report
from nanochat.report import get_report from nanochat.report import get_report
get_report().log(section="Base model loss", data=[
get_report().log(
section="Base model loss",
data=[
{ {
"train bpb": bpb_results["train"], "train bpb": bpb_results["train"],
"val bpb": bpb_results["val"], "val bpb": bpb_results["val"],
}, },
{f"sample {i}": sample for i, sample in enumerate(samples)}, {f"sample {i}": sample for i, sample in enumerate(samples)},
]) ],
)
# Cleanup # Cleanup
compute_cleanup() compute_cleanup()

View File

@ -12,21 +12,31 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 -
""" """
import os import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time import time
from contextlib import nullcontext from contextlib import nullcontext
import wandb
import torch import torch
import wandb
from nanochat.gpt import GPT, GPTConfig from nanochat.checkpoint_manager import load_checkpoint, save_checkpoint
from nanochat.common import (
DummyWandb,
autodetect_device_type,
compute_cleanup,
compute_init,
get_base_dir,
print0,
print_banner,
)
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine from nanochat.engine import Engine
from nanochat.gpt import GPT, GPTConfig
from nanochat.loss_eval import evaluate_bpb
from nanochat.tokenizer import get_token_bytes, get_tokenizer
from scripts.base_eval import evaluate_model from scripts.base_eval import evaluate_model
print_banner() print_banner()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -39,8 +49,12 @@ depth = 20 # the depth of the Transformer model to train, rest of the kwargs are
max_seq_len = 2048 # max context length max_seq_len = 2048 # max context length
# Training horizon. Only one of these 3 will be used, in this order of precedence. # Training horizon. Only one of these 3 will be used, in this order of precedence.
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable) target_flops = (
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable) -1.0
) # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
target_param_data_ratio = (
20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
)
# Optimization # Optimization
device_batch_size = 32 # per-device batch size (set to not OOM) device_batch_size = 32 # per-device batch size (set to not OOM)
total_batch_size = 524288 # total desired batch size, in #tokens total_batch_size = 524288 # total desired batch size, in #tokens
@ -72,7 +86,9 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
device_type = autodetect_device_type() if device_type == "" else device_type device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
@ -110,7 +126,14 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {
# Initialize the Model # Initialize the Model
# Create a new model with random weights # Create a new model with random weights
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) model_config_kwargs = dict(
sequence_len=max_seq_len,
vocab_size=vocab_size,
n_layer=num_layers,
n_head=num_heads,
n_kv_head=num_kv_heads,
n_embd=model_dim,
)
with torch.device("meta"): with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs) model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config) model = GPT(model_config)
@ -124,7 +147,9 @@ checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
resuming = resume_from_step != -1 resuming = resume_from_step != -1
if resuming: if resuming:
print0(f"Resuming optimization from step {resume_from_step}") print0(f"Resuming optimization from step {resume_from_step}")
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank) model_data, optimizer_data, meta_data = load_checkpoint(
checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank
)
model.load_state_dict(model_data, strict=True, assign=True) model.load_state_dict(model_data, strict=True, assign=True)
del model_data # free up this memory after the copy del model_data # free up this memory after the copy
@ -157,7 +182,9 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay
)
adamw_optimizer, muon_optimizer = optimizers adamw_optimizer, muon_optimizer = optimizers
if resuming: if resuming:
@ -169,13 +196,18 @@ if resuming:
# Initialize the DataLoaders for train/val # Initialize the DataLoaders for train/val
tokens_dir = os.path.join(base_dir, "tokenized_data") tokens_dir = os.path.join(base_dir, "tokenized_data")
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) train_loader = tokenizing_distributed_data_loader_with_state(
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device) device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict
)
build_val_loader = lambda: tokenizing_distributed_data_loader(
device_batch_size, max_seq_len, split="val", device=device
)
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Set up hyperparameter schedulers # Set up hyperparameter schedulers
# Learning rate scheduler # Learning rate scheduler
def get_lr_multiplier(it): def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations) warmup_iters = round(warmup_ratio * num_iterations)
@ -188,12 +220,14 @@ def get_lr_multiplier(it):
progress = (num_iterations - it) / warmdown_iters progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac return progress * 1.0 + (1 - progress) * final_lr_frac
# Momentum scheduler for Muon optimizer # Momentum scheduler for Muon optimizer
def get_muon_momentum(it): def get_muon_momentum(it):
frac = min(it / 300, 1) frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95 momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum return momentum
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Loop state (variables updated by the training loop) # Loop state (variables updated by the training loop)
@ -225,12 +259,14 @@ while True:
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb: if val_bpb < min_val_bpb:
min_val_bpb = val_bpb min_val_bpb = val_bpb
wandb_run.log({ wandb_run.log(
{
"step": step, "step": step,
"total_training_flops": flops_so_far, "total_training_flops": flops_so_far,
"total_training_time": total_training_time, "total_training_time": total_training_time,
"val/bpb": val_bpb, "val/bpb": val_bpb,
}) }
)
model.train() model.train()
# once in a while: estimate the CORE metric (all ranks participate) # once in a while: estimate the CORE metric (all ranks participate)
@ -241,12 +277,14 @@ while True:
with autocast_ctx: with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task) results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({ wandb_run.log(
{
"step": step, "step": step,
"total_training_flops": flops_so_far, "total_training_flops": flops_so_far,
"core_metric": results["core_metric"], "core_metric": results["core_metric"],
"centered_results": results["centered_results"], "centered_results": results["centered_results"],
}) }
)
model.train() model.train()
# once in a while: sample from the model (only on master process) # once in a while: sample from the model (only on master process)
@ -309,7 +347,9 @@ while True:
train_loss = loss.detach() # for logging train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward x, y, dataloader_state_dict = next(
train_loader
) # prefetch the next batch while the GPU is busy with forward/backward
# gradient clipping # gradient clipping
grad_clip_enabled = grad_clip > 0.0 grad_clip_enabled = grad_clip > 0.0
if grad_clip_enabled: if grad_clip_enabled:
@ -343,7 +383,9 @@ while True:
if step > 10: if step > 10:
total_training_time += dt # only count the time after the first 10 steps total_training_time += dt # only count the time after the first 10 steps
print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else "" print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else ""
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") print0(
f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time / 60:.2f}m"
)
if step % 100 == 0: if step % 100 == 0:
log_data = { log_data = {
"step": step, "step": step,
@ -369,7 +411,10 @@ print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report # Log to report
from nanochat.report import get_report from nanochat.report import get_report
get_report().log(section="Base model training", data=[
get_report().log(
section="Base model training",
data=[
user_config, # CLI args user_config, # CLI args
{ # stats about the training setup { # stats about the training setup
"Number of parameters": num_params, "Number of parameters": num_params,
@ -390,8 +435,9 @@ get_report().log(section="Base model training", data=[
"Total training flops": f"{flops_so_far:e}", "Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time / 60:.2f}m", "Total training time": f"{total_training_time / 60:.2f}m",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
} },
]) ],
)
# cleanup # cleanup
wandb_run.finish() # wandb run finish wandb_run.finish() # wandb run finish

View File

@ -4,12 +4,15 @@ New and upgraded chat mode because a lot of the code has changed since the last
Intended to be run single GPU only atm: Intended to be run single GPU only atm:
python -m scripts.chat_cli -i mid python -m scripts.chat_cli -i mid
""" """
import argparse import argparse
import torch
from nanochat.common import compute_init, autodetect_device_type
from contextlib import nullcontext from contextlib import nullcontext
from nanochat.engine import Engine
import torch
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.common import autodetect_device_type, compute_init
from nanochat.engine import Engine
parser = argparse.ArgumentParser(description='Chat with the model') parser = argparse.ArgumentParser(description='Chat with the model')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
@ -18,7 +21,13 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back') parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation') parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter') parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') parser.add_argument(
'--device-type',
type=str,
default='',
choices=['cuda', 'cpu', 'mps'],
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
)
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
args = parser.parse_args() args = parser.parse_args()
@ -33,7 +42,10 @@ model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag
# Special tokens for the chat state machine # Special tokens for the chat state machine
bos = tokenizer.get_bos_token_id() bos = tokenizer.get_bos_token_id()
user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>") user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>") assistant_start, assistant_end = (
tokenizer.encode_special("<|assistant_start|>"),
tokenizer.encode_special("<|assistant_end|>"),
)
# Create Engine for efficient generation # Create Engine for efficient generation
engine = Engine(model, tokenizer) engine = Engine(model, tokenizer)
@ -47,7 +59,6 @@ print("-" * 50)
conversation_tokens = [bos] conversation_tokens = [bos]
while True: while True:
if args.prompt: if args.prompt:
# Get the prompt from the launch command # Get the prompt from the launch command
user_input = args.prompt user_input = args.prompt

View File

@ -9,27 +9,28 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
""" """
import argparse import argparse
from functools import partial
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, get_dist_info, print0
from nanochat.engine import Engine from nanochat.engine import Engine
from tasks.humaneval import HumanEval
from tasks.mmlu import MMLU
from tasks.arc import ARC from tasks.arc import ARC
from tasks.gsm8k import GSM8K from tasks.gsm8k import GSM8K
from tasks.humaneval import HumanEval
from tasks.mmlu import MMLU
from tasks.spellingbee import SpellingBee from tasks.spellingbee import SpellingBee
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Generative evaluation loop (we go one problem at a time, sample, evaluate) # Generative evaluation loop (we go one problem at a time, sample, evaluate)
def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None):
def run_generative_eval(
task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None
):
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
device = model.get_device() device = model.get_device()
@ -82,13 +83,14 @@ def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_
# Return the accuracy # Return the accuracy
return num_passed / total return num_passed / total
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Categorical evaluation loop # Categorical evaluation loop
# A lot easier because we don't have to sample. Therefore, we can actually go # A lot easier because we don't have to sample. Therefore, we can actually go
# batches at a time and just check the logits for correct answer choices. # batches at a time and just check the logits for correct answer choices.
def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
device = model.get_device() device = model.get_device()
bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
@ -106,9 +108,13 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
# Prepare the batch of problems. They might all be of different length, so we pad/collate them. # Prepare the batch of problems. They might all be of different length, so we pad/collate them.
conversations = [task_object[ii] for ii in range(i0, i1)] conversations = [task_object[ii] for ii in range(i0, i1)]
prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works prompt_ids = [
tokenizer.render_for_completion(conversation) for conversation in conversations
] # TODO: remake the way this works
max_length = max(len(ids) for ids in prompt_ids) max_length = max(len(ids) for ids in prompt_ids)
answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer) answer_time_positions = [
len(ids) - 1 for ids in prompt_ids
] # where the last token is (and the predicted answer)
padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids] padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids]
prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device) prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device)
@ -154,11 +160,22 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
print0(f"Final: {num_passed}/{total} ({100 * average:.2f}%)") print0(f"Final: {num_passed}/{total} ({100 * average:.2f}%)")
return average return average
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def run_chat_eval(task_name, model, tokenizer, engine,
batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50, def run_chat_eval(
max_problems=None): task_name,
model,
tokenizer,
engine,
batch_size=1,
num_samples=1,
max_new_tokens=512,
temperature=0.0,
top_k=50,
max_problems=None,
):
# Create the evaluation object # Create the evaluation object
task_module = { task_module = {
'HumanEval': HumanEval, 'HumanEval': HumanEval,
@ -171,20 +188,36 @@ def run_chat_eval(task_name, model, tokenizer, engine,
task_object = task_module() task_object = task_module()
# Run the evaluation # Run the evaluation
if task_object.eval_type == 'generative': if task_object.eval_type == 'generative':
acc = run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=max_problems) acc = run_generative_eval(
task_object,
tokenizer,
model,
engine,
num_samples,
max_new_tokens,
temperature,
top_k,
max_problems=max_problems,
)
elif task_object.eval_type == 'categorical': elif task_object.eval_type == 'categorical':
acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems) acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems)
else: else:
raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}") raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}")
return acc return acc
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == "__main__": if __name__ == "__main__":
# Parse command-line arguments # Parse command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl") parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl")
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.") parser.add_argument(
'-a',
'--task-name',
type=str,
default=None,
help="Task name. Default = all tasks. Use | to split multiple tasks.",
)
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('-t', '--temperature', type=float, default=0.0) parser.add_argument('-t', '--temperature', type=float, default=0.0)
parser.add_argument('-m', '--max-new-tokens', type=int, default=512) parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
@ -194,13 +227,21 @@ if __name__ == "__main__":
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate') parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') parser.add_argument(
'--device-type',
type=str,
default='',
choices=['cuda', 'cpu', 'mps'],
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
)
args = parser.parse_args() args = parser.parse_args()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
)
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
engine = Engine(model, tokenizer) engine = Engine(model, tokenizer)
@ -223,7 +264,9 @@ if __name__ == "__main__":
with autocast_ctx: with autocast_ctx:
acc = run_chat_eval( acc = run_chat_eval(
task_name, task_name,
model, tokenizer, engine, model,
tokenizer,
engine,
batch_size=args.batch_size, batch_size=args.batch_size,
num_samples=args.num_samples, num_samples=args.num_samples,
max_new_tokens=args.max_new_tokens, max_new_tokens=args.max_new_tokens,
@ -236,6 +279,7 @@ if __name__ == "__main__":
# Log to report # Log to report
from nanochat.report import get_report from nanochat.report import get_report
all_tasks_were_evaluated = all(task_name in results for task_name in all_tasks) all_tasks_were_evaluated = all(task_name in results for task_name in all_tasks)
# calculate the ChatCORE metric if we can (similar to CORE, it's the mean centered accuracy) # calculate the ChatCORE metric if we can (similar to CORE, it's the mean centered accuracy)
# this way, ChatCORE ranges from 0 (at random baseline) to 1 (peak performance) # this way, ChatCORE ranges from 0 (at random baseline) to 1 (peak performance)
@ -248,10 +292,13 @@ if __name__ == "__main__":
centered_mean += centered_acc centered_mean += centered_acc
chatcore_metric = centered_mean / len(results) chatcore_metric = centered_mean / len(results)
chatcore_metric_dict = {"ChatCORE metric": chatcore_metric} chatcore_metric_dict = {"ChatCORE metric": chatcore_metric}
get_report().log(section="Chat evaluation " + args.source, data=[ get_report().log(
section="Chat evaluation " + args.source,
data=[
vars(args), # CLI args vars(args), # CLI args
results, results,
chatcore_metric_dict, chatcore_metric_dict,
]) ],
)
compute_cleanup() compute_cleanup()

View File

@ -16,15 +16,15 @@ python -m scripts.chat_rl
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
""" """
import os
import itertools import itertools
import re import os
import wandb
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import wandb
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb from nanochat.checkpoint_manager import load_model, save_checkpoint
from nanochat.checkpoint_manager import save_checkpoint, load_model from nanochat.common import DummyWandb, compute_cleanup, compute_init, get_base_dir, print0
from nanochat.engine import Engine from nanochat.engine import Engine
from tasks.gsm8k import GSM8K from tasks.gsm8k import GSM8K
@ -75,12 +75,16 @@ val_task = GSM8K(subset="main", split="test")
num_steps = (len(train_task) // examples_per_step) * num_epochs num_steps = (len(train_task) // examples_per_step) * num_epochs
print0(f"Calculated number of steps: {num_steps}") print0(f"Calculated number of steps: {num_steps}")
@torch.no_grad() @torch.no_grad()
def get_batch(): def get_batch():
assistant_end = tokenizer.encode_special("<|assistant_end|>") # ok to use this token, it's only for padding and isn't used in the loss. assistant_end = tokenizer.encode_special(
rank_indices = range(ddp_rank, len(train_task), ddp_world_size) # each rank is responsible for different examples in the training data "<|assistant_end|>"
) # ok to use this token, it's only for padding and isn't used in the loss.
rank_indices = range(
ddp_rank, len(train_task), ddp_world_size
) # each rank is responsible for different examples in the training data
for example_idx in itertools.cycle(rank_indices): for example_idx in itertools.cycle(rank_indices):
# First get the full conversation of both user and assistant messages # First get the full conversation of both user and assistant messages
conversation = train_task[example_idx] conversation = train_task[example_idx]
@ -121,7 +125,9 @@ def get_batch():
# Pad the sequences so that their lengths (in time) match # Pad the sequences so that their lengths (in time) match
max_length = max(len(seq) for seq in generated_token_sequences) max_length = max(len(seq) for seq in generated_token_sequences)
padded_generated_token_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences] padded_generated_token_sequences = [
seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences
]
padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks] padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
# Stack up the sequences and masks into PyTorch tensors # Stack up the sequences and masks into PyTorch tensors
ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device) ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
@ -139,14 +145,11 @@ def get_batch():
# yield inputs/targets as (B, T) of ids and rewards as (B,) of floats # yield inputs/targets as (B, T) of ids and rewards as (B,) of floats
yield generated_token_sequences, inputs, targets, rewards, advantages yield generated_token_sequences, inputs, targets, rewards, advantages
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Simple evaluation loop for GSM8K pass@k # Simple evaluation loop for GSM8K pass@k
def run_gsm8k_eval(task, tokenizer, engine, def run_gsm8k_eval(
max_examples=None, task, tokenizer, engine, max_examples=None, num_samples=1, max_completion_tokens=256, temperature=0.0, top_k=50
num_samples=1,
max_completion_tokens=256,
temperature=0.0,
top_k=50
): ):
""" """
Evaluates GSM8K task and returns a list of records of evaluation outcomes. Evaluates GSM8K task and returns a list of records of evaluation outcomes.
@ -162,11 +165,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
# Generate k samples using batched generation inside the Engine # Generate k samples using batched generation inside the Engine
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not... assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
generated_token_sequences, masks = engine.generate_batch( generated_token_sequences, masks = engine.generate_batch(
tokens, tokens, num_samples=num_samples, max_tokens=max_completion_tokens, temperature=temperature, top_k=top_k
num_samples=num_samples,
max_tokens=max_completion_tokens,
temperature=temperature,
top_k=top_k
) )
# Check each sample for correctness # Check each sample for correctness
outcomes = [] outcomes = []
@ -174,9 +173,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
generated_tokens = sample_tokens[prefix_length:] generated_tokens = sample_tokens[prefix_length:]
generated_text = tokenizer.decode(generated_tokens) generated_text = tokenizer.decode(generated_tokens)
is_correct = task.evaluate(conversation, generated_text) is_correct = task.evaluate(conversation, generated_text)
outcomes.append({ outcomes.append({"is_correct": is_correct})
"is_correct": is_correct
})
# A bit bloated because I wanted to do more complex logging at one point. # A bit bloated because I wanted to do more complex logging at one point.
record = { record = {
"idx": idx, "idx": idx,
@ -184,6 +181,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
} }
yield record yield record
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Training loop # Training loop
@ -201,11 +199,13 @@ for opt in optimizers:
group["lr"] = group["lr"] * init_lr_frac group["lr"] = group["lr"] * init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Learning rate scheduler: simple rampdown to zero over num_steps # Learning rate scheduler: simple rampdown to zero over num_steps
def get_lr_multiplier(it): def get_lr_multiplier(it):
lrm = 1.0 - it / num_steps lrm = 1.0 - it / num_steps
return lrm return lrm
# Calculate the number of examples each rank handles to achieve the desired examples_per_step # Calculate the number of examples each rank handles to achieve the desired examples_per_step
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks" assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
@ -215,13 +215,14 @@ print0(f"Calculated examples per rank: {examples_per_rank}")
# Kick off the training loop # Kick off the training loop
batch_iterator = get_batch() batch_iterator = get_batch()
for step in range(num_steps): for step in range(num_steps):
# Evaluate the model once in a while and log to wandb # Evaluate the model once in a while and log to wandb
if step % eval_every == 0: if step % eval_every == 0:
model.eval() model.eval()
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
with autocast_ctx: with autocast_ctx:
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0) records_iter = run_gsm8k_eval(
val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0
)
records = list(records_iter) # collect all records records = list(records_iter) # collect all records
for k in range(1, device_batch_size + 1): for k in range(1, device_batch_size + 1):
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records) passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
@ -233,10 +234,12 @@ for step in range(num_steps):
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)] print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
print0(f"Step {step} | {', '.join(print_passk)}") print0(f"Step {step} | {', '.join(print_passk)}")
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)} log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
wandb_run.log({ wandb_run.log(
{
"step": step, "step": step,
**log_passk, **log_passk,
}) }
)
# Forward/Backward on rollouts over multiple examples in the dataset # Forward/Backward on rollouts over multiple examples in the dataset
rewards_list = [] rewards_list = []
@ -268,7 +271,9 @@ for step in range(num_steps):
# Finally, formulate the loss that we want to minimize (instead of objective we wish to maximize) # Finally, formulate the loss that we want to minimize (instead of objective we wish to maximize)
loss = -pg_obj loss = -pg_obj
loss.backward() loss.backward()
print0(f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}") print0(
f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}"
)
# For logging # For logging
rewards_list.append(rewards_all.mean().item()) rewards_list.append(rewards_all.mean().item())
sequence_lengths.extend(len(seq) for seq in sequences_all) sequence_lengths.extend(len(seq) for seq in sequences_all)
@ -283,12 +288,16 @@ for step in range(num_steps):
dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG) dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG)
mean_reward = mean_reward_tensor.item() mean_reward = mean_reward_tensor.item()
mean_sequence_length = mean_sequence_length_tensor.item() mean_sequence_length = mean_sequence_length_tensor.item()
print0(f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}") print0(
wandb_run.log({ f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}"
)
wandb_run.log(
{
"step": step, "step": step,
"reward": mean_reward, "reward": mean_reward,
"sequence_length": mean_sequence_length, "sequence_length": mean_sequence_length,
}) }
)
# Update the model parameters # Update the model parameters
lrm = get_lr_multiplier(step) lrm = get_lr_multiplier(step)
@ -298,10 +307,12 @@ for step in range(num_steps):
for opt in optimizers: # then step the optimizers for opt in optimizers: # then step the optimizers
opt.step() opt.step()
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
wandb_run.log({ wandb_run.log(
{
"step": step, "step": step,
"lrm": lrm, "lrm": lrm,
}) }
)
# Master process saves the model once in a while. Skip first step. Save last step. # Master process saves the model once in a while. Skip first step. Save last step.
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1): if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
@ -317,15 +328,19 @@ for step in range(num_steps):
None, # note: we don't bother to save the optimizer state None, # note: we don't bother to save the optimizer state
{ {
"model_config": model_config_kwargs, "model_config": model_config_kwargs,
} },
) )
print(f"✅ Saved model checkpoint to {checkpoint_dir}") print(f"✅ Saved model checkpoint to {checkpoint_dir}")
# Log to report # Log to report
from nanochat.report import get_report from nanochat.report import get_report
get_report().log(section="Chat RL", data=[
get_report().log(
section="Chat RL",
data=[
user_config, # CLI args user_config, # CLI args
]) ],
)
wandb_run.finish() # wandb run finish wandb_run.finish() # wandb run finish
compute_cleanup() compute_cleanup()

View File

@ -10,24 +10,24 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
""" """
import os import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import wandb
import torch
import torch.distributed as dist
from contextlib import nullcontext from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type import torch
from nanochat.checkpoint_manager import load_model import torch.distributed as dist
from nanochat.checkpoint_manager import save_checkpoint import wandb
from nanochat.checkpoint_manager import load_model, save_checkpoint
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
from nanochat.engine import Engine from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval from scripts.chat_eval import run_chat_eval
from tasks.common import TaskMixture
from tasks.arc import ARC from tasks.arc import ARC
from tasks.common import TaskMixture
from tasks.customjson import CustomJSON
from tasks.gsm8k import GSM8K from tasks.gsm8k import GSM8K
from tasks.smoltalk import SmolTalk from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee from tasks.spellingbee import SimpleSpelling, SpellingBee
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -70,7 +70,11 @@ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if dev
# wandb logging init # wandb logging init
use_dummy_wandb = run == "dummy" or not master_process use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True) wandb_run = (
DummyWandb()
if use_dummy_wandb
else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
)
# Load the model and tokenizer # Load the model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step) model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
@ -81,7 +85,8 @@ engine = Engine(model, tokenizer) # will be used for inline model evaluation onl
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Task data mixture we'll train on # Task data mixture we'll train on
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl") identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
train_ds = TaskMixture([ train_ds = TaskMixture(
[
ARC(subset="ARC-Easy", split="train"), # 2.3K rows ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows GSM8K(subset="main", split="train"), # 8K rows
@ -89,14 +94,19 @@ train_ds = TaskMixture([
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple') SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows ]
) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# DataLoader # DataLoader
def sft_data_generator(dataset, batch_size): def sft_data_generator(dataset, batch_size):
pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss pad_token_id = tokenizer.encode_special(
"<|assistant_end|>"
) # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
# prepares a list of tokenized conversations into a batch and yields # prepares a list of tokenized conversations into a batch and yields
def collate_and_yield(batch): def collate_and_yield(batch):
nrows = len(batch) nrows = len(batch)
@ -116,6 +126,7 @@ def sft_data_generator(dataset, batch_size):
inputs = inputs.to(device) # move to device inputs = inputs.to(device) # move to device
targets = targets.to(device) targets = targets.to(device)
return inputs, targets return inputs, targets
# iterates over the dataset in epochs, tokenizes # iterates over the dataset in epochs, tokenizes
batch = [] batch = []
while True: while True:
@ -127,11 +138,14 @@ def sft_data_generator(dataset, batch_size):
yield collate_and_yield(batch) yield collate_and_yield(batch)
batch = [] batch = []
examples_per_step = device_batch_size * ddp_world_size examples_per_step = device_batch_size * ddp_world_size
print0(f"Target examples per step: {target_examples_per_step}") print0(f"Target examples per step: {target_examples_per_step}")
print0(f"Device batch size: {device_batch_size}") print0(f"Device batch size: {device_batch_size}")
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}") print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step" assert target_examples_per_step % examples_per_step == 0, (
"Target examples per step must be divisible by examples per step"
)
grad_accum_steps = target_examples_per_step // examples_per_step grad_accum_steps = target_examples_per_step // examples_per_step
print0(f"=> Setting grad accum steps: {grad_accum_steps}") print0(f"=> Setting grad accum steps: {grad_accum_steps}")
@ -160,11 +174,13 @@ for opt in optimizers:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Training loop # Training loop
# Learning rate scheduler # Learning rate scheduler
def get_lr_multiplier(it): def get_lr_multiplier(it):
lrm = 1.0 - it / num_iterations lrm = 1.0 - it / num_iterations
return lrm return lrm
# Go! # Go!
step = 0 step = 0
train_iter = iter(train_loader) train_iter = iter(train_loader)
@ -186,10 +202,12 @@ for step in range(num_iterations):
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
val_loss = val_loss.item() val_loss = val_loss.item()
print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}") print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}")
wandb_run.log({ wandb_run.log(
{
"step": step, "step": step,
"val_loss": val_loss, "val_loss": val_loss,
}) }
)
model.train() model.train()
# evaluate accuracy of the multiple choice tasks (which are quick to run) # evaluate accuracy of the multiple choice tasks (which are quick to run)
@ -198,14 +216,30 @@ for step in range(num_iterations):
metrics = {} metrics = {}
with torch.no_grad(), autocast_ctx: with torch.no_grad(), autocast_ctx:
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems) metrics["mmlu_acc"] = run_chat_eval(
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems) "MMLU",
model,
tokenizer,
engine,
batch_size=device_batch_size * 2,
max_problems=eval_metrics_max_problems,
)
metrics["arc_easy_acc"] = run_chat_eval(
"ARC-Easy",
model,
tokenizer,
engine,
batch_size=device_batch_size * 2,
max_problems=eval_metrics_max_problems,
)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}") print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({ wandb_run.log(
{
"step": step, "step": step,
**metrics, **metrics,
}) }
)
model.train() model.train()
if last_step: if last_step:
@ -238,13 +272,17 @@ for step in range(num_iterations):
# logging # logging
train_loss_item = train_loss.item() train_loss_item = train_loss.item()
num_tokens_item = num_tokens.item() num_tokens_item = num_tokens.item()
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}") print0(
wandb_run.log({ f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}"
)
wandb_run.log(
{
"step": step, "step": step,
"lrm": lrm, "lrm": lrm,
"train_loss": train_loss_item, "train_loss": train_loss_item,
"num_tokens": num_tokens_item, "num_tokens": num_tokens_item,
}) }
)
step += 1 step += 1
# Save the model at the end of the run # Save the model at the end of the run
@ -264,13 +302,16 @@ if master_process:
"val_loss": val_loss, "val_loss": val_loss,
**metrics, **metrics,
"model_config": model_config_kwargs, "model_config": model_config_kwargs,
} },
) )
print(f"✅ Saved model checkpoint to {checkpoint_dir}") print(f"✅ Saved model checkpoint to {checkpoint_dir}")
# Log to report # Log to report
from nanochat.report import get_report from nanochat.report import get_report
get_report().log(section="Chat SFT", data=[
get_report().log(
section="Chat SFT",
data=[
user_config, # CLI args user_config, # CLI args
{ {
"Training rows": len(train_ds), "Training rows": len(train_ds),
@ -278,7 +319,8 @@ get_report().log(section="Chat SFT", data=[
"Training loss": train_loss_item, "Training loss": train_loss_item,
"Validation loss": val_loss, "Validation loss": val_loss,
}, },
]) ],
)
# Cleanup # Cleanup
wandb_run.finish() wandb_run.finish()

View File

@ -31,22 +31,23 @@ Abuse Prevention:
""" """
import argparse import argparse
import json
import os
import torch
import asyncio import asyncio
import json
import logging import logging
import os
import random import random
from contextlib import asynccontextmanager from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager, nullcontext
from dataclasses import dataclass
import torch
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass
from contextlib import nullcontext
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.common import autodetect_device_type, compute_init
from nanochat.engine import Engine from nanochat.engine import Engine
# Abuse prevention limits # Abuse prevention limits
@ -70,52 +71,56 @@ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') parser.add_argument(
'--device-type',
type=str,
default='',
choices=['cuda', 'cpu', 'mps'],
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
)
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args() args = parser.parse_args()
# Configure logging for conversation traffic # Configure logging for conversation traffic
logging.basicConfig( logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
device_type = autodetect_device_type() if args.device_type == "" else args.device_type device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
@dataclass @dataclass
class Worker: class Worker:
"""A worker with a model loaded on a specific GPU.""" """A worker with a model loaded on a specific GPU."""
gpu_id: int gpu_id: int
device: torch.device device: torch.device
engine: Engine engine: Engine
tokenizer: object tokenizer: object
autocast_ctx: torch.amp.autocast autocast_ctx: torch.amp.autocast
class WorkerPool: class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU.""" """Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None): def __init__(self, num_gpus: int | None = None):
if num_gpus is None: if num_gpus is None:
if device_type == "cuda": if device_type == "cuda":
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
else: else:
num_gpus = 1 # e.g. cpu|mps num_gpus = 1 # e.g. cpu|mps
self.num_gpus = num_gpus self.num_gpus = num_gpus
self.workers: List[Worker] = [] self.workers: list[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue() self.available_workers: asyncio.Queue = asyncio.Queue()
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None): async def initialize(self, source: str, model_tag: str | None = None, step: int | None = None):
"""Load model on each GPU.""" """Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...") print(f"Initializing worker pool with {self.num_gpus} GPUs...")
if self.num_gpus > 1: if self.num_gpus > 1:
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not." assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
for gpu_id in range(self.num_gpus): for gpu_id in range(self.num_gpus):
if device_type == "cuda": if device_type == "cuda":
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...") print(f"Loading model on GPU {gpu_id}...")
@ -125,15 +130,11 @@ class WorkerPool:
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer) engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
worker = Worker(
gpu_id=gpu_id,
device=device,
engine=engine,
tokenizer=tokenizer,
autocast_ctx=autocast_ctx
) )
worker = Worker(gpu_id=gpu_id, device=device, engine=engine, tokenizer=tokenizer, autocast_ctx=autocast_ctx)
self.workers.append(worker) self.workers.append(worker)
await self.available_workers.put(worker) await self.available_workers.put(worker)
@ -147,15 +148,18 @@ class WorkerPool:
"""Return a worker to the pool.""" """Return a worker to the pool."""
await self.available_workers.put(worker) await self.available_workers.put(worker)
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: str role: str
content: str content: str
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
messages: List[ChatMessage] messages: list[ChatMessage]
temperature: Optional[float] = None temperature: float | None = None
max_tokens: Optional[int] = None max_tokens: int | None = None
top_k: Optional[int] = None top_k: int | None = None
def validate_chat_request(request: ChatRequest): def validate_chat_request(request: ChatRequest):
"""Validate chat request to prevent abuse.""" """Validate chat request to prevent abuse."""
@ -165,7 +169,7 @@ def validate_chat_request(request: ChatRequest):
if len(request.messages) > MAX_MESSAGES_PER_REQUEST: if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request" detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request",
) )
# Check individual message lengths and total conversation length # Check individual message lengths and total conversation length
@ -178,48 +182,43 @@ def validate_chat_request(request: ChatRequest):
if msg_length > MAX_MESSAGE_LENGTH: if msg_length > MAX_MESSAGE_LENGTH:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message" detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message",
) )
total_length += msg_length total_length += msg_length
if total_length > MAX_TOTAL_CONVERSATION_LENGTH: if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed" detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed",
) )
# Validate role values # Validate role values
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if message.role not in ["user", "assistant"]: if message.role not in ["user", "assistant"]:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
) )
# Validate temperature # Validate temperature
if request.temperature is not None: if request.temperature is not None:
if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE): if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
) )
# Validate top_k # Validate top_k
if request.top_k is not None: if request.top_k is not None:
if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K): if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
raise HTTPException( raise HTTPException(status_code=400, detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}")
status_code=400,
detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
)
# Validate max_tokens # Validate max_tokens
if request.max_tokens is not None: if request.max_tokens is not None:
if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS): if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
) )
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Load models on all GPUs on startup.""" """Load models on all GPUs on startup."""
@ -229,6 +228,7 @@ async def lifespan(app: FastAPI):
print(f"Server ready at http://localhost:{args.port}") print(f"Server ready at http://localhost:{args.port}")
yield yield
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware( app.add_middleware(
@ -239,16 +239,16 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
@app.get("/") @app.get("/")
async def root(): async def root():
"""Serve the chat UI.""" """Serve the chat UI."""
ui_html_path = os.path.join("nanochat", "ui.html") ui_html_path = os.path.join("nanochat", "ui.html")
with open(ui_html_path, "r", encoding="utf-8") as f: with open(ui_html_path, encoding="utf-8") as f:
html_content = f.read() html_content = f.read()
# Replace the API_URL to use the same origin # Replace the API_URL to use the same origin
html_content = html_content.replace( html_content = html_content.replace(
"const API_URL = `http://${window.location.hostname}:8000`;", "const API_URL = `http://${window.location.hostname}:8000`;", "const API_URL = '';"
"const API_URL = '';"
) )
return HTMLResponse(content=html_content) return HTMLResponse(content=html_content)
@ -259,12 +259,9 @@ async def logo():
logo_path = os.path.join("nanochat", "logo.svg") logo_path = os.path.join("nanochat", "logo.svg")
return FileResponse(logo_path, media_type="image/svg+xml") return FileResponse(logo_path, media_type="image/svg+xml")
async def generate_stream( async def generate_stream(
worker: Worker, worker: Worker, tokens, temperature=None, max_new_tokens=None, top_k=None
tokens,
temperature=None,
max_new_tokens=None,
top_k=None
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""Generate assistant response with streaming.""" """Generate assistant response with streaming."""
temperature = temperature if temperature is not None else args.temperature temperature = temperature if temperature is not None else args.temperature
@ -286,7 +283,7 @@ async def generate_stream(
max_tokens=max_new_tokens, max_tokens=max_new_tokens,
temperature=temperature, temperature=temperature,
top_k=top_k, top_k=top_k,
seed=random.randint(0, 2**31 - 1) seed=random.randint(0, 2**31 - 1),
): ):
token = token_column[0] token = token_column[0]
@ -310,6 +307,7 @@ async def generate_stream(
yield f"data: {json.dumps({'done': True})}\n\n" yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/chat/completions") @app.post("/chat/completions")
async def chat_completions(request: ChatRequest): async def chat_completions(request: ChatRequest):
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU.""" """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
@ -350,6 +348,7 @@ async def chat_completions(request: ChatRequest):
# Streaming response with worker release after completion # Streaming response with worker release after completion
response_tokens = [] response_tokens = []
async def stream_and_release(): async def stream_and_release():
try: try:
async for chunk in generate_stream( async for chunk in generate_stream(
@ -357,7 +356,7 @@ async def chat_completions(request: ChatRequest):
conversation_tokens, conversation_tokens,
temperature=request.temperature, temperature=request.temperature,
max_new_tokens=request.max_tokens, max_new_tokens=request.max_tokens,
top_k=request.top_k top_k=request.top_k,
): ):
# Accumulate response for logging # Accumulate response for logging
chunk_data = json.loads(chunk.replace("data: ", "").strip()) chunk_data = json.loads(chunk.replace("data: ", "").strip())
@ -372,15 +371,13 @@ async def chat_completions(request: ChatRequest):
# Release worker back to pool after streaming is done # Release worker back to pool after streaming is done
await worker_pool.release_worker(worker) await worker_pool.release_worker(worker)
return StreamingResponse( return StreamingResponse(stream_and_release(), media_type="text/event-stream")
stream_and_release(),
media_type="text/event-stream"
)
except Exception as e: except Exception as e:
# Make sure to release worker even on error # Make sure to release worker even on error
await worker_pool.release_worker(worker) await worker_pool.release_worker(worker)
raise e raise e
@app.get("/health") @app.get("/health")
async def health(): async def health():
"""Health check endpoint.""" """Health check endpoint."""
@ -389,9 +386,10 @@ async def health():
"status": "ok", "status": "ok",
"ready": worker_pool is not None and len(worker_pool.workers) > 0, "ready": worker_pool is not None and len(worker_pool.workers) > 0,
"num_gpus": worker_pool.num_gpus if worker_pool else 0, "num_gpus": worker_pool.num_gpus if worker_pool else 0,
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0 "available_workers": worker_pool.available_workers.qsize() if worker_pool else 0,
} }
@app.get("/stats") @app.get("/stats")
async def stats(): async def stats():
"""Get worker pool statistics.""" """Get worker pool statistics."""
@ -400,16 +398,13 @@ async def stats():
"total_workers": len(worker_pool.workers), "total_workers": len(worker_pool.workers),
"available_workers": worker_pool.available_workers.qsize(), "available_workers": worker_pool.available_workers.qsize(),
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(), "busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
"workers": [ "workers": [{"gpu_id": w.gpu_id, "device": str(w.device)} for w in worker_pool.workers],
{
"gpu_id": w.gpu_id,
"device": str(w.device)
} for w in worker_pool.workers
]
} }
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
print(f"Starting NanoChat Web Server")
print("Starting NanoChat Web Server")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}") print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app, host=args.host, port=args.port)

View File

@ -9,25 +9,26 @@ Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16 torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
""" """
from collections import deque
import os import os
from collections import deque
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time import time
import wandb
import torch
from contextlib import nullcontext from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.checkpoint_manager import load_model
import torch.distributed as dist
import torch
import torch.distributed as dist
import wandb
from nanochat.checkpoint_manager import load_model, save_checkpoint
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
from nanochat.loss_eval import evaluate_bpb
from nanochat.tokenizer import get_token_bytes
from tasks.common import TaskMixture from tasks.common import TaskMixture
from tasks.customjson import CustomJSON
from tasks.gsm8k import GSM8K from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee from tasks.spellingbee import SimpleSpelling, SpellingBee
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -57,7 +58,9 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
device_type = autodetect_device_type() if device_type == "" else device_type device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 master_process = ddp_rank == 0
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
@ -69,7 +72,9 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step) model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
pretrain_batch_size = meta.get("device_batch_size", None) pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size: if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?") print0(
f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?"
)
orig_model = model orig_model = model
model = torch.compile(model, dynamic=False) model = torch.compile(model, dynamic=False)
depth = model.config.n_layer depth = model.config.n_layer
@ -84,7 +89,9 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {
token_bytes = get_token_bytes(device=device) token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay
)
adamw_optimizer, muon_optimizer = optimizers adamw_optimizer, muon_optimizer = optimizers
# Override the initial learning rate as a fraction of the base learning rate # Override the initial learning rate as a fraction of the base learning rate
for opt in optimizers: for opt in optimizers:
@ -95,25 +102,33 @@ for opt in optimizers:
# Midtraining data mixture and DataLoader # Midtraining data mixture and DataLoader
base_dir = get_base_dir() base_dir = get_base_dir()
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([ train_dataset = TaskMixture(
[
SmolTalk(split="train"), # 460K rows of general conversations SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE MMLU(
subset="auxiliary_train", split="train"
), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows ]
val_dataset = TaskMixture([ ) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
val_dataset = TaskMixture(
[
SmolTalk(split="test"), # 24K rows in test set SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
]) # total: 24K + 14K + 1.32K ~= 39K rows ]
) # total: 24K + 14K + 1.32K ~= 39K rows
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
# A big problem is that we don't know the final num_iterations in advance. So we create # A big problem is that we don't know the final num_iterations in advance. So we create
# these two global variables and update them from within the data generator. # these two global variables and update them from within the data generator.
last_step = False # we will toggle this to True when we reach the end of the dataset last_step = False # we will toggle this to True when we reach the end of the dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
def mid_data_generator(split): def mid_data_generator(split):
global last_step, approx_progress global last_step, approx_progress
assert split in {"train", "val"}, "split must be 'train' or 'val'" assert split in {"train", "val"}, "split must be 'train' or 'val'"
@ -147,7 +162,9 @@ def mid_data_generator(split):
inputs_cpu = scratch[:-1].to(dtype=torch.int32) inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:] targets_cpu = scratch[1:]
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True) inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True) targets = targets_cpu.view(device_batch_size, max_seq_len).to(
device=device, dtype=torch.int64, non_blocking=True
)
if split == "train": if split == "train":
if num_iterations > 0: if num_iterations > 0:
approx_progress = it / num_iterations # calculate progress from the max number of iterations approx_progress = it / num_iterations # calculate progress from the max number of iterations
@ -155,21 +172,25 @@ def mid_data_generator(split):
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
yield inputs, targets yield inputs, targets
train_loader = mid_data_generator("train") train_loader = mid_data_generator("train")
build_val_loader = lambda: mid_data_generator("val") build_val_loader = lambda: mid_data_generator("val")
progress = 0 # will go from 0 to 1 over the course of the epoch progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler # Learning rate scheduler
def get_lr_multiplier(progress): def get_lr_multiplier(progress):
# first 80% of training: no decay, then linearly ramp down to 0. # first 80% of training: no decay, then linearly ramp down to 0.
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
# Momentum scheduler for Muon optimizer # Momentum scheduler for Muon optimizer
def get_muon_momentum(it): def get_muon_momentum(it):
frac = min(it / 300, 1) frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95 momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum return momentum
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Training loop # Training loop
x, y = next(train_loader) # prefetch the very first batch of data x, y = next(train_loader) # prefetch the very first batch of data
@ -197,12 +218,14 @@ while True:
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb: if val_bpb < min_val_bpb:
min_val_bpb = val_bpb min_val_bpb = val_bpb
wandb_run.log({ wandb_run.log(
{
"step": step, "step": step,
"total_training_flops": flops_so_far, "total_training_flops": flops_so_far,
"total_training_time": total_training_time, "total_training_time": total_training_time,
"val/bpb": val_bpb, "val/bpb": val_bpb,
}) }
)
model.train() model.train()
# save checkpoint at the end of the run (only on master process) # save checkpoint at the end of the run (only on master process)
@ -226,7 +249,7 @@ while True:
"n_embd": model.config.n_embd, "n_embd": model.config.n_embd,
}, },
"user_config": user_config, # inputs to the training script "user_config": user_config, # inputs to the training script
} },
) )
if last_step: if last_step:
@ -274,9 +297,12 @@ while True:
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10: if step > 10:
total_training_time += dt # only count the time after the first 10 steps total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") print0(
f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time / 60:.2f}m"
)
if step % 10 == 0: if step % 10 == 0:
wandb_run.log({ wandb_run.log(
{
"step": step, "step": step,
"total_training_flops": flops_so_far, "total_training_flops": flops_so_far,
"total_training_time": total_training_time, "total_training_time": total_training_time,
@ -285,7 +311,8 @@ while True:
"train/dt": dt, "train/dt": dt,
"train/tok_per_sec": tok_per_sec, "train/tok_per_sec": tok_per_sec,
"train/mfu": mfu, "train/mfu": mfu,
}) }
)
# print a few more stats # print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
@ -295,7 +322,10 @@ print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report # Log to report
if not dry_run: if not dry_run:
from nanochat.report import get_report from nanochat.report import get_report
get_report().log(section="Midtraining", data=[
get_report().log(
section="Midtraining",
data=[
user_config, # CLI args user_config, # CLI args
{ # stats about the training setup { # stats about the training setup
"Number of iterations": step, "Number of iterations": step,
@ -303,8 +333,9 @@ if not dry_run:
}, },
{ # stats about training outcomes { # stats about training outcomes
"Minimum validation bpb": min_val_bpb, "Minimum validation bpb": min_val_bpb,
} },
]) ],
)
# cleanup # cleanup
wandb_run.finish() # wandb run finish wandb_run.finish() # wandb run finish

View File

@ -2,8 +2,8 @@
Evaluate compression ratio of the tokenizer. Evaluate compression ratio of the tokenizer.
""" """
from nanochat.tokenizer import get_tokenizer, RustBPETokenizer
from nanochat.dataset import parquets_iter_batched from nanochat.dataset import parquets_iter_batched
from nanochat.tokenizer import RustBPETokenizer, get_tokenizer
# Random text I got from a random website this morning # Random text I got from a random website this morning
news_text = r""" news_text = r"""
@ -165,7 +165,6 @@ tokenizer_results = {}
vocab_sizes = {} vocab_sizes = {}
for tokenizer_name in ["gpt2", "gpt4", "ours"]: for tokenizer_name in ["gpt2", "gpt4", "ours"]:
if tokenizer_name == "gpt2": if tokenizer_name == "gpt2":
tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer
elif tokenizer_name == "gpt4": elif tokenizer_name == "gpt4":
@ -183,11 +182,7 @@ for tokenizer_name in ["gpt2", "gpt4", "ours"]:
encoded_bytes = text.encode('utf-8') encoded_bytes = text.encode('utf-8')
ratio = len(encoded_bytes) / len(encoded) ratio = len(encoded_bytes) / len(encoded)
tokenizer_results[tokenizer_name][name] = { tokenizer_results[tokenizer_name][name] = {'bytes': len(encoded_bytes), 'tokens': len(encoded), 'ratio': ratio}
'bytes': len(encoded_bytes),
'tokens': len(encoded),
'ratio': ratio
}
# ANSI color codes # ANSI color codes
GREEN = '\033[92m' GREEN = '\033[92m'
@ -195,11 +190,12 @@ RED = '\033[91m'
RESET = '\033[0m' RESET = '\033[0m'
# Print vocab sizes # Print vocab sizes
print(f"\nVocab sizes:") print("\nVocab sizes:")
print(f"GPT-2: {vocab_sizes['gpt2']}") print(f"GPT-2: {vocab_sizes['gpt2']}")
print(f"GPT-4: {vocab_sizes['gpt4']}") print(f"GPT-4: {vocab_sizes['gpt4']}")
print(f"Ours: {vocab_sizes['ours']}") print(f"Ours: {vocab_sizes['ours']}")
def print_comparison(baseline_name, baseline_results, ours_results, all_text): def print_comparison(baseline_name, baseline_results, ours_results, all_text):
"""Print comparison table between baseline tokenizer and ours.""" """Print comparison table between baseline tokenizer and ours."""
print(f"\nComparison with {baseline_name}:") print(f"\nComparison with {baseline_name}:")
@ -230,13 +226,16 @@ def print_comparison(baseline_name, baseline_results, ours_results, all_text):
better = "Tie" better = "Tie"
diff_color = "" diff_color = ""
print(f"{name:<10} {baseline_data['bytes']:<8} " print(
f"{name:<10} {baseline_data['bytes']:<8} "
f"{baseline_color}{baseline_data['tokens']:<7}{RESET} " f"{baseline_color}{baseline_data['tokens']:<7}{RESET} "
f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} " f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} "
f"{ours_color}{ours_data['tokens']:<7}{RESET} " f"{ours_color}{ours_data['tokens']:<7}{RESET} "
f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} " f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} "
f"{diff_color}{relative_diff:+7.1f}%{RESET} " f"{diff_color}{relative_diff:+7.1f}%{RESET} "
f"{better:<10}") f"{better:<10}"
)
# Print comparisons # Print comparisons
print_comparison("GPT-2", tokenizer_results['gpt2'], tokenizer_results['ours'], all_text) print_comparison("GPT-2", tokenizer_results['gpt2'], tokenizer_results['ours'], all_text)
@ -244,6 +243,7 @@ print_comparison("GPT-4", tokenizer_results['gpt4'], tokenizer_results['ours'],
# Log to report # Log to report
from nanochat.report import get_report from nanochat.report import get_report
lines = [] lines = []
for baseline_name in ["GPT-2", "GPT-4"]: for baseline_name in ["GPT-2", "GPT-4"]:
baseline_key = baseline_name.lower().replace('-', '') baseline_key = baseline_name.lower().replace('-', '')
@ -251,15 +251,26 @@ for baseline_name in ["GPT-2", "GPT-4"]:
ours_results = tokenizer_results['ours'] ours_results = tokenizer_results['ours']
lines.append(f"### Comparison with {baseline_name}") lines.append(f"### Comparison with {baseline_name}")
lines.append("") lines.append("")
lines.append("| Text Type | Bytes | " + baseline_name + " Tokens | " + baseline_name + " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |") lines.append(
"| Text Type | Bytes | "
+ baseline_name
+ " Tokens | "
+ baseline_name
+ " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |"
)
lines.append("|-----------|-------|--------------|--------------|-------------|------------|-----------------|") lines.append("|-----------|-------|--------------|--------------|-------------|------------|-----------------|")
for name, text in all_text: for name, text in all_text:
baseline_data = baseline_results[name] baseline_data = baseline_results[name]
ours_data = ours_results[name] ours_data = ours_results[name]
relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100 relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100
lines.append(f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |") lines.append(
f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |"
)
lines.append("") lines.append("")
report_markdown = "\n".join(lines) report_markdown = "\n".join(lines)
get_report().log(section="Tokenizer evaluation", data=[ get_report().log(
section="Tokenizer evaluation",
data=[
report_markdown, report_markdown,
]) ],
)

View File

@ -2,19 +2,24 @@
Train a tokenizer using the HuggingFace Tokenizers library. Train a tokenizer using the HuggingFace Tokenizers library.
In the style of GPT-4 tokenizer. In the style of GPT-4 tokenizer.
""" """
import argparse
import os import os
import time import time
import argparse
import torch import torch
from nanochat.tokenizer import RustBPETokenizer
from nanochat.common import get_base_dir from nanochat.common import get_base_dir
from nanochat.dataset import parquets_iter_batched from nanochat.dataset import parquets_iter_batched
from nanochat.tokenizer import RustBPETokenizer
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Parse command line arguments # Parse command line arguments
parser = argparse.ArgumentParser(description='Train a BPE tokenizer') parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)') parser.add_argument(
'--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)'
)
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)') parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
args = parser.parse_args() args = parser.parse_args()
@ -25,6 +30,7 @@ print(f"vocab_size: {args.vocab_size:,}")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Text iterator # Text iterator
def text_iterator(): def text_iterator():
""" """
1) Flatten the batches into a single iterator 1) Flatten the batches into a single iterator
@ -41,6 +47,8 @@ def text_iterator():
yield doc_text yield doc_text
if nchars > args.max_chars: if nchars > args.max_chars:
return return
text_iter = text_iterator() text_iter = text_iterator()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -92,8 +100,11 @@ print(f"Saved token_bytes to {token_bytes_path}")
# Log to report # Log to report
from nanochat.report import get_report from nanochat.report import get_report
token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32) token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32)
get_report().log(section="Tokenizer training", data=[ get_report().log(
section="Tokenizer training",
data=[
vars(args), # argparse command line arguments vars(args), # argparse command line arguments
{"train_time": train_time}, {"train_time": train_time},
{"num_special_tokens": len(special_set)}, {"num_special_tokens": len(special_set)},
@ -102,5 +113,6 @@ get_report().log(section="Tokenizer training", data=[
"token_bytes_max": int(token_bytes_nonzero.max().item()), "token_bytes_max": int(token_bytes_nonzero.max().item()),
"token_bytes_mean": token_bytes_nonzero.mean().item(), "token_bytes_mean": token_bytes_nonzero.mean().item(),
"token_bytes_std": token_bytes_nonzero.std().item(), "token_bytes_std": token_bytes_nonzero.std().item(),
} },
]) ],
)

View File

@ -4,10 +4,11 @@ https://huggingface.co/datasets/allenai/ai2_arc
""" """
from datasets import load_dataset from datasets import load_dataset
from tasks.common import Task, render_mc from tasks.common import Task, render_mc
class ARC(Task):
class ARC(Task):
def __init__(self, subset, split, **kwargs): def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge" assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge"
@ -30,10 +31,7 @@ class ARC(Task):
assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check
# create and return the Conversation object # create and return the Conversation object
user_message = render_mc(question, letters, choices) user_message = render_mc(question, letters, choices)
messages = [ messages = [{"role": "user", "content": user_message}, {"role": "assistant", "content": answer_string}]
{"role": "user", "content": user_message},
{"role": "assistant", "content": answer_string}
]
conversation = { conversation = {
"messages": messages, "messages": messages,
"letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters "letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
@ -43,6 +41,8 @@ class ARC(Task):
def evaluate(self, conversation, assistant_response): def evaluate(self, conversation, assistant_response):
# the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
# I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it. # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
assert assistant_response in conversation['letters'], f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}" assert assistant_response in conversation['letters'], (
f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}"
)
assistant_message = conversation['messages'][-1]['content'] # e.g. "A" assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
return assistant_response == assistant_message return assistant_response == assistant_message

View File

@ -7,6 +7,7 @@ Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk.
import random import random
class Task: class Task:
""" """
Base class of a Task. Allows for lightweight slicing of the underlying dataset. Base class of a Task. Allows for lightweight slicing of the underlying dataset.
@ -81,7 +82,9 @@ class TaskMixture(Task):
Access conversations according to a deterministic shuffle of all examples. Access conversations according to a deterministic shuffle of all examples.
This ensures tasks are mixed throughout training, regardless of dataset size. This ensures tasks are mixed throughout training, regardless of dataset size.
""" """
assert 0 <= index < self.num_conversations, f"Index {index} out of range for mixture with {self.num_conversations} conversations" assert 0 <= index < self.num_conversations, (
f"Index {index} out of range for mixture with {self.num_conversations} conversations"
)
task_idx, local_idx = self.index_map[index] task_idx, local_idx = self.index_map[index]
return self.tasks[task_idx][local_idx] return self.tasks[task_idx][local_idx]
@ -102,7 +105,9 @@ class TaskSequence(Task):
return self.num_conversations return self.num_conversations
def get_example(self, index): def get_example(self, index):
assert 0 <= index < self.num_conversations, f"Index {index} out of range for sequence with {self.num_conversations} conversations" assert 0 <= index < self.num_conversations, (
f"Index {index} out of range for sequence with {self.num_conversations} conversations"
)
for task_idx, task_length in enumerate(self.lengths): for task_idx, task_length in enumerate(self.lengths):
if index < task_length: if index < task_length:
return self.tasks[task_idx][index] return self.tasks[task_idx][index]

View File

@ -3,10 +3,12 @@ CustomJSON task for loading conversations from JSONL files.
Each line in the JSONL file should be a JSON array of messages. Each line in the JSONL file should be a JSON array of messages.
""" """
import os
import json import json
import os
from tasks.common import Task from tasks.common import Task
class CustomJSON(Task): class CustomJSON(Task):
""" """
Load conversations from a JSONL file. Load conversations from a JSONL file.
@ -25,14 +27,18 @@ class CustomJSON(Task):
print("-" * 80) print("-" * 80)
print(f"Warning: File {filepath} does not exist") print(f"Warning: File {filepath} does not exist")
print("HINT (Oct 21 2025)") print("HINT (Oct 21 2025)")
print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations") print(
"If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations"
)
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139") print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
print("Quick fix: simply run the following command to download the file and you're done:") print("Quick fix: simply run the following command to download the file and you're done:")
print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl") print(
f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl"
)
print("-" * 80) print("-" * 80)
else: else:
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, encoding='utf-8') as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if not line: # skip empty lines if not line: # skip empty lines
@ -46,7 +52,9 @@ class CustomJSON(Task):
assert "role" in message, f"Message {i} missing 'role' field" assert "role" in message, f"Message {i} missing 'role' field"
assert "content" in message, f"Message {i} missing 'content' field" assert "content" in message, f"Message {i} missing 'content' field"
expected_role = "user" if i % 2 == 0 else "assistant" expected_role = "user" if i % 2 == 0 else "assistant"
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" assert message["role"] == expected_role, (
f"Message {i} has role {message['role']} but should be {expected_role}"
)
assert isinstance(message["content"], str), f"Message {i} content must be a string" assert isinstance(message["content"], str), f"Message {i} content must be a string"
self.conversations.append(messages) self.conversations.append(messages)
@ -62,4 +70,3 @@ class CustomJSON(Task):
"messages": messages, "messages": messages,
} }
return conversation return conversation

View File

@ -15,11 +15,14 @@ Notice that GSM8K uses tool calls inside << >> tags.
""" """
import re import re
from datasets import load_dataset from datasets import load_dataset
from tasks.common import Task from tasks.common import Task
GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)") GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion): def extract_answer(completion):
""" """
Extract the numerical answer after #### marker. Extract the numerical answer after #### marker.
@ -35,7 +38,6 @@ def extract_answer(completion):
class GSM8K(Task): class GSM8K(Task):
def __init__(self, subset, split, **kwargs): def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic" assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic"

View File

@ -5,10 +5,13 @@ It is a coding benchmark.
""" """
import re import re
from datasets import load_dataset from datasets import load_dataset
from nanochat.execution import execute_code from nanochat.execution import execute_code
from tasks.common import Task from tasks.common import Task
def extract_imports(prompt): def extract_imports(prompt):
"""Extract import statements from the beginning of a code block.""" """Extract import statements from the beginning of a code block."""
imports = [] imports = []
@ -21,6 +24,7 @@ def extract_imports(prompt):
break break
return '\n'.join(imports) return '\n'.join(imports)
def extract_program(completion): def extract_program(completion):
""" """
Extract Python code from LLM completion. Extract Python code from LLM completion.
@ -44,8 +48,8 @@ def extract_program(completion):
# No code blocks found, return the whole completion # No code blocks found, return the whole completion
return completion.strip() return completion.strip()
class HumanEval(Task):
class HumanEval(Task):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.ds = load_dataset("openai/openai_humaneval", split="test").shuffle(seed=42) self.ds = load_dataset("openai/openai_humaneval", split="test").shuffle(seed=42)

View File

@ -4,12 +4,71 @@ https://huggingface.co/datasets/cais/mmlu
""" """
from datasets import load_dataset from datasets import load_dataset
from tasks.common import Task, render_mc from tasks.common import Task, render_mc
class MMLU(Task):
class MMLU(Task):
letters = ('A', 'B', 'C', 'D') letters = ('A', 'B', 'C', 'D')
groups = ('abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions') groups = (
'abstract_algebra',
'anatomy',
'astronomy',
'business_ethics',
'clinical_knowledge',
'college_biology',
'college_chemistry',
'college_computer_science',
'college_mathematics',
'college_medicine',
'college_physics',
'computer_security',
'conceptual_physics',
'econometrics',
'electrical_engineering',
'elementary_mathematics',
'formal_logic',
'global_facts',
'high_school_biology',
'high_school_chemistry',
'high_school_computer_science',
'high_school_european_history',
'high_school_geography',
'high_school_government_and_politics',
'high_school_macroeconomics',
'high_school_mathematics',
'high_school_microeconomics',
'high_school_physics',
'high_school_psychology',
'high_school_statistics',
'high_school_us_history',
'high_school_world_history',
'human_aging',
'human_sexuality',
'international_law',
'jurisprudence',
'logical_fallacies',
'machine_learning',
'management',
'marketing',
'medical_genetics',
'miscellaneous',
'moral_disputes',
'moral_scenarios',
'nutrition',
'philosophy',
'prehistory',
'professional_accounting',
'professional_law',
'professional_medicine',
'professional_psychology',
'public_relations',
'security_studies',
'sociology',
'us_foreign_policy',
'virology',
'world_religions',
)
def __init__(self, subset, split, **kwargs): def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -41,10 +100,7 @@ class MMLU(Task):
# create and return the Conversation object # create and return the Conversation object
user_message = render_mc(question, self.letters, choices) user_message = render_mc(question, self.letters, choices)
assistant_message = self.letters[answer] assistant_message = self.letters[answer]
messages = [ messages = [{"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_message}]
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_message}
]
conversation = { conversation = {
"messages": messages, "messages": messages,
"subject": subject, # might be useful later for grouping metrics by subject "subject": subject, # might be useful later for grouping metrics by subject
@ -55,6 +111,8 @@ class MMLU(Task):
def evaluate(self, conversation, assistant_response): def evaluate(self, conversation, assistant_response):
# the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
# I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it. # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
assert assistant_response in self.letters, f"MMLU answer {assistant_response} is expected to be one of {self.letters}" assert assistant_response in self.letters, (
f"MMLU answer {assistant_response} is expected to be one of {self.letters}"
)
assistant_message = conversation['messages'][-1]['content'] # e.g. "A" assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
return assistant_response == assistant_message return assistant_response == assistant_message

View File

@ -5,8 +5,10 @@ We use the "smol" version, which is more appropriate for smaller models.
""" """
from datasets import load_dataset from datasets import load_dataset
from tasks.common import Task from tasks.common import Task
class SmolTalk(Task): class SmolTalk(Task):
"""smol-smoltalk dataset. train is 460K rows, test is 24K rows.""" """smol-smoltalk dataset. train is 460K rows, test is 24K rows."""
@ -36,7 +38,9 @@ class SmolTalk(Task):
for i, message in enumerate(rest_messages): for i, message in enumerate(rest_messages):
# user and assistant alternate as user,assistant,user,assistant,... # user and assistant alternate as user,assistant,user,assistant,...
expected_role = "user" if i % 2 == 0 else "assistant" expected_role = "user" if i % 2 == 0 else "assistant"
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" assert message["role"] == expected_role, (
f"Message {i} has role {message['role']} but should be {expected_role}"
)
assert isinstance(message["content"], str), "Content must be a string" assert isinstance(message["content"], str), "Content must be a string"
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
# create and return the Conversation object (ok to emit the system message too) # create and return the Conversation object (ok to emit the system message too)

View File

@ -26,10 +26,11 @@ To preview a few example conversations, run:
python -m tasks.spellingbee python -m tasks.spellingbee
""" """
import re
import random import random
from tasks.common import Task import re
from nanochat.common import download_file_with_lock from nanochat.common import download_file_with_lock
from tasks.common import Task
# Letters of the alphabet # Letters of the alphabet
LETTERS = "abcdefghijklmnopqrstuvwxyz" LETTERS = "abcdefghijklmnopqrstuvwxyz"
@ -38,6 +39,8 @@ WORD_LIST_URL = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads
# Identical to gsm8k's answer extraction # Identical to gsm8k's answer extraction
ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)") ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion): def extract_answer(completion):
""" """
Extract the numerical answer after #### marker. Extract the numerical answer after #### marker.
@ -49,6 +52,7 @@ def extract_answer(completion):
return match_str return match_str
return None return None
# User message templates for data augmentation # User message templates for data augmentation
USER_MSG_TEMPLATES = [ USER_MSG_TEMPLATES = [
"How many {letter} are in the word {word}", "How many {letter} are in the word {word}",
@ -110,8 +114,8 @@ USER_MSG_TEMPLATES = [
"{word}{letter}が何回出てくる", "{word}{letter}が何回出てくる",
] ]
class SpellingBee(Task):
class SpellingBee(Task):
def __init__(self, size=1000, split="train", **kwargs): def __init__(self, size=1000, split="train", **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
assert split in ["train", "test"], "SpellingBee split must be train|test" assert split in ["train", "test"], "SpellingBee split must be train|test"
@ -119,7 +123,7 @@ class SpellingBee(Task):
self.split = split self.split = split
filename = WORD_LIST_URL.split("/")[-1] filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename) word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path, 'r', encoding='utf-8') as f: with open(word_list_path, encoding='utf-8') as f:
words = [line.strip() for line in f] words = [line.strip() for line in f]
self.words = words self.words = words
@ -190,13 +194,12 @@ Then count the occurrences of '{letter}':
# Part 4: Python output # Part 4: Python output
assistant_parts.append({"type": "python_output", "text": str(count)}) assistant_parts.append({"type": "python_output", "text": str(count)})
# Part 5: Final answer # Part 5: Final answer
assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"}) assistant_parts.append(
{"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"}
)
# return the full conversation # return the full conversation
messages = [ messages = [{"role": "user", "content": user_msg}, {"role": "assistant", "content": assistant_parts}]
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_parts}
]
conversation = { conversation = {
"messages": messages, "messages": messages,
} }
@ -238,7 +241,7 @@ class SimpleSpelling(Task):
self.split = split self.split = split
filename = WORD_LIST_URL.split("/")[-1] filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename) word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path, 'r', encoding='utf-8') as f: with open(word_list_path, encoding='utf-8') as f:
words = [line.strip() for line in f] words = [line.strip() for line in f]
rng = random.Random(42) rng = random.Random(42)
rng.shuffle(words) # use a different word order than the SpellingBee task rng.shuffle(words) # use a different word order than the SpellingBee task
@ -260,7 +263,7 @@ class SimpleSpelling(Task):
# return the full conversation # return the full conversation
messages = [ messages = [
{"role": "user", "content": f"Spell the word: {word}"}, {"role": "user", "content": f"Spell the word: {word}"},
{"role": "assistant", "content": f"{word}:{word_letters}"} {"role": "assistant", "content": f"{word}:{word_letters}"},
] ]
conversation = { conversation = {
"messages": messages, "messages": messages,
@ -269,7 +272,6 @@ class SimpleSpelling(Task):
if __name__ == "__main__": if __name__ == "__main__":
# preview the SpellingBee task, first 10 examples # preview the SpellingBee task, first 10 examples
task = SpellingBee() task = SpellingBee()
for i in range(10): for i in range(10):

View File

@ -5,8 +5,10 @@ python -m pytest tests/test_engine.py -v
""" """
import torch import torch
from nanochat.engine import KVCache from nanochat.engine import KVCache
def test_kv_cache_resize(): def test_kv_cache_resize():
""" """
The KV cache was not resized correctly, more information here: The KV cache was not resized correctly, more information here:
@ -21,11 +23,7 @@ def test_kv_cache_resize():
num_layers = 6 num_layers = 6
kv_cache = KVCache( kv_cache = KVCache(
batch_size=batch_size, batch_size=batch_size, num_heads=num_heads, seq_len=seq_len, head_dim=head_dim, num_layers=num_layers
num_heads=num_heads,
seq_len=seq_len,
head_dim=head_dim,
num_layers=num_layers
) )
# Insert a single token with a distinct fill value to all layers # Insert a single token with a distinct fill value to all layers
@ -47,7 +45,9 @@ def test_kv_cache_resize():
insert_token(4) insert_token(4)
# Verify that the cache actually resized # Verify that the cache actually resized
new_seq_len = kv_cache.kv_cache.shape[4] new_seq_len = kv_cache.kv_cache.shape[4]
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}" assert new_seq_len > original_seq_len, (
f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
)
# Verify that the original 4 tokens are still intact after resize # Verify that the original 4 tokens are still intact after resize
for layer_idx in range(num_layers): for layer_idx in range(num_layers):
@ -57,8 +57,12 @@ def test_kv_cache_resize():
expected_v = float(token_idx * 100) expected_v = float(token_idx * 100)
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :] actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :] actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}" assert (actual_k == expected_k).all(), (
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}" f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
)
assert (actual_v == expected_v).all(), (
f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
)
# And that the original cache matches resized cache # And that the original cache matches resized cache
original_k = original_cache[layer_idx, 0, :, :, token_idx, :] original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
original_v = original_cache[layer_idx, 1, :, :, token_idx, :] original_v = original_cache[layer_idx, 1, :, :, token_idx, :]

View File

@ -18,18 +18,23 @@ python -m pytest tests/test_rustbpe.py -v -s
-v is verbose, -s is show prints -v is verbose, -s is show prints
""" """
import regex as re
from collections import Counter, defaultdict
import time import time
import rustbpe from collections import Counter, defaultdict
import tiktoken
import pytest
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" import pytest
import regex as re
import tiktoken
import rustbpe
GPT4_SPLIT_PATTERN = (
r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Reference tokenizer, pretty much copy pasted and pruned a bit from minbpe # Reference tokenizer, pretty much copy pasted and pruned a bit from minbpe
def get_stats(ids, counts=None): def get_stats(ids, counts=None):
""" """
Given a list of integers, return a dictionary of counts of consecutive pairs Given a list of integers, return a dictionary of counts of consecutive pairs
@ -41,6 +46,7 @@ def get_stats(ids, counts=None):
counts[pair] = counts.get(pair, 0) + 1 counts[pair] = counts.get(pair, 0) + 1
return counts return counts
def merge(ids, pair, idx): def merge(ids, pair, idx):
""" """
In the list of integers (ids), replace all consecutive occurrences In the list of integers (ids), replace all consecutive occurrences
@ -59,8 +65,8 @@ def merge(ids, pair, idx):
i += 1 i += 1
return newids return newids
class RegexTokenizer:
class RegexTokenizer:
def __init__(self, pattern=None): def __init__(self, pattern=None):
""" """
- pattern: optional string to override the default (GPT-4 split pattern) - pattern: optional string to override the default (GPT-4 split pattern)
@ -163,9 +169,11 @@ class RegexTokenizer:
ids.extend(chunk_ids) ids.extend(chunk_ids)
return ids return ids
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Faster Python tokenizer, optimized version of the reference tokenizer # Faster Python tokenizer, optimized version of the reference tokenizer
def fast_merge_inplace(ids, pair, idx): def fast_merge_inplace(ids, pair, idx):
""" """
In the list of integers (ids), replace all consecutive occurrences In the list of integers (ids), replace all consecutive occurrences
@ -184,7 +192,6 @@ def fast_merge_inplace(ids, pair, idx):
class FastRegexTokenizer: class FastRegexTokenizer:
def __init__(self, pattern=None): def __init__(self, pattern=None):
""" """
- pattern: optional string to override the default (GPT-4 split pattern) - pattern: optional string to override the default (GPT-4 split pattern)
@ -302,8 +309,9 @@ class FastRegexTokenizer:
# Update positions for changed pairs - only check affected chunks # Update positions for changed pairs - only check affected chunks
for chunk_idx in affected_chunks: for chunk_idx in affected_chunks:
chunk_ids = ids[chunk_idx] chunk_ids = ids[chunk_idx]
contains_pair = any((chunk_ids[j], chunk_ids[j+1]) == changed_pair contains_pair = any(
for j in range(len(chunk_ids) - 1)) (chunk_ids[j], chunk_ids[j + 1]) == changed_pair for j in range(len(chunk_ids) - 1)
)
if contains_pair: if contains_pair:
positions[changed_pair].add(chunk_idx) positions[changed_pair].add(chunk_idx)
else: else:
@ -372,13 +380,15 @@ class FastRegexTokenizer:
ids.extend(chunk_ids) ids.extend(chunk_ids)
return ids return ids
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# HuggingFace tokenizer # HuggingFace tokenizer
from tokenizers import Regex, decoders, pre_tokenizers
from tokenizers import Tokenizer as HFTokenizer from tokenizers import Tokenizer as HFTokenizer
from tokenizers import pre_tokenizers, decoders, Regex
from tokenizers.models import BPE from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer from tokenizers.trainers import BpeTrainer
class HuggingFaceTokenizer: class HuggingFaceTokenizer:
"""Light wrapper around HuggingFace Tokenizer for some utilities""" """Light wrapper around HuggingFace Tokenizer for some utilities"""
@ -389,19 +399,23 @@ class HuggingFaceTokenizer:
def train_from_iterator(cls, text_iterator, vocab_size): def train_from_iterator(cls, text_iterator, vocab_size):
# train from an iterator of text # train from an iterator of text
# Configure the HuggingFace Tokenizer # Configure the HuggingFace Tokenizer
tokenizer = HFTokenizer(BPE( tokenizer = HFTokenizer(
BPE(
byte_fallback=True, # needed! byte_fallback=True, # needed!
unk_token=None, unk_token=None,
fuse_unk=False, fuse_unk=False,
)) )
)
# Normalizer: None # Normalizer: None
tokenizer.normalizer = None tokenizer.normalizer = None
# Pre-tokenizer: GPT-4 style # Pre-tokenizer: GPT-4 style
gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!! gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False), pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False) pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
]) ]
)
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer) # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
tokenizer.decoder = decoders.ByteLevel() tokenizer.decoder = decoders.ByteLevel()
# Post-processor: None # Post-processor: None
@ -422,15 +436,19 @@ class HuggingFaceTokenizer:
ids = self.tokenizer.encode(text, add_special_tokens=False).ids ids = self.tokenizer.encode(text, add_special_tokens=False).ids
return ids return ids
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Test all of the above # Test all of the above
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def enwik8_path(): def enwik8_path():
"""Fixture to download and cache enwik8 dataset.""" """Fixture to download and cache enwik8 dataset."""
import os import os
import zipfile import zipfile
from nanochat.common import get_base_dir from nanochat.common import get_base_dir
base_dir = get_base_dir() base_dir = get_base_dir()
# download and unzip enwik8 to .cache directory # download and unzip enwik8 to .cache directory
enwik8_url = "https://mattmahoney.net/dc/enwik8.zip" enwik8_url = "https://mattmahoney.net/dc/enwik8.zip"
@ -439,6 +457,7 @@ def enwik8_path():
if not os.path.exists(enwik8_local_path): if not os.path.exists(enwik8_local_path):
print(f"Downloading enwik8 to {enwik8_local_path_zip}") print(f"Downloading enwik8 to {enwik8_local_path_zip}")
import requests import requests
response = requests.get(enwik8_url) response = requests.get(enwik8_url)
with open(enwik8_local_path_zip, "wb") as f: with open(enwik8_local_path_zip, "wb") as f:
f.write(response.content) f.write(response.content)
@ -455,15 +474,17 @@ def enwik8_path():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def enwik8_small(enwik8_path): def enwik8_small(enwik8_path):
"""Fixture providing 100KB of enwik8 for quick tests.""" """Fixture providing 100KB of enwik8 for quick tests."""
with open(enwik8_path, "r", encoding="utf-8") as f: with open(enwik8_path, encoding="utf-8") as f:
return f.read(100_000) return f.read(100_000)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def enwik8_large(enwik8_path): def enwik8_large(enwik8_path):
"""Fixture providing 10MB of enwik8 for performance tests.""" """Fixture providing 10MB of enwik8 for performance tests."""
with open(enwik8_path, "r", encoding="utf-8") as f: with open(enwik8_path, encoding="utf-8") as f:
return f.read(10**7) return f.read(10**7)
def time_function(func, *args, **kwargs): def time_function(func, *args, **kwargs):
"""Time a function call and return the result and elapsed time""" """Time a function call and return the result and elapsed time"""
start_time = time.time() start_time = time.time()
@ -472,6 +493,7 @@ def time_function(func, *args, **kwargs):
elapsed = end_time - start_time elapsed = end_time - start_time
return result, elapsed return result, elapsed
def test_correctness(enwik8_small): def test_correctness(enwik8_small):
"""Test that all tokenizer implementations produce the same results.""" """Test that all tokenizer implementations produce the same results."""
text = enwik8_small text = enwik8_small
@ -482,7 +504,9 @@ def test_correctness(enwik8_small):
print("\nTraining slow reference...") print("\nTraining slow reference...")
slow_reference_tokenizer = RegexTokenizer() slow_reference_tokenizer = RegexTokenizer()
ambiguous_flag, slow_reference_train_time = time_function(slow_reference_tokenizer.train, text, vocab_size) ambiguous_flag, slow_reference_train_time = time_function(slow_reference_tokenizer.train, text, vocab_size)
slow_reference_ids, slow_reference_encode_time = time_function(slow_reference_tokenizer.encode_ordinary, encode_text) slow_reference_ids, slow_reference_encode_time = time_function(
slow_reference_tokenizer.encode_ordinary, encode_text
)
print(f"Slow reference train time: {slow_reference_train_time:.4f}s") print(f"Slow reference train time: {slow_reference_train_time:.4f}s")
print(f"Slow reference encode time: {slow_reference_encode_time:.4f}s") print(f"Slow reference encode time: {slow_reference_encode_time:.4f}s")
print(slow_reference_ids[:20]) print(slow_reference_ids[:20])
@ -497,7 +521,9 @@ def test_correctness(enwik8_small):
print("\nTraining fast reference...") print("\nTraining fast reference...")
fast_reference_tokenizer = FastRegexTokenizer() fast_reference_tokenizer = FastRegexTokenizer()
_, fast_reference_train_time = time_function(fast_reference_tokenizer.train, text, vocab_size) _, fast_reference_train_time = time_function(fast_reference_tokenizer.train, text, vocab_size)
fast_reference_ids, fast_reference_encode_time = time_function(fast_reference_tokenizer.encode_ordinary, encode_text) fast_reference_ids, fast_reference_encode_time = time_function(
fast_reference_tokenizer.encode_ordinary, encode_text
)
print(f"Fast reference train time: {fast_reference_train_time:.4f}s") print(f"Fast reference train time: {fast_reference_train_time:.4f}s")
print(f"Fast reference encode time: {fast_reference_encode_time:.4f}s") print(f"Fast reference encode time: {fast_reference_encode_time:.4f}s")
print(fast_reference_ids[:20]) print(fast_reference_ids[:20])
@ -589,14 +615,16 @@ def test_training_performance(enwik8_large):
assert hf_train_time > 0, "Training should take some time" assert hf_train_time > 0, "Training should take some time"
# Print comparison # Print comparison
print(f"\n📊 Performance comparison:") print("\n📊 Performance comparison:")
print(f" RustBPE: {rustbpe_train_time:.4f}s") print(f" RustBPE: {rustbpe_train_time:.4f}s")
print(f" HuggingFace: {hf_train_time:.4f}s") print(f" HuggingFace: {hf_train_time:.4f}s")
print(f" Speedup: {hf_train_time / rustbpe_train_time:.2f}x") print(f" Speedup: {hf_train_time / rustbpe_train_time:.2f}x")
def test_interface(enwik8_small): def test_interface(enwik8_small):
"""Test the RustBPETokenizer interface for training, encoding, decoding, and serialization.""" """Test the RustBPETokenizer interface for training, encoding, decoding, and serialization."""
import tempfile import tempfile
from nanochat.tokenizer import RustBPETokenizer from nanochat.tokenizer import RustBPETokenizer
# Simple train test # Simple train test