previous moe to_hf implementation

This commit is contained in:
Muheng 2026-01-01 12:03:33 +00:00
parent aef3893f44
commit 08f3255ff5
32 changed files with 722 additions and 4625 deletions

5
.gitignore vendored
View File

@ -15,4 +15,7 @@ hf-export/**/*.pkl
hf-export/**/__pycache__/
*.log
.cache/
agent.md
agent.md
lm_eval_sample_*/
*.pt
hf-export/**/*

View File

@ -1,32 +0,0 @@
{
"architectures": [
"NanoChatHFForCausalLM"
],
"auto_map": {
"AutoConfig": "configuration_nanochat.NanoChatHFConfig",
"AutoModelForCausalLM": "modeling_nanochat.NanoChatHFForCausalLM",
"AutoTokenizer": [
"tokenization_nanochat.NanoChatTokenizer",
null
]
},
"dtype": "float32",
"model_type": "nanochat",
"n_embd": 1280,
"n_head": 10,
"n_kv_head": 10,
"n_layer": 20,
"num_attention_heads": 10,
"num_hidden_layers": 20,
"num_key_value_heads": 10,
"hidden_size": 1280,
"max_position_embeddings": 2048,
"sequence_len": 2048,
"tie_word_embeddings": false,
"tokenizer_class": "NanoChatTokenizer",
"bos_token_id": 65527,
"eos_token_id": 65531,
"pad_token_id": 65527,
"transformers_version": "4.57.3",
"vocab_size": 65536
}

View File

@ -1,20 +0,0 @@
from transformers import PretrainedConfig
class NanoChatHFConfig(PretrainedConfig):
model_type = "nanochat"
def __init__(self, sequence_len=1024, vocab_size=50304, n_layer=12, n_head=6, n_kv_head=6, n_embd=768, **kwargs):
kwargs.setdefault("tie_word_embeddings", False)
super().__init__(**kwargs)
self.sequence_len = sequence_len
self.vocab_size = vocab_size
self.n_layer = n_layer
self.n_head = n_head
self.n_kv_head = n_kv_head
self.n_embd = n_embd
# HF compatibility aliases
self.num_hidden_layers = n_layer
self.hidden_size = n_embd
self.num_attention_heads = n_head
self.num_key_value_heads = n_kv_head
self.max_position_embeddings = sequence_len

View File

@ -1,4 +0,0 @@
{
"_from_model_config": true,
"transformers_version": "4.57.3"
}

View File

@ -1,172 +0,0 @@
"""
Minimal GPT implementation for HF export (inference-only utilities).
"""
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class GPTConfig:
sequence_len: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
y2 = x1 * (-sin) + x2 * cos
out = torch.cat([y1, y2], 3) # re-assemble
out = out.to(x.dtype) # ensure input/output dtypes match
return out
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache):
B, T, _ = x.size()
# Project the input to get queries, keys, and values
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = norm(q), norm(k) # QK norm
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # (B, T, H, D) -> (B, H, T, D)
# Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk:
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif Tq == 1:
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
else:
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
prefix_len = Tk - Tq
if prefix_len > 0:
attn_mask[:, :prefix_len] = True
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
# Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache)
x = x + self.mlp(norm(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Precompute rotary embeddings (small overhead, avoids realloc each forward)
self.rotary_seq_len = config.sequence_len * 10
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
def get_device(self):
return self.transformer.wte.weight.device
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
if device is None:
device = self.transformer.wte.weight.device
# When transformers initializes the model under `init_empty_weights`, parameters live on the
# meta device. Buffers created on meta cannot be moved with `.to()`. To keep HF/accelerate
# happy, build these buffers on CPU in that case.
if getattr(device, "type", None) == "meta":
device = torch.device("cpu")
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16()
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
return cos, sin
def forward(self, idx, targets=None, kv_cache=None, loss_reduction="mean"):
B, T = idx.size()
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
x = self.transformer.wte(idx)
x = norm(x)
for block in self.transformer.h:
x = block(x, cos_sin, kv_cache)
x = norm(x)
softcap = 15
if targets is not None:
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap)
logits = logits.float()
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap)
return logits

View File

@ -1,48 +0,0 @@
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_nanochat import NanoChatHFConfig
from .gpt import GPT, GPTConfig
class NanoChatHFForCausalLM(PreTrainedModel, GenerationMixin):
config_class = NanoChatHFConfig
def __init__(self, config: NanoChatHFConfig):
super().__init__(config)
gpt_cfg = GPTConfig(
sequence_len=config.sequence_len,
vocab_size=config.vocab_size,
n_layer=config.n_layer,
n_head=config.n_head,
n_kv_head=config.n_kv_head,
n_embd=config.n_embd,
)
self.model = GPT(gpt_cfg)
def get_input_embeddings(self):
return self.model.transformer.wte
def set_input_embeddings(self, value):
self.model.transformer.wte = value
def get_output_embeddings(self):
return self.model.lm_head
def tie_weights(self):
return
def forward(self, input_ids=None, attention_mask=None, labels=None, past_key_values=None, **_):
if input_ids is None:
raise ValueError("input_ids must be provided")
logits = self.model(input_ids)
loss = None
if labels is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1)
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}

View File

@ -1,76 +0,0 @@
"""
Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
Not a general optimizer! But works for our specific use.
"""
import torch
import torch.distributed as dist
from torch import Tensor
class DistAdamW(torch.optim.Optimizer):
"""
Distributed AdamW optimizer.
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
"""
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(param_groups, defaults)
@torch.compile
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
reduce_scatter_futures: list[torch.Future] = []
all_reduce_futures: list[torch.Future] = []
grad_slices = []
for group in self.param_groups:
params: list[Tensor] = group["params"]
for base_i in range(len(params)):
grad = params[base_i].grad
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad_slice)
idx = 0
for group in self.param_groups:
beta1, beta2 = group['betas']
eps = group['eps']
wd = group['weight_decay']
params = group['params']
for base in range(len(params)):
reduce_scatter_futures[idx].wait()
p = params[base]
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
state = self.state[p]
g_slice = grad_slices[idx]
# State init
if not state:
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
state['exp_avg'] = torch.zeros_like(p_slice)
state['exp_avg_sq'] = torch.zeros_like(p_slice)
exp_avg = state['exp_avg']
exp_avg_sq = state['exp_avg_sq']
state['step'] += 1
t = state['step']
# weight decay
if wd != 0:
eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
p_slice.mul_(1 - eff_weight_decay)
# update running averages
exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
# bias corrections
bias1 = 1 - beta1 ** t
bias2 = 1 - beta2 ** t
# compute step
denom = exp_avg_sq.sqrt().add_(eps)
step_size = lr * (torch.sqrt(bias2) / bias1)
update = exp_avg.div(denom).mul_(step_size)
p_slice.add_(other=update, alpha=-1.0)
idx += 1
all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
torch.futures.collect_all(all_reduce_futures).wait()

View File

@ -1,151 +0,0 @@
"""
Utilities for saving and loading model/optim/state checkpoints.
"""
import os
import re
import glob
import json
import logging
import torch
from nanochat.common import get_base_dir
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import get_tokenizer
from nanochat.common import setup_default_logging
# Set up logging
setup_default_logging()
logger = logging.getLogger(__name__)
def log0(message):
if int(os.environ.get('RANK', 0)) == 0:
logger.info(message)
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
# Save the model state parameters
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
torch.save(model_data, model_path)
logger.info(f"Saved model parameters to: {model_path}")
# Save the metadata dict as json
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
logger.info(f"Saved metadata to: {meta_path}")
# Note that optimizer state is sharded across ranks, so each rank must save its own.
if optimizer_data is not None:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
torch.save(optimizer_data, optimizer_path)
logger.info(f"Saved optimizer state to: {optimizer_path}")
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
# Load the model state
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
model_data = torch.load(model_path, map_location=device)
# Load the optimizer state if requested
optimizer_data = None
if load_optimizer:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
optimizer_data = torch.load(optimizer_path, map_location=device)
# Load the metadata
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
with open(meta_path, "r", encoding="utf-8") as f:
meta_data = json.load(f)
return model_data, optimizer_data, meta_data
def build_model(checkpoint_dir, step, device, phase):
"""
A bunch of repetitive code to build a model from a given checkpoint.
Returns:
- base model - uncompiled, not wrapped in DDP
- tokenizer
- meta data saved during base model training
"""
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
if device.type in {"cpu", "mps"}:
# Convert bfloat16 tensors to float for CPU inference
model_data = {
k: v.float() if v.dtype == torch.bfloat16 else v
for k, v in model_data.items()
}
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
model_config_kwargs = meta_data["model_config"]
log0(f"Building model with config: {model_config_kwargs}")
model_config = GPTConfig(**model_config_kwargs)
with torch.device("meta"):
model = GPT(model_config)
# Load the model state
model.to_empty(device=device)
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
model.load_state_dict(model_data, strict=True, assign=True)
# Put the model in the right training phase / mode
if phase == "eval":
model.eval()
else:
model.train()
# Load the Tokenizer
tokenizer = get_tokenizer()
# Sanity check: compatibility between model and tokenizer
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
return model, tokenizer, meta_data
def find_largest_model(checkpoint_dir):
# attempt to guess the model tag: take the biggest model available
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
if not model_tags:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
# 1) normally all model tags are of the form d<number>, try that first:
candidates = []
for model_tag in model_tags:
match = re.match(r"d(\d+)", model_tag)
if match:
model_depth = int(match.group(1))
candidates.append((model_depth, model_tag))
if candidates:
candidates.sort(key=lambda x: x[0], reverse=True)
return candidates[0][1]
# 2) if that failed, take the most recently updated model:
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
return model_tags[0]
def find_last_step(checkpoint_dir):
# Look into checkpoint_dir and find model_<step>.pt with the highest step
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
if not checkpoint_files:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
return last_step
# -----------------------------------------------------------------------------
# convenience functions that take into account nanochat's directory structure
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
if model_tag is None:
# guess the model tag by defaulting to the largest model
model_tag = find_largest_model(checkpoints_dir)
log0(f"No model tag provided, guessing model tag: {model_tag}")
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
if step is None:
# guess the step by defaulting to the last step
step = find_last_step(checkpoint_dir)
assert step is not None, f"No checkpoints found in {checkpoint_dir}"
# build the model
log0(f"Loading model from {checkpoint_dir} with step {step}")
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
return model, tokenizer, meta_data
def load_model(source, *args, **kwargs):
model_dir = {
"base": "base_checkpoints",
"mid": "mid_checkpoints",
"sft": "chatsft_checkpoints",
"rl": "chatrl_checkpoints",
}[source]
base_dir = get_base_dir()
checkpoints_dir = os.path.join(base_dir, model_dir)
return load_model_from_dir(checkpoints_dir, *args, **kwargs)

View File

@ -1,190 +0,0 @@
"""
Common utilities for nanochat.
"""
import os
import re
import logging
import urllib.request
import torch
import torch.distributed as dist
from filelock import FileLock
class ColoredFormatter(logging.Formatter):
"""Custom formatter that adds colors to log messages."""
# ANSI color codes
COLORS = {
'DEBUG': '\033[36m', # Cyan
'INFO': '\033[32m', # Green
'WARNING': '\033[33m', # Yellow
'ERROR': '\033[31m', # Red
'CRITICAL': '\033[35m', # Magenta
}
RESET = '\033[0m'
BOLD = '\033[1m'
def format(self, record):
# Add color to the level name
levelname = record.levelname
if levelname in self.COLORS:
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
# Format the message
message = super().format(record)
# Add color to specific parts of the message
if levelname == 'INFO':
# Highlight numbers and percentages
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
return message
def setup_default_logging():
handler = logging.StreamHandler()
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logging.basicConfig(
level=logging.INFO,
handlers=[handler]
)
setup_default_logging()
logger = logging.getLogger(__name__)
def get_base_dir():
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
if os.environ.get("NANOCHAT_BASE_DIR"):
nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
else:
home_dir = os.path.expanduser("~")
cache_dir = os.path.join(home_dir, ".cache")
nanochat_dir = os.path.join(cache_dir, "nanochat")
os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir
def download_file_with_lock(url, filename, postprocess_fn=None):
"""
Downloads a file from a URL to a local path in the base directory.
Uses a lock file to prevent concurrent downloads among multiple ranks.
"""
base_dir = get_base_dir()
file_path = os.path.join(base_dir, filename)
lock_path = file_path + ".lock"
if os.path.exists(file_path):
return file_path
with FileLock(lock_path):
# Only a single rank can acquire this lock
# All other ranks block until it is released
# Recheck after acquiring lock
if os.path.exists(file_path):
return file_path
# Download the content as bytes
print(f"Downloading {url}...")
with urllib.request.urlopen(url) as response:
content = response.read() # bytes
# Write to local file
with open(file_path, 'wb') as f:
f.write(content)
print(f"Downloaded to {file_path}")
# Run the postprocess function if provided
if postprocess_fn is not None:
postprocess_fn(file_path)
return file_path
def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
print(s, **kwargs)
def print_banner():
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
banner = """
"""
print0(banner)
def is_ddp():
# TODO is there a proper way
return int(os.environ.get('RANK', -1)) != -1
def get_dist_info():
if is_ddp():
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
return True, ddp_rank, ddp_local_rank, ddp_world_size
else:
return False, 0, 0, 1
def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch.cuda.is_available():
device_type = "cuda"
elif torch.backends.mps.is_available():
device_type = "mps"
else:
device_type = "cpu"
print0(f"Autodetected device type: {device_type}")
return device_type
def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common."""
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
# Reproducibility
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
# The only place where global rng might be used is nn.Module initialization of the model weights.
torch.manual_seed(42)
if device_type == "cuda":
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# Precision
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp and device_type == "cuda":
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
else:
device = torch.device(device_type) # mps|cpu
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit"""
if is_ddp():
dist.destroy_process_group()
class DummyWandb:
"""Useful if we wish to not use wandb but have all the same signatures"""
def __init__(self):
pass
def log(self, *args, **kwargs):
pass
def finish(self):
pass

View File

@ -1,56 +0,0 @@
"""
Poor Man's Configurator. Probably a terrible idea. Example usage:
$ python train.py config/override_file.py --batch_size=32
this will first run config/override_file.py, then override batch_size to 32
The code in this file will be run as follows from e.g. train.py:
>>> exec(open('configurator.py').read())
So it's not a Python module, it's just shuttling this code away from train.py
The code in this script then overrides the globals()
I know people are not going to love this, I just really dislike configuration
complexity and having to prepend config. to every single variable. If someone
comes up with a better simple Python solution I am all ears.
"""
import os
import sys
from ast import literal_eval
def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
print(s, **kwargs)
for arg in sys.argv[1:]:
if '=' not in arg:
# assume it's the name of a config file
assert not arg.startswith('--')
config_file = arg
print0(f"Overriding config with {config_file}:")
with open(config_file) as f:
print0(f.read())
exec(open(config_file).read())
else:
# assume it's a --key=value argument
assert arg.startswith('--')
key, val = arg.split('=')
key = key[2:]
if key in globals():
try:
# attempt to eval it it (e.g. if bool, number, or etc)
attempt = literal_eval(val)
except (SyntaxError, ValueError):
# if that goes wrong, just use the string
attempt = val
# ensure the types match ok
if globals()[key] is not None:
attempt_type = type(attempt)
default_type = type(globals()[key])
assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
# cross fingers
print0(f"Overriding: {key} = {attempt}")
globals()[key] = attempt
else:
raise ValueError(f"Unknown config key: {key}")

View File

@ -1,262 +0,0 @@
"""
Functions for evaluating the CORE metric, as described in the DCLM paper.
https://arxiv.org/abs/2406.11794
TODOs:
- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
"""
import random
from jinja2 import Template
import torch
import torch.distributed as dist
# -----------------------------------------------------------------------------
# Prompt rendering utilities
def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
"""Render complete prompts for a multiple choice question"""
template_str = """
{%- for example in fewshot_examples -%}
{{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
{% endfor -%}
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
prompts = [template.render(choice=choice, **context) for choice in item['choices']]
return prompts
def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
"""Render complete prompts for a schema question"""
template_str = """
{%- for example in fewshot_examples -%}
{{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
{% endfor -%}
{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
prompts = [template.render(context=context_option, **context)
for context_option in item['context_options']]
return prompts
def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
"""
Render complete prompt for a language modeling task.
Notice that we manually trim the context in the template,
which in some datasets seems to have trailing whitespace (which we don't want).
"""
template_str = """
{%- for example in fewshot_examples -%}
{{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
{% endfor -%}
{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
# Return two prompts: without and with the continuation
prompt_without = template.render(include_continuation=False, **context)
prompt_with = template.render(include_continuation=True, **context)
# Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
# Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
# token in prompt_with), meaning we don't get a nice and clean prefix in the token space
# to detect the final continuation. Tokenizers...
prompt_without = prompt_without.strip()
return [prompt_without, prompt_with]
def find_common_length(token_sequences, direction='left'):
"""
Find the length of the common prefix or suffix across token sequences
- direction: 'left' for prefix, 'right' for suffix
"""
min_len = min(len(seq) for seq in token_sequences)
indices = {
'left': range(min_len),
'right': range(-1, -min_len-1, -1)
}[direction]
# Find the first position where the token sequences differ
for i, idx in enumerate(indices):
token = token_sequences[0][idx]
if not all(seq[idx] == token for seq in token_sequences):
return i
return min_len
def stack_sequences(tokens, pad_token_id):
"""Stack up a list of token sequences, pad to longest on the right"""
bsz, seq_len = len(tokens), max(len(x) for x in tokens)
input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
for i, x in enumerate(tokens):
input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
return input_ids
def batch_sequences_mc(tokenizer, prompts):
# In multiple choice, contexts are the same but the continuation is different (common prefix)
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
# figure out the start and end of each continuation
answer_start_idx = find_common_length(tokens, direction='left')
start_indices = [answer_start_idx] * len(prompts)
end_indices = [len(x) for x in tokens]
return tokens, start_indices, end_indices
def batch_sequences_schema(tokenizer, prompts):
# In schema tasks, contexts vary but continuation is the same (common suffix)
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
# figure out the start and end of each context
suffix_length = find_common_length(tokens, direction='right')
end_indices = [len(x) for x in tokens]
start_indices = [ei - suffix_length for ei in end_indices]
return tokens, start_indices, end_indices
def batch_sequences_lm(tokenizer, prompts):
# In LM tasks, we have two prompts: without and with continuation
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
tokens_without, tokens_with = tokens
start_idx, end_idx = len(tokens_without), len(tokens_with)
assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
# we only need the with continuation prompt in the LM task, i.e. batch size of 1
return [tokens_with], [start_idx], [end_idx]
@torch.no_grad()
def forward_model(model, input_ids):
"""
Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
The last column of losses is set to nan because we don't have autoregressive targets there.
"""
batch_size, seq_len = input_ids.size()
outputs = model(input_ids)
# Roll the tensor to the left by one position to get the (autoregressive) target ids
target_ids = torch.roll(input_ids, shifts=-1, dims=1)
# Calculate cross entropy at all positions
losses = torch.nn.functional.cross_entropy(
outputs.view(batch_size * seq_len, -1),
target_ids.view(batch_size * seq_len),
reduction='none'
).view(batch_size, seq_len)
# Set the last column to be nan because there is no autoregressive loss there
losses[:, -1] = float('nan')
# Get the argmax predictions at each position
predictions = outputs.argmax(dim=-1)
return losses, predictions
@torch.no_grad()
def evaluate_example(idx, model, tokenizer, data, device, task_meta):
"""Evaluate a single example, return True if correct, False otherwise"""
item = data[idx]
task_type = task_meta['task_type']
num_fewshot = task_meta['num_fewshot']
continuation_delimiter = task_meta['continuation_delimiter']
# Sample few-shot examples (excluding current item)
fewshot_examples = []
if num_fewshot > 0:
rng = random.Random(1234 + idx)
available_indices = [i for i in range(len(data)) if i != idx]
fewshot_indices = rng.sample(available_indices, num_fewshot)
fewshot_examples = [data[i] for i in fewshot_indices]
# Render prompts and batch sequences based on task type
if task_type == 'multiple_choice':
prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
elif task_type == 'schema':
prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
elif task_type == 'language_modeling':
prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
else:
raise ValueError(f"Unsupported task type: {task_type}")
# Some models can't forward sequences beyond a certain length (e.g. GPT-2)
# In these cases, we have to truncate sequences to max length and adjust the indices
if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
max_tokens = model.max_seq_len
new_tokens, new_start_idxs, new_end_idxs = [], [], []
for t, s, e in zip(tokens, start_idxs, end_idxs):
if len(t) > max_tokens:
num_to_crop = len(t) - max_tokens
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
new_start_idxs.append(s - num_to_crop) # shift the indices down
new_end_idxs.append(e - num_to_crop)
assert s - num_to_crop >= 0, "this should never happen right?"
assert e - num_to_crop >= 0, "this should never happen right?"
else:
new_tokens.append(t) # keep unchanged
new_start_idxs.append(s)
new_end_idxs.append(e)
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
# Stack up all the sequences into a batch
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
input_ids = stack_sequences(tokens, pad_token_id)
input_ids = input_ids.to(device)
# Forward the model, get the autoregressive loss and argmax prediction at each token
losses, predictions = forward_model(model, input_ids)
# See if the losses/predictions come out correctly
if task_type == 'language_modeling':
# language modeling task is currently always batch size 1
si = start_idxs[0]
ei = end_idxs[0]
# predictions[i] predict input_ids[i+1] autoregressively
predicted_tokens = predictions[0, si-1:ei-1]
actual_tokens = input_ids[0, si:ei]
is_correct = torch.all(predicted_tokens == actual_tokens).item()
elif task_type in ['multiple_choice', 'schema']:
# For MC/schema: find the option with lowest average loss
mean_losses = [losses[i, si-1:ei-1].mean().item()
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
pred_idx = mean_losses.index(min(mean_losses))
is_correct = pred_idx == item['gold']
else:
raise ValueError(f"Unsupported task type: {task_type}")
return is_correct
def evaluate_task(model, tokenizer, data, device, task_meta):
"""
This function is responsible for evaluating one task across many examples.
It also handles dispatch to all processes if the script is run with torchrun.
"""
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
# stride the examples to each rank
for idx in range(rank, len(data), world_size):
is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
correct[idx] = float(is_correct)
# sync results across all the processes if running distributed
if world_size > 1:
dist.barrier()
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
# compute the mean
mean_correct = correct.mean().item()
return mean_correct

View File

@ -1,87 +0,0 @@
from collections import deque
import torch
import pyarrow.parquet as pq
from nanochat.common import get_dist_info
from nanochat.dataset import list_parquet_files
from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
"""
Stream pretraining text from parquet files, tokenize, yield training batches.
This implementation became a bit more complex because we wish to support approximate resume training.
Instead of turning this into a Class, we opt to return the state_dict with every batch,
and then the caller can pass in a state_dict to resume training from a desired point.
Note that this resumption is atm only *approximate* for simplicity.
We won't repeat the same documents but we might skip a few.
The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
# infinite iterator over document batches (list of text strings)
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
def document_batches():
parquet_paths = list_parquet_files()
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
while True: # iterate infinitely (multi-epoch)
while pq_idx < len(parquet_paths): # iterate over all parquet files
filepath = parquet_paths[pq_idx]
pf = pq.ParquetFile(filepath)
# Start from resume point if resuming on same file, otherwise from DDP rank
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
if resume_rg_idx is not None:
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank
resume_rg_idx = None # set to None as we only want to do this a single time
else:
rg_idx = ddp_rank
while rg_idx < pf.num_row_groups:
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
rg_idx += ddp_world_size # advance to the next row group (in DDP)
pq_idx += 1 # advance to the next parquet file
batches = document_batches()
# Now emit batches of tokens.
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
# get the tokenizer and the bos token
tokenizer = get_tokenizer()
bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration
token_buffer = deque() # we stream tokens on the right and pop from the left
while True:
# Accumulate enough tokens for one iteration before yielding.
while len(token_buffer) < needed_tokens:
doc_batch, (pq_idx, rg_idx) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
token_buffer.extend(tokens)
# Move tokens from the deque into the scratch buffer
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
use_cuda_optimizations = device == "cuda"
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
# Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1]
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
yield inputs, targets, state_dict
def tokenizing_distributed_data_loader(*args, **kwargs):
# helper function that only emits the inputs/targets and not the state_dict
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
yield inputs, targets

View File

@ -1,128 +0,0 @@
"""
The base/pretraining dataset is a set of parquet files.
This file contains utilities for:
- iterating over the parquet files and yielding documents from it
- download the files on demand if they are not on disk
For details of how the dataset was prepared, see `repackage_data_reference.py`.
"""
import os
import argparse
import time
import requests
import pyarrow.parquet as pq
from multiprocessing import Pool
from nanochat.common import get_base_dir
# -----------------------------------------------------------------------------
# The specifics of the current pretraining dataset
# The URL on the internet where the data is hosted and downloaded from on demand
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
base_dir = get_base_dir()
DATA_DIR = os.path.join(base_dir, "base_data")
os.makedirs(DATA_DIR, exist_ok=True)
# -----------------------------------------------------------------------------
# These functions are useful utilities to other modules, can/should be imported
def list_parquet_files(data_dir=None):
""" Looks into a data dir and returns full paths to all parquet files. """
data_dir = DATA_DIR if data_dir is None else data_dir
parquet_files = sorted([
f for f in os.listdir(data_dir)
if f.endswith('.parquet') and not f.endswith('.tmp')
])
parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
return parquet_paths
def parquets_iter_batched(split, start=0, step=1):
"""
Iterate through the dataset, in batches of underlying row_groups for efficiency.
- split can be "train" or "val". the last parquet file will be val.
- start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
parquet_paths = list_parquet_files()
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
for filepath in parquet_paths:
pf = pq.ParquetFile(filepath)
for rg_idx in range(start, pf.num_row_groups, step):
rg = pf.read_row_group(rg_idx)
texts = rg.column('text').to_pylist()
yield texts
# -----------------------------------------------------------------------------
def download_single_file(index):
""" Downloads a single file index, with some backoff """
# Construct the local filepath for this file and skip if it already exists
filename = index_to_filename(index)
filepath = os.path.join(DATA_DIR, filename)
if os.path.exists(filepath):
print(f"Skipping {filepath} (already exists)")
return True
# Construct the remote URL for this file
url = f"{BASE_URL}/{filename}"
print(f"Downloading {filename}...")
# Download with retries
max_attempts = 5
for attempt in range(1, max_attempts + 1):
try:
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
# Write to temporary file first
temp_path = filepath + f".tmp"
with open(temp_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
if chunk:
f.write(chunk)
# Move temp file to final location
os.rename(temp_path, filepath)
print(f"Successfully downloaded {filename}")
return True
except (requests.RequestException, IOError) as e:
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
# Clean up any partial files
for path in [filepath + f".tmp", filepath]:
if os.path.exists(path):
try:
os.remove(path)
except:
pass
# Try a few times with exponential backoff: 2^attempt seconds
if attempt < max_attempts:
wait_time = 2 ** attempt
print(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time)
else:
print(f"Failed to download {filename} after {max_attempts} attempts")
return False
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
args = parser.parse_args()
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
ids_to_download = list(range(num))
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
print(f"Target directory: {DATA_DIR}")
print()
with Pool(processes=args.num_workers) as pool:
results = pool.map(download_single_file, ids_to_download)
# Report results
successful = sum(1 for success in results if success)
print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")

View File

@ -1,378 +0,0 @@
"""
Engine for efficient inference of our models.
Everything works around token sequences:
- The user can send token sequences to the engine
- The engine returns the next token
Notes:
- The engine knows nothing about tokenization, it's purely token id sequences.
The whole thing is made as efficient as possible.
"""
import torch
import torch.nn.functional as F
import signal
import warnings
from contextlib import contextmanager
from collections import deque
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from contextlib import nullcontext
# -----------------------------------------------------------------------------
# Calculator tool helpers
@contextmanager
def timeout(duration, formula):
def timeout_handler(signum, frame):
raise Exception(f"'{formula}': timed out after {duration} seconds")
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(duration)
yield
signal.alarm(0)
def eval_with_timeout(formula, max_time=3):
try:
with timeout(max_time, formula):
with warnings.catch_warnings():
warnings.simplefilter("ignore", SyntaxWarning)
return eval(formula, {"__builtins__": {}}, {})
except Exception as e:
signal.alarm(0)
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
return None
def use_calculator(expr):
"""
Evaluate a Python expression safely.
Supports both math expressions and string operations like .count()
"""
# Remove commas from numbers
expr = expr.replace(",", "")
# Check if it's a pure math expression (old behavior)
if all([x in "0123456789*+-/.() " for x in expr]):
if "**" in expr: # disallow power operator
return None
return eval_with_timeout(expr)
# Check if it's a string operation we support
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
if not all([x in allowed_chars for x in expr]):
return None
# Disallow dangerous patterns
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
'getattr', 'setattr', 'delattr', 'hasattr']
expr_lower = expr.lower()
if any(pattern in expr_lower for pattern in dangerous_patterns):
return None
# Only allow .count() method for now (can expand later)
if '.count(' not in expr:
return None
# Evaluate with timeout
return eval_with_timeout(expr)
# -----------------------------------------------------------------------------
class KVCache:
"""
Works hand-in-hand with the GPT model to maintain the KV cache.
Note that the .pos advances automatically after the last layer of the Transformer inserts.
"""
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
self.kv_cache = None
self.pos = 0 # current position in time in the cache
def reset(self):
self.pos = 0
def get_pos(self):
return self.pos
def prefill(self, other):
"""
Prefill given another KV cache. Optionally expand along batch dim.
This is used when we do batch 1 prefill and then want to generate
multiple samples in parallel from there.
"""
# 1) validate the shapes
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
# ix 0: num_layers, 1: k/v, 2: batch_size, 3: num_heads, 4: seq_len, 5: head_dim
if ix in [0, 1, 3, 5]:
# num_layers, k/v, num_heads, head_dim must match
assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}"
elif ix == 2:
# batch_size can be expanded
assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
elif ix == 4:
# seq_len: self must be longer than other
assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
# 2) initialize the cache
dtype, device = other.kv_cache.dtype, other.kv_cache.device
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
# 3) copy the data over
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
# 4) update the pos
self.pos = other.pos
def insert_kv(self, layer_idx, k, v):
# Lazy initialize the cache here because we need to know the dtype/device
if self.kv_cache is None:
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
# Insert new keys/values to the cache and return the full cache so far
B, H, T_add, D = k.size()
t0, t1 = self.pos, self.pos + T_add
# Dynamically grow the cache if needed
if t1 > self.kv_cache.size(4):
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
additional_shape = list(self.kv_cache.shape)
additional_shape[4] = t_needed - self.kv_cache.size(4)
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
self.kv_shape = self.kv_cache.shape
# Insert k, v into the cache
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
# Return the full cached keys/values up to current position (as a view)
key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
# Increment pos after the last layer of the Transformer processes
if layer_idx == self.kv_cache.size(0) - 1:
self.pos = t1
return key_view, value_view
# -----------------------------------------------------------------------------
@torch.inference_mode()
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
"""Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
assert temperature >= 0.0, "temperature must be non-negative"
if temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
if top_k is not None:
k = min(top_k, logits.size(-1))
vals, idx = torch.topk(logits, k, dim=-1)
vals = vals / temperature
probs = F.softmax(vals, dim=-1)
choice = torch.multinomial(probs, num_samples=1, generator=rng)
return idx.gather(1, choice)
else:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=rng)
# -----------------------------------------------------------------------------
class RowState:
# Per-row state tracking during generation
def __init__(self, current_tokens=None):
self.current_tokens = current_tokens or [] # Current token sequence for this row
self.forced_tokens = deque() # Queue of tokens to force inject
self.in_python_block = False # Whether we are inside a python block
self.python_expr_tokens = [] # Tokens of the current python expression
self.completed = False # Whether this row has completed generation
class Engine:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer # needed for tool use
@torch.inference_mode()
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
"""Same as generate, but does single prefill and then clones the KV cache."""
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
device = self.model.get_device()
rng = torch.Generator(device=device)
rng.manual_seed(seed)
# Get the special tokens we need to coordinate the tool use state machine
get_special = lambda s: self.tokenizer.encode_special(s)
python_start = get_special("<|python_start|>")
python_end = get_special("<|python_end|>")
output_start = get_special("<|output_start|>")
output_end = get_special("<|output_end|>")
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
# 1) Run a batch 1 prefill of the prompt tokens
m = self.model.config
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens),
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :]
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
kv_cache_decode = KVCache(
batch_size=num_samples,
seq_len=kv_length_hint,
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
del kv_cache_prefill # no need to keep this memory around
# 3) Initialize states for each sample
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
# 4) Main generation loop
num_generated = 0
first_iteration = True
while True:
# Stop condition: we've reached max tokens
if max_tokens is not None and num_generated >= max_tokens:
break
# Stop condition: all rows are completed
if all(state.completed for state in row_states):
break
# Get sampled tokens - either from prefill or from forward pass
if first_iteration:
# Use the tokens we already sampled from prefill
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
# TODO: we should sample a token for each row instead of broadcasting
first_iteration = False
else:
# Forward the model and get the next token for each row
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size) at last time step
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# Process each row: choose the next token, update state, optional tool use
token_column = [] # contains the next token id along each row
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
for i, state in enumerate(row_states):
# Select the next token in this row
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
token_column.append(next_token)
# Update the state of this row to include the next token
state.current_tokens.append(next_token)
# On <|assistant_end|> or <|bos|>, mark the row as completed
if next_token == assistant_end or next_token == bos:
state.completed = True
# Handle tool logic
if next_token == python_start:
state.in_python_block = True
state.python_expr_tokens = []
elif next_token == python_end and state.in_python_block:
state.in_python_block = False
if state.python_expr_tokens:
expr = self.tokenizer.decode(state.python_expr_tokens)
result = use_calculator(expr)
if result is not None:
result_tokens = self.tokenizer.encode(str(result))
state.forced_tokens.append(output_start)
state.forced_tokens.extend(result_tokens)
state.forced_tokens.append(output_end)
state.python_expr_tokens = []
elif state.in_python_block:
state.python_expr_tokens.append(next_token)
# Yield the token column
yield token_column, token_masks
num_generated += 1
# Prepare ids for next iteration
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
def generate_batch(self, tokens, num_samples=1, **kwargs):
"""
Non-streaming batch generation that just returns the final token sequences.
Returns a list of token sequences (list of lists of ints).
Terminal tokens (assistant_end, bos) are not included in the results.
"""
assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
bos = self.tokenizer.get_bos_token_id()
results = [tokens.copy() for _ in range(num_samples)]
masks = [[0] * len(tokens) for _ in range(num_samples)]
completed = [False] * num_samples
for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
for i, (token, mask) in enumerate(zip(token_column, token_masks)):
if not completed[i]:
if token == assistant_end or token == bos:
completed[i] = True
else:
results[i].append(token)
masks[i].append(mask)
# Stop if all rows are completed
if all(completed):
break
return results, masks
if __name__ == "__main__":
"""
Quick inline test to make sure that the naive/slow model.generate function
is equivalent to the faster Engine.generate function here.
"""
import time
# init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="eval")
bos_token_id = tokenizer.get_bos_token_id()
# common hyperparameters
kwargs = dict(max_tokens=64, temperature=0.0)
# set the starting prompt
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
# generate the reference sequence using the model.generate() function
generated_tokens = []
torch.cuda.synchronize()
t0 = time.time()
stream = model.generate(prompt_tokens, **kwargs)
with autocast_ctx:
for token in stream:
generated_tokens.append(token)
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print()
torch.cuda.synchronize()
t1 = time.time()
print(f"Reference time: {t1 - t0:.2f}s")
reference_ids = generated_tokens
# generate tokens with Engine
generated_tokens = []
engine = Engine(model, tokenizer)
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
torch.cuda.synchronize()
t0 = time.time()
with autocast_ctx:
for token_column, token_masks in stream:
token = token_column[0] # only print out the first row
generated_tokens.append(token)
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print()
torch.cuda.synchronize()
t1 = time.time()
print(f"Engine time: {t1 - t0:.2f}s")
# compare the two sequences
for i in range(len(reference_ids)):
if reference_ids[i] != generated_tokens[i]:
print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
break
print(f"Match: {reference_ids == generated_tokens}")

View File

@ -1,349 +0,0 @@
"""
Sandboxed execution utilities for running Python code that comes out of an LLM.
Adapted from OpenAI HumanEval code:
https://github.com/openai/human-eval/blob/master/human_eval/execution.py
What is covered:
- Each execution runs in its own process (can be killed if it hangs or crashes)
- Execution is limited by a timeout to stop infinite loops
- Memory limits are enforced by default (256MB)
- stdout and stderr are captured and returned
- Code runs in a temporary directory that is deleted afterwards
- Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen)
What is not covered:
- Not a true security sandbox
- Network access is not blocked (e.g. sockets could be opened)
- Python's dynamic features (e.g. ctypes) could bypass restrictions
- No kernel-level isolation (no seccomp, no containers, no virtualization)
Overall this sandbox is good for evaluation of generated code and protects against
accidental destructive behavior, but it is not safe against malicious adversarial code.
"""
import contextlib
import faulthandler
import io
import multiprocessing
import os
import platform
import signal
import tempfile
from dataclasses import dataclass
from typing import Optional
# -----------------------------------------------------------------------------
@dataclass
class ExecutionResult:
"""Result of executing Python code in a sandbox."""
success: bool
stdout: str
stderr: str
error: Optional[str] = None
timeout: bool = False
memory_exceeded: bool = False
def __repr__(self):
parts = []
parts.append(f"ExecutionResult(success={self.success}")
if self.timeout:
parts.append(", timeout=True")
if self.memory_exceeded:
parts.append(", memory_exceeded=True")
if self.error:
parts.append(f", error={self.error!r}")
if self.stdout:
parts.append(f", stdout={self.stdout!r}")
if self.stderr:
parts.append(f", stderr={self.stderr!r}")
parts.append(")")
return "".join(parts)
@contextlib.contextmanager
def time_limit(seconds: float):
def signal_handler(signum, frame):
raise TimeoutException("Timed out!")
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
@contextlib.contextmanager
def capture_io():
"""Capture stdout and stderr, and disable stdin."""
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
stdin_block = WriteOnlyStringIO()
with contextlib.redirect_stdout(stdout_capture):
with contextlib.redirect_stderr(stderr_capture):
with redirect_stdin(stdin_block):
yield stdout_capture, stderr_capture
@contextlib.contextmanager
def create_tempdir():
with tempfile.TemporaryDirectory() as dirname:
with chdir(dirname):
yield dirname
class TimeoutException(Exception):
pass
class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from"""
def read(self, *args, **kwargs):
raise IOError
def readline(self, *args, **kwargs):
raise IOError
def readlines(self, *args, **kwargs):
raise IOError
def readable(self, *args, **kwargs):
"""Returns True if the IO object can be read."""
return False
class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = "stdin"
@contextlib.contextmanager
def chdir(root):
if root == ".":
yield
return
cwd = os.getcwd()
os.chdir(root)
try:
yield
finally:
os.chdir(cwd)
def reliability_guard(maximum_memory_bytes: Optional[int] = None):
"""
This disables various destructive functions and prevents the generated code
from interfering with the test (e.g. fork bomb, killing other processes,
removing filesystem files, etc.)
WARNING
This function is NOT a security sandbox. Untrusted code, including, model-
generated code, should not be blindly executed outside of one. See the
Codex paper for more information about OpenAI's code sandbox, and proceed
with caution.
"""
if platform.uname().system != "Darwin":
# These resource limit calls seem to fail on macOS (Darwin), skip?
import resource
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
faulthandler.disable()
import builtins
builtins.exit = None
builtins.quit = None
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.kill = None
os.system = None
os.putenv = None
os.remove = None
os.removedirs = None
os.rmdir = None
os.fchdir = None
os.setuid = None
os.fork = None
os.forkpty = None
os.killpg = None
os.rename = None
os.renames = None
os.truncate = None
os.replace = None
os.unlink = None
os.fchmod = None
os.fchown = None
os.chmod = None
os.chown = None
os.chroot = None
os.fchdir = None
os.lchflags = None
os.lchmod = None
os.lchown = None
os.getcwd = None
os.chdir = None
import shutil
shutil.rmtree = None
shutil.move = None
shutil.chown = None
import subprocess
subprocess.Popen = None # type: ignore
__builtins__["help"] = None
import sys
sys.modules["ipdb"] = None
sys.modules["joblib"] = None
sys.modules["resource"] = None
sys.modules["psutil"] = None
sys.modules["tkinter"] = None
def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
"""Execute code in a subprocess with safety guards. Results are written to result_dict."""
with create_tempdir():
# These system calls are needed when cleaning up tempdir.
import os
import shutil
rmtree = shutil.rmtree
rmdir = os.rmdir
chdir = os.chdir
unlink = os.unlink
# Disable functionalities that can make destructive changes to the test.
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
# Default to failure
result_dict.update({
"success": False,
"stdout": "",
"stderr": "",
"timeout": False,
"memory_exceeded": False,
"error": None,
})
try:
exec_globals = {}
with capture_io() as (stdout_capture, stderr_capture):
with time_limit(timeout):
# WARNING
# This program exists to execute untrusted model-generated code. Although
# it is highly unlikely that model-generated code will do something overtly
# malicious in response to this test suite, model-generated code may act
# destructively due to a lack of model capability or alignment.
# Users are strongly encouraged to sandbox this evaluation suite so that it
# does not perform destructive actions on their host or network. For more
# information on how OpenAI sandboxes its code, see the accompanying paper.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
exec(code, exec_globals)
result_dict.update({
"success": True,
"stdout": stdout_capture.getvalue(),
"stderr": stderr_capture.getvalue(),
})
except TimeoutException:
result_dict.update({
"timeout": True,
"error": "Execution timed out",
})
except MemoryError as e:
result_dict.update({
"memory_exceeded": True,
"error": f"Memory limit exceeded: {e}",
})
except BaseException as e:
result_dict.update({
"error": f"{type(e).__name__}: {e}",
})
# Needed for cleaning up.
shutil.rmtree = rmtree
os.rmdir = rmdir
os.chdir = chdir
os.unlink = unlink
def execute_code(
code: str,
timeout: float = 5.0, # 5 seconds default
maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
) -> ExecutionResult:
"""
Execute Python code in a sandboxed environment.
Args:
code: Python code to execute as a string
timeout: Maximum execution time in seconds (default: 5.0)
maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable)
Returns:
ExecutionResult with success status, stdout/stderr, and error information
Example:
>>> result = execute_code("print('hello world')")
>>> result.success
True
>>> result.stdout
'hello world\\n'
"""
manager = multiprocessing.Manager()
result_dict = manager.dict()
p = multiprocessing.Process(
target=_unsafe_execute,
args=(code, timeout, maximum_memory_bytes, result_dict)
)
p.start()
p.join(timeout=timeout + 1)
if p.is_alive():
p.kill()
return ExecutionResult(
success=False,
stdout="",
stderr="",
error="Execution timed out (process killed)",
timeout=True,
memory_exceeded=False,
)
if not result_dict:
return ExecutionResult(
success=False,
stdout="",
stderr="",
error="Execution failed (no result returned)",
timeout=True,
memory_exceeded=False,
)
return ExecutionResult(
success=result_dict["success"],
stdout=result_dict["stdout"],
stderr=result_dict["stderr"],
error=result_dict["error"],
timeout=result_dict["timeout"],
memory_exceeded=result_dict["memory_exceeded"],
)

View File

@ -1,307 +0,0 @@
"""
GPT model (rewrite, a lot simpler)
Notable features:
- rotary embeddings (and no positional embeddings)
- QK norm
- untied weights for token embedding and lm_head
- relu^2 activation in MLP
- norm after token embedding
- no learnable params in rmsnorm
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
"""
import math
from functools import partial
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
@dataclass
class GPTConfig:
sequence_len: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
y2 = x1 * (-sin) + x2 * cos
out = torch.cat([y1, y2], 3) # re-assemble
out = out.to(x.dtype) # ensure input/output dtypes match
return out
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = norm(q), norm(k) # QK norm
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
# Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif Tq == 1:
# During inference but with a single query in this forward pass:
# The query has to attend to all the keys/values in the cache
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
else:
# During inference AND we have a chunk of queries in this forward pass:
# First, each query attends to all the cached keys/values (i.e. full prefix)
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
prefix_len = Tk - Tq
if prefix_len > 0: # can't be negative but could be zero
attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
# Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache)
x = x + self.mlp(norm(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# To support meta device initialization, we init the rotary embeddings here, but it's fake
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them, but assert fail if we ever reach that amount.
# In the future we can dynamically grow the cache, for now it's fine.
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
def init_weights(self):
self.apply(self._init_weights)
# zero out classifier weights
torch.nn.init.zeros_(self.lm_head.weight)
# zero out c_proj weights in all blocks
for block in self.transformer.h:
torch.nn.init.zeros_(block.mlp.c_proj.weight)
torch.nn.init.zeros_(block.attn.c_proj.weight)
# init the rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
# https://arxiv.org/pdf/2310.17813
fan_out = module.weight.size(0)
fan_in = module.weight.size(1)
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
# TODO: bump base theta more, e.g. 100K is more common more recently
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# autodetect the device from model embeddings
if device is None:
device = self.transformer.wte.weight.device
# stride the channels
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
# stride the time steps
t = torch.arange(seq_len, dtype=torch.float32, device=device)
# calculate the rotation frequencies at each (time, channel) pair
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
def get_device(self):
return self.transformer.wte.weight.device
def estimate_flops(self):
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
nparams = sum(p.numel() for p in self.parameters())
nparams_embedding = self.transformer.wte.weight.numel()
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
# Create the AdamW optimizer for the embedding and lm_head
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
if rank == 0:
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
]
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
# Create the Muon optimizer for the linear layers
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
MuonFactory = DistMuon if ddp else Muon
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
# Combine them the two optimizers into one list
optimizers = [adamw_optimizer, muon_optimizer]
for opt in optimizers:
for group in opt.param_groups:
group["initial_lr"] = group["lr"]
return optimizers
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
# Forward the trunk of the Transformer
x = self.transformer.wte(idx)
x = norm(x)
for block in self.transformer.h:
x = block(x, cos_sin, kv_cache)
x = norm(x)
# Forward the lm_head (compute logits)
softcap = 15
if targets is not None:
# training mode: compute and return the loss
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = logits.float() # use tf32/fp32 for logits
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss
else:
# inference mode: compute and return the logits
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
return logits
@torch.inference_mode()
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
"""
Naive autoregressive streaming inference.
To make it super simple, let's assume:
- batch size is 1
- ids and the yielded tokens are simple Python lists and ints
"""
assert isinstance(tokens, list)
device = self.get_device()
rng = None
if temperature > 0:
rng = torch.Generator(device=device)
rng.manual_seed(seed)
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
for _ in range(max_tokens):
logits = self.forward(ids) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
if temperature > 0:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
else:
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
ids = torch.cat((ids, next_ids), dim=1)
token = next_ids.item()
yield token

View File

@ -1,8 +0,0 @@
<svg width="400" height="400" xmlns="http://www.w3.org/2000/svg">
<defs>
<radialGradient id="g" cx="50%" cy="50%">
<stop offset="0%" style="stop-color:#667eea;stop-opacity:1"></stop>
<stop offset="100%" style="stop-color:#764ba2;stop-opacity:0.3"></stop>
</radialGradient>
</defs>
<path d="M315.9110991546882,231.0582854123025 L228.97777478867204,207.7645713530756 M284.8528137423857,284.8528137423857 L221.21320343559643,221.21320343559643 M231.0582854123025,315.9110991546882 L207.7645713530756,228.97777478867204 M168.94171458769753,315.9110991546882 L192.2354286469244,228.97777478867204 M115.14718625761431,284.8528137423857 L178.78679656440357,221.21320343559643 M84.08890084531181,231.05828541230252 L171.02222521132796,207.76457135307564 M84.0889008453118,168.9417145876975 L171.02222521132796,192.2354286469244 M115.14718625761425,115.14718625761435 L178.78679656440357,178.78679656440357 M168.94171458769753,84.0889008453118 L192.2354286469244,171.02222521132796 M231.05828541230244,84.08890084531178 L207.7645713530756,171.02222521132794 M284.8528137423857,115.14718625761428 L221.21320343559643,178.78679656440357 M315.91109915468815,168.94171458769742 L228.97777478867204,192.23542864692436 " stroke="url(#g)" stroke-width="6" stroke-linecap="round" fill="none"></path><path d="M200,-12 L212,0 L200,12 L188,0 Z" transform="translate(0,200)" fill="#000"></path></svg>

Before

Width:  |  Height:  |  Size: 1.4 KiB

View File

@ -1,65 +0,0 @@
"""
A number of functions that help with evaluating a base model.
"""
import math
import torch
import torch.distributed as dist
@torch.no_grad()
def evaluate_bpb(model, batches, steps, token_bytes):
"""
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
which is a tokenization vocab size-independent metric, meaning you are still comparing
apples:apples if you change the vocab size. The way this works is that instead of just
calculating the average loss as usual, you calculate the sum loss, and independently
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
the number of bytes that the target tokens represent.
The added complexity is so that:
1) All "normal" tokens are normalized by the length of the token in bytes
2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
In addition to evaluate_loss, we need the token_bytes tensor:
It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
each token id, or 0 if the token is to not be counted (e.g. special tokens).
"""
# record the losses
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
batch_iter = iter(batches)
for _ in range(steps):
x, y = next(batch_iter)
loss2d = model(x, y, loss_reduction='none') # (B, T)
loss2d = loss2d.view(-1) # flatten
y = y.view(-1) # flatten
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
valid = y >= 0
y_safe = torch.where(valid, y, torch.zeros_like(y))
# map valid targets to their byte length; ignored targets contribute 0 bytes
num_bytes2d = torch.where(
valid,
token_bytes[y_safe],
torch.zeros_like(y, dtype=token_bytes.dtype)
)
total_nats += (loss2d * (num_bytes2d > 0)).sum()
total_bytes += num_bytes2d.sum()
else:
# fast path: no ignored targets, safe to index directly
num_bytes2d = token_bytes[y]
total_nats += (loss2d * (num_bytes2d > 0)).sum()
total_bytes += num_bytes2d.sum()
# sum reduce across all ranks
world_size = dist.get_world_size() if dist.is_initialized() else 1
if world_size > 1:
dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
# move both to cpu, calculate bpb and return
total_nats = total_nats.item()
total_bytes = total_bytes.item()
if total_bytes == 0:
return float('inf')
bpb = total_nats / (math.log(2) * total_bytes)
return bpb

View File

@ -1,187 +0,0 @@
"""
Muon optimizer from Keller et al.
Also a lot of borrowing of ideas from modded-nanogpt.
"""
import torch
from torch import Tensor
import torch.distributed as dist
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- This optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use.
"""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params: list[Tensor] = [*params]
param_groups = []
for size in {p.numel() for p in params}:
group = dict(params=[p for p in params if p.numel() == size])
param_groups.append(group)
super().__init__(param_groups, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
params: list[Tensor] = group["params"]
for p in params:
g = p.grad
assert g is not None
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
class DistMuon(torch.optim.Optimizer):
"""
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via NewtonSchulz,
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
- reduce_scatter(AVG) for gradient averaging
- all_gather to replicate updated weights
Notes:
* Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
params like embeddings or scalars.
* Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
consolidate states beforehand.
Args:
params: iterable of Tensors
lr: learning rate
momentum: momentum coefficient in [0,1)
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
ns_steps: number of NewtonSchulz iterations for the orthogonalization
"""
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
nesterov: bool = True, ns_steps: int = 5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params = list(params)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
rank = dist.get_rank()
# Group all parameters by their shape
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
param_groups = []
for shape in shapes:
group_params = [p for p in params if p.shape == shape]
device, dtype = group_params[0].device, group_params[0].dtype
assert all(p.device == device for p in group_params)
assert all(p.dtype == dtype for p in group_params)
if rank == 0:
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
super().__init__(param_groups, defaults)
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
# Ensure all grads exist
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
# Kick off all the reduce scatter operations to average up the gradients across all ranks
all_reduce_futures = []
for group in self.param_groups:
params = group["params"]
zero_buffer = group["zero_buffer"]
# Go through params in groups of world_size.
for base_i in range(0, len(params), world_size):
# The compute owner of each param is rank i % world_size
owner_idx = base_i + rank
# each rank stacks up its chunk of world_size params into a list
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
# pad rs_input with the zero buffer to complete the group
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
# the output buffer gets strided across the group based on the rank
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
# reduce scatter the gradients within this group of world_size params
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
all_reduce_futures.append(work)
# Now each rank computes the update and gathers
future_idx = 0
all_gather_futures = []
for group in self.param_groups:
params = group["params"]
zero_buffer = group["zero_buffer"]
# Go through params in groups of world_size.
for base_i in range(0, len(params), world_size):
# The compute owner of each param is rank i % world_size
owner_idx = base_i + rank # calculate the index of the param that this rank owns
# Wait for the reduce scatter to complete
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
future_idx += 1
# Owner computes the Muon update, result is in its param
if owner_idx < len(params):
p = params[owner_idx]
g = p.grad # now averaged across ranks
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1.0 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
p.add_(g, alpha=-group["lr"] * scale)
# Replicate updated parameters to all ranks
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
ag_output = params[base_i:base_i + world_size]
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
all_gather_futures.append(work)
# Wait for all work to finish
torch.futures.collect_all(all_gather_futures).wait()

View File

@ -1,408 +0,0 @@
"""
Utilities for generating training report cards. More messy code than usual, will fix.
"""
import os
import re
import shutil
import subprocess
import socket
import datetime
import platform
import psutil
import torch
def run_command(cmd):
"""Run a shell command and return output, or None if it fails."""
try:
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
if result.returncode == 0:
return result.stdout.strip()
return None
except:
return None
def get_git_info():
"""Get current git commit, branch, and dirty status."""
info = {}
info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
# Check if repo is dirty (has uncommitted changes)
status = run_command("git status --porcelain")
info['dirty'] = bool(status) if status is not None else False
# Get commit message
info['message'] = run_command("git log -1 --pretty=%B") or ""
info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
return info
def get_gpu_info():
"""Get GPU information."""
if not torch.cuda.is_available():
return {"available": False}
num_devices = torch.cuda.device_count()
info = {
"available": True,
"count": num_devices,
"names": [],
"memory_gb": []
}
for i in range(num_devices):
props = torch.cuda.get_device_properties(i)
info["names"].append(props.name)
info["memory_gb"].append(props.total_memory / (1024**3))
# Get CUDA version
info["cuda_version"] = torch.version.cuda or "unknown"
return info
def get_system_info():
"""Get system information."""
info = {}
# Basic system info
info['hostname'] = socket.gethostname()
info['platform'] = platform.system()
info['python_version'] = platform.python_version()
info['torch_version'] = torch.__version__
# CPU and memory
info['cpu_count'] = psutil.cpu_count(logical=False)
info['cpu_count_logical'] = psutil.cpu_count(logical=True)
info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
# User and environment
info['user'] = os.environ.get('USER', 'unknown')
info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
info['working_dir'] = os.getcwd()
return info
def estimate_cost(gpu_info, runtime_hours=None):
"""Estimate training cost based on GPU type and runtime."""
# Rough pricing, from Lambda Cloud
default_rate = 2.0
gpu_hourly_rates = {
"H100": 3.00,
"A100": 1.79,
"V100": 0.55,
}
if not gpu_info.get("available"):
return None
# Try to identify GPU type from name
hourly_rate = None
gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown"
for gpu_type, rate in gpu_hourly_rates.items():
if gpu_type in gpu_name:
hourly_rate = rate * gpu_info["count"]
break
if hourly_rate is None:
hourly_rate = default_rate * gpu_info["count"] # Default estimate
return {
"hourly_rate": hourly_rate,
"gpu_type": gpu_name,
"estimated_total": hourly_rate * runtime_hours if runtime_hours else None
}
def generate_header():
"""Generate the header for a training report."""
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
git_info = get_git_info()
gpu_info = get_gpu_info()
sys_info = get_system_info()
cost_info = estimate_cost(gpu_info)
header = f"""# nanochat training report
Generated: {timestamp}
## Environment
### Git Information
- Branch: {git_info['branch']}
- Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"}
- Message: {git_info['message']}
### Hardware
- Platform: {sys_info['platform']}
- CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical)
- Memory: {sys_info['memory_gb']:.1f} GB
"""
if gpu_info.get("available"):
gpu_names = ", ".join(set(gpu_info["names"]))
total_vram = sum(gpu_info["memory_gb"])
header += f"""- GPUs: {gpu_info['count']}x {gpu_names}
- GPU Memory: {total_vram:.1f} GB total
- CUDA Version: {gpu_info['cuda_version']}
"""
else:
header += "- GPUs: None available\n"
if cost_info and cost_info["hourly_rate"] > 0:
header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n"""
header += f"""
### Software
- Python: {sys_info['python_version']}
- PyTorch: {sys_info['torch_version']}
"""
# bloat metrics: package all of the source code and assess its weight
packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml')
num_chars = len(packaged)
num_lines = len(packaged.split('\n'))
num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')])
num_tokens = num_chars // 4 # assume approximately 4 chars per token
# count dependencies via uv.lock
uv_lock_lines = 0
if os.path.exists('uv.lock'):
with open('uv.lock', 'r', encoding='utf-8') as f:
uv_lock_lines = len(f.readlines())
header += f"""
### Bloat
- Characters: {num_chars:,}
- Lines: {num_lines:,}
- Files: {num_files:,}
- Tokens (approx): {num_tokens:,}
- Dependencies (uv.lock lines): {uv_lock_lines:,}
"""
return header
# -----------------------------------------------------------------------------
def slugify(text):
"""Slugify a text string."""
return text.lower().replace(" ", "-")
# the expected files and their order
EXPECTED_FILES = [
"tokenizer-training.md",
"tokenizer-evaluation.md",
"base-model-training.md",
"base-model-loss.md",
"base-model-evaluation.md",
"midtraining.md",
"chat-evaluation-mid.md",
"chat-sft.md",
"chat-evaluation-sft.md",
"chat-rl.md",
"chat-evaluation-rl.md",
]
# the metrics we're currently interested in
chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
def extract(section, keys):
"""simple def to extract a single key from a section"""
if not isinstance(keys, list):
keys = [keys] # convenience
out = {}
for line in section.split("\n"):
for key in keys:
if key in line:
out[key] = line.split(":")[1].strip()
return out
def extract_timestamp(content, prefix):
"""Extract timestamp from content with given prefix."""
for line in content.split('\n'):
if line.startswith(prefix):
time_str = line.split(":", 1)[1].strip()
try:
return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
except:
pass
return None
class Report:
"""Maintains a bunch of logs, generates a final markdown report."""
def __init__(self, report_dir):
os.makedirs(report_dir, exist_ok=True)
self.report_dir = report_dir
def log(self, section, data):
"""Log a section of data to the report."""
slug = slugify(section)
file_name = f"{slug}.md"
file_path = os.path.join(self.report_dir, file_name)
with open(file_path, "w", encoding="utf-8") as f:
f.write(f"## {section}\n")
f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
for item in data:
if not item:
# skip falsy values like None or empty dict etc.
continue
if isinstance(item, str):
# directly write the string
f.write(item)
else:
# render a dict
for k, v in item.items():
if isinstance(v, float):
vstr = f"{v:.4f}"
elif isinstance(v, int) and v >= 10000:
vstr = f"{v:,.0f}"
else:
vstr = str(v)
f.write(f"- {k}: {vstr}\n")
f.write("\n")
return file_path
def generate(self):
"""Generate the final report."""
report_dir = self.report_dir
report_file = os.path.join(report_dir, "report.md")
print(f"Generating report to {report_file}")
final_metrics = {} # the most important final metrics we'll add as table at the end
start_time = None
end_time = None
with open(report_file, "w", encoding="utf-8") as out_file:
# write the header first
header_file = os.path.join(report_dir, "header.md")
if os.path.exists(header_file):
with open(header_file, "r", encoding="utf-8") as f:
header_content = f.read()
out_file.write(header_content)
start_time = extract_timestamp(header_content, "Run started:")
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
bloat_data = bloat_data.group(1) if bloat_data else ""
else:
start_time = None # will cause us to not write the total wall clock time
bloat_data = "[bloat data missing]"
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
# process all the individual sections
for file_name in EXPECTED_FILES:
section_file = os.path.join(report_dir, file_name)
if not os.path.exists(section_file):
print(f"Warning: {section_file} does not exist, skipping")
continue
with open(section_file, "r", encoding="utf-8") as in_file:
section = in_file.read()
# Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
if "rl" not in file_name:
# Skip RL sections for end_time calculation because RL is experimental
end_time = extract_timestamp(section, "timestamp:")
# extract the most important metrics from the sections
if file_name == "base-model-evaluation.md":
final_metrics["base"] = extract(section, "CORE")
if file_name == "chat-evaluation-mid.md":
final_metrics["mid"] = extract(section, chat_metrics)
if file_name == "chat-evaluation-sft.md":
final_metrics["sft"] = extract(section, chat_metrics)
if file_name == "chat-evaluation-rl.md":
final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
# append this section of the report
out_file.write(section)
out_file.write("\n")
# add the final metrics table
out_file.write("## Summary\n\n")
# Copy over the bloat metrics from the header
out_file.write(bloat_data)
out_file.write("\n\n")
# Collect all unique metric names
all_metrics = set()
for stage_metrics in final_metrics.values():
all_metrics.update(stage_metrics.keys())
# Custom ordering: CORE first, ChatCORE last, rest in middle
all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
# Fixed column widths
stages = ["base", "mid", "sft", "rl"]
metric_width = 15
value_width = 8
# Write table header
header = f"| {'Metric'.ljust(metric_width)} |"
for stage in stages:
header += f" {stage.upper().ljust(value_width)} |"
out_file.write(header + "\n")
# Write separator
separator = f"|{'-' * (metric_width + 2)}|"
for stage in stages:
separator += f"{'-' * (value_width + 2)}|"
out_file.write(separator + "\n")
# Write table rows
for metric in all_metrics:
row = f"| {metric.ljust(metric_width)} |"
for stage in stages:
value = final_metrics.get(stage, {}).get(metric, "-")
row += f" {str(value).ljust(value_width)} |"
out_file.write(row + "\n")
out_file.write("\n")
# Calculate and write total wall clock time
if start_time and end_time:
duration = end_time - start_time
total_seconds = int(duration.total_seconds())
hours = total_seconds // 3600
minutes = (total_seconds % 3600) // 60
out_file.write(f"Total wall clock time: {hours}h{minutes}m\n")
else:
out_file.write("Total wall clock time: unknown\n")
# also cp the report.md file to current directory
print(f"Copying report.md to current directory for convenience")
shutil.copy(report_file, "report.md")
return report_file
def reset(self):
"""Reset the report."""
# Remove section files
for file_name in EXPECTED_FILES:
file_path = os.path.join(self.report_dir, file_name)
if os.path.exists(file_path):
os.remove(file_path)
# Remove report.md if it exists
report_file = os.path.join(self.report_dir, "report.md")
if os.path.exists(report_file):
os.remove(report_file)
# Generate and write the header section with start timestamp
header_file = os.path.join(self.report_dir, "header.md")
header = generate_header()
start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(header_file, "w", encoding="utf-8") as f:
f.write(header)
f.write(f"Run started: {start_time}\n\n---\n\n")
print(f"Reset report and wrote header to {header_file}")
# -----------------------------------------------------------------------------
# nanochat-specific convenience functions
class DummyReport:
def log(self, *args, **kwargs):
pass
def reset(self, *args, **kwargs):
pass
def get_report():
# just for convenience, only rank 0 logs to report
from nanochat.common import get_base_dir, get_dist_info
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp_rank == 0:
report_dir = os.path.join(get_base_dir(), "report")
return Report(report_dir)
else:
return DummyReport()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
args = parser.parse_args()
if args.command == "generate":
get_report().generate()
elif args.command == "reset":
get_report().reset()

View File

@ -1,238 +0,0 @@
"""
Convert a nanochat checkpoint into a HuggingFace-style folder.
Usage (example):
python -m nanochat.to_hf --source base --output hf-export/base
Notes
- Assumes checkpoints live under ~/.cache/nanochat/<source>_checkpoints/ (same as training scripts).
- The exported model can be loaded with transformers via:
AutoModelForCausalLM.from_pretrained(<export_dir>, trust_remote_code=True)
- KV cache is not implemented in the HF wrapper; generation works but is not incremental.
"""
import argparse
import os
import shutil
from typing import Optional
import torch
import torch.nn.functional as F
try:
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerationMixin
except ImportError as exc:
raise SystemExit(
"transformers is required for HF export. Run `uv sync` (with the hf extra) first."
) from exc
from nanochat.checkpoint_manager import load_model
from nanochat.gpt import GPT, GPTConfig
from nanochat.common import get_base_dir
from nanochat.tokenizer import get_tokenizer
class NanoChatHFConfig(PretrainedConfig):
model_type = "nanochat"
def __init__(
self,
sequence_len: int = 1024,
vocab_size: int = 50304,
n_layer: int = 12,
n_head: int = 6,
n_kv_head: int = 6,
n_embd: int = 768,
**kwargs,
):
# Don't tie embeddings; nanochat uses untied wte/lm_head
kwargs.setdefault("tie_word_embeddings", False)
super().__init__(**kwargs)
self.sequence_len = sequence_len
self.vocab_size = vocab_size
self.n_layer = n_layer
self.n_head = n_head
self.n_kv_head = n_kv_head
self.n_embd = n_embd
class NanoChatHFForCausalLM(PreTrainedModel, GenerationMixin):
config_class = NanoChatHFConfig
def __init__(self, config: NanoChatHFConfig):
super().__init__(config)
gpt_cfg = GPTConfig(
sequence_len=config.sequence_len,
vocab_size=config.vocab_size,
n_layer=config.n_layer,
n_head=config.n_head,
n_kv_head=config.n_kv_head,
n_embd=config.n_embd,
)
self.model = GPT(gpt_cfg)
def get_input_embeddings(self):
return self.model.transformer.wte
def set_input_embeddings(self, value):
self.model.transformer.wte = value
def get_output_embeddings(self):
return self.model.lm_head
def tie_weights(self):
# nanochat uses untied embeddings; override to no-op
return
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, # unused
labels: Optional[torch.LongTensor] = None,
past_key_values=None, # not implemented
**_: dict,
) -> CausalLMOutputWithPast:
if input_ids is None:
raise ValueError("input_ids must be provided")
logits = self.model(input_ids)
loss = None
if labels is not None:
# Align shapes for CE: shift labels to match logits
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
def copy_tokenizer_files(output_dir: str):
base_dir = get_base_dir()
tokenizer_dir = os.path.join(base_dir, "tokenizer")
if not os.path.isdir(tokenizer_dir):
print(f"[to_hf] tokenizer directory not found at {tokenizer_dir}, skipping tokenizer export")
return
for name in os.listdir(tokenizer_dir):
src = os.path.join(tokenizer_dir, name)
dst = os.path.join(output_dir, name)
if os.path.isdir(src):
shutil.copytree(src, dst, dirs_exist_ok=True)
else:
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy2(src, dst)
print(f"[to_hf] Copied tokenizer files from {tokenizer_dir} to {output_dir}")
def write_hf_code(output_dir: str):
cfg_py = r'''
from transformers import PretrainedConfig
class NanoChatHFConfig(PretrainedConfig):
model_type = "nanochat"
def __init__(self, sequence_len=1024, vocab_size=50304, n_layer=12, n_head=6, n_kv_head=6, n_embd=768, **kwargs):
kwargs.setdefault("tie_word_embeddings", False)
super().__init__(**kwargs)
self.sequence_len = sequence_len
self.vocab_size = vocab_size
self.n_layer = n_layer
self.n_head = n_head
self.n_kv_head = n_kv_head
self.n_embd = n_embd
'''
mdl_py = r'''
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from configuration_nanochat import NanoChatHFConfig
from nanochat.gpt import GPT, GPTConfig
class NanoChatHFForCausalLM(PreTrainedModel, GenerationMixin):
config_class = NanoChatHFConfig
def __init__(self, config: NanoChatHFConfig):
super().__init__(config)
gpt_cfg = GPTConfig(
sequence_len=config.sequence_len,
vocab_size=config.vocab_size,
n_layer=config.n_layer,
n_head=config.n_head,
n_kv_head=config.n_kv_head,
n_embd=config.n_embd,
)
self.model = GPT(gpt_cfg)
def get_input_embeddings(self):
return self.model.transformer.wte
def set_input_embeddings(self, value):
self.model.transformer.wte = value
def get_output_embeddings(self):
return self.model.lm_head
def tie_weights(self):
return
def forward(self, input_ids=None, attention_mask=None, labels=None, past_key_values=None, **_):
if input_ids is None:
raise ValueError("input_ids must be provided")
logits = self.model(input_ids)
loss = None
if labels is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1)
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
'''
with open(os.path.join(output_dir, "configuration_nanochat.py"), "w") as f:
f.write(cfg_py)
with open(os.path.join(output_dir, "modeling_nanochat.py"), "w") as f:
f.write(mdl_py)
def export_to_hf(source: str, output_dir: str, model_tag: Optional[str], step: Optional[int]):
device = torch.device("cpu")
model, tokenizer, meta = load_model(source, device=device, phase="eval", model_tag=model_tag, step=step)
cfg_kwargs = meta["model_config"]
hf_config = NanoChatHFConfig(**cfg_kwargs)
hf_model = NanoChatHFForCausalLM(hf_config)
hf_model.model.load_state_dict(model.state_dict(), strict=True)
os.makedirs(output_dir, exist_ok=True)
# Tell transformers how to load custom code when trust_remote_code=True
hf_model.config.auto_map = {
"AutoConfig": "configuration_nanochat.NanoChatHFConfig",
"AutoModelForCausalLM": "modeling_nanochat.NanoChatHFForCausalLM",
}
hf_model.config.architectures = ["NanoChatHFForCausalLM"]
hf_model.save_pretrained(output_dir, safe_serialization=False)
# Best effort: drop tokenizer files alongside weights
copy_tokenizer_files(output_dir)
print(f"[to_hf] Exported {source} checkpoint to {output_dir}")
write_hf_code(output_dir)
def main():
parser = argparse.ArgumentParser(description="Export nanochat checkpoint to HuggingFace format")
parser.add_argument("--source", choices=["base", "mid", "sft", "rl"], default="base", help="Which checkpoint family to export")
parser.add_argument("--model-tag", type=str, default=None, help="Model tag (e.g., d20). Defaults to largest available.")
parser.add_argument("--step", type=int, default=None, help="Checkpoint step. Defaults to latest.")
parser.add_argument("--output", type=str, default="hf-export", help="Output directory for HF files")
args = parser.parse_args()
export_to_hf(args.source, args.output, args.model_tag, args.step)
if __name__ == "__main__":
main()

View File

@ -1,398 +0,0 @@
"""
BPE Tokenizer in the style of GPT-4.
Two implementations are available:
1) HuggingFace Tokenizer that can do both training and inference but is really confusing
2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
"""
import os
import copy
from functools import lru_cache
SPECIAL_TOKENS = [
# every document begins with the Beginning of Sequence (BOS) token that delimits documents
"<|bos|>",
# tokens below are only used during finetuning to render Conversations into token ids
"<|user_start|>", # user messages
"<|user_end|>",
"<|assistant_start|>", # assistant messages
"<|assistant_end|>",
"<|python_start|>", # assistant invokes python REPL tool
"<|python_end|>",
"<|output_start|>", # python REPL outputs back to assistant
"<|output_end|>",
]
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
# I haven't validated that this is actually a good idea, TODO.
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
# -----------------------------------------------------------------------------
# Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
from tokenizers import Tokenizer as HFTokenizer
from tokenizers import pre_tokenizers, decoders, Regex
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
class HuggingFaceTokenizer:
"""Light wrapper around HuggingFace Tokenizer for some utilities"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
@classmethod
def from_pretrained(cls, hf_path):
# init from a HuggingFace pretrained tokenizer (e.g. "gpt2")
tokenizer = HFTokenizer.from_pretrained(hf_path)
return cls(tokenizer)
@classmethod
def from_directory(cls, tokenizer_dir):
# init from a local directory on disk (e.g. "out/tokenizer")
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
tokenizer = HFTokenizer.from_file(tokenizer_path)
return cls(tokenizer)
@classmethod
def train_from_iterator(cls, text_iterator, vocab_size):
# train from an iterator of text
# Configure the HuggingFace Tokenizer
tokenizer = HFTokenizer(BPE(
byte_fallback=True, # needed!
unk_token=None,
fuse_unk=False,
))
# Normalizer: None
tokenizer.normalizer = None
# Pre-tokenizer: GPT-4 style
# the regex pattern used by GPT-4 to split text into groups before BPE
# NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
# very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
# (but I haven't validated this! TODO)
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
])
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
tokenizer.decoder = decoders.ByteLevel()
# Post-processor: None
tokenizer.post_processor = None
# Trainer: BPE
trainer = BpeTrainer(
vocab_size=vocab_size,
show_progress=True,
min_frequency=0, # no minimum frequency
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
special_tokens=SPECIAL_TOKENS,
)
# Kick off the training
tokenizer.train_from_iterator(text_iterator, trainer)
return cls(tokenizer)
def get_vocab_size(self):
return self.tokenizer.get_vocab_size()
def get_special_tokens(self):
special_tokens_map = self.tokenizer.get_added_tokens_decoder()
special_tokens = [w.content for w in special_tokens_map.values()]
return special_tokens
def id_to_token(self, id):
return self.tokenizer.id_to_token(id)
def _encode_one(self, text, prepend=None, append=None):
# encode a single string
# prepend/append can be either a string of a special token or a token id directly.
assert isinstance(text, str)
ids = []
if prepend is not None:
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
ids.append(prepend_id)
ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
if append is not None:
append_id = append if isinstance(append, int) else self.encode_special(append)
ids.append(append_id)
return ids
def encode_special(self, text):
# encode a single special token via exact match
return self.tokenizer.token_to_id(text)
def get_bos_token_id(self):
bos = self.encode_special("<|bos|>")
return bos
def encode(self, text, *args, **kwargs):
if isinstance(text, str):
return self._encode_one(text, *args, **kwargs)
elif isinstance(text, list):
return [self._encode_one(t, *args, **kwargs) for t in text]
else:
raise ValueError(f"Invalid input type: {type(text)}")
def __call__(self, *args, **kwargs):
return self.encode(*args, **kwargs)
def decode(self, ids):
return self.tokenizer.decode(ids, skip_special_tokens=False)
def save(self, tokenizer_dir):
# save the tokenizer to disk
os.makedirs(tokenizer_dir, exist_ok=True)
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
self.tokenizer.save(tokenizer_path)
print(f"Saved tokenizer to {tokenizer_path}")
# -----------------------------------------------------------------------------
# Tokenizer based on rustbpe + tiktoken combo
import pickle
import rustbpe
import tiktoken
class RustBPETokenizer:
"""Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
def __init__(self, enc, bos_token):
self.enc = enc
self.bos_token_id = self.encode_special(bos_token)
@classmethod
def train_from_iterator(cls, text_iterator, vocab_size):
# 1) train using rustbpe
tokenizer = rustbpe.Tokenizer()
# the special tokens are inserted later in __init__, we don't train them here
vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
# 2) construct the associated tiktoken encoding for inference
pattern = tokenizer.get_pattern()
mergeable_ranks_list = tokenizer.get_mergeable_ranks()
mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
tokens_offset = len(mergeable_ranks)
special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
enc = tiktoken.Encoding(
name="rustbpe",
pat_str=pattern,
mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
)
return cls(enc, "<|bos|>")
@classmethod
def from_directory(cls, tokenizer_dir):
pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
with open(pickle_path, "rb") as f:
enc = pickle.load(f)
return cls(enc, "<|bos|>")
@classmethod
def from_pretrained(cls, tiktoken_name):
# https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py
enc = tiktoken.get_encoding(tiktoken_name)
# tiktoken calls the special document delimiter token "<|endoftext|>"
# yes this is confusing because this token is almost always PREPENDED to the beginning of the document
# it most often is used to signal the start of a new sequence to the LLM during inference etc.
# so in nanoChat we always use "<|bos|>" short for "beginning of sequence", but historically it is often called "<|endoftext|>".
return cls(enc, "<|endoftext|>")
def get_vocab_size(self):
return self.enc.n_vocab
def get_special_tokens(self):
return self.enc.special_tokens_set
def id_to_token(self, id):
return self.enc.decode([id])
@lru_cache(maxsize=32)
def encode_special(self, text):
return self.enc.encode_single_token(text)
def get_bos_token_id(self):
return self.bos_token_id
def encode(self, text, prepend=None, append=None, num_threads=8):
# text can be either a string or a list of strings
if prepend is not None:
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
if append is not None:
append_id = append if isinstance(append, int) else self.encode_special(append)
if isinstance(text, str):
ids = self.enc.encode_ordinary(text)
if prepend is not None:
ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
if append is not None:
ids.append(append_id)
elif isinstance(text, list):
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
if prepend is not None:
for ids_row in ids:
ids_row.insert(0, prepend_id) # TODO: same
if append is not None:
for ids_row in ids:
ids_row.append(append_id)
else:
raise ValueError(f"Invalid input type: {type(text)}")
return ids
def __call__(self, *args, **kwargs):
return self.encode(*args, **kwargs)
def decode(self, ids):
return self.enc.decode(ids)
def save(self, tokenizer_dir):
# save the encoding object to disk
os.makedirs(tokenizer_dir, exist_ok=True)
pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
with open(pickle_path, "wb") as f:
pickle.dump(self.enc, f)
print(f"Saved tokenizer encoding to {pickle_path}")
def render_conversation(self, conversation, max_tokens=2048):
"""
Tokenize a single Chat conversation (which we call a "doc" or "document" here).
Returns:
- ids: list[int] is a list of token ids of this rendered conversation
- mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
"""
# ids, masks that we will return and a helper function to help build them up.
ids, mask = [], []
def add_tokens(token_ids, mask_val):
if isinstance(token_ids, int):
token_ids = [token_ids]
ids.extend(token_ids)
mask.extend([mask_val] * len(token_ids))
# sometimes the first message is a system message...
# => just merge it with the second (user) message
if conversation["messages"][0]["role"] == "system":
# some conversation surgery is necessary here for now...
conversation = copy.deepcopy(conversation) # avoid mutating the original
messages = conversation["messages"]
assert messages[1]["role"] == "user", "System message must be followed by a user message"
messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
messages = messages[1:]
else:
messages = conversation["messages"]
assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"
# fetch all the special tokens we need
bos = self.get_bos_token_id()
user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
# now we can tokenize the conversation
add_tokens(bos, 0)
for i, message in enumerate(messages):
# some sanity checking here around assumptions, to prevent footguns
must_be_from = "user" if i % 2 == 0 else "assistant"
assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
# content can be either a simple string or a list of parts (e.g. containing tool calls)
content = message["content"]
if message["role"] == "user":
assert isinstance(content, str), "User messages are simply expected to be strings"
value_ids = self.encode(content)
add_tokens(user_start, 0)
add_tokens(value_ids, 0)
add_tokens(user_end, 0)
elif message["role"] == "assistant":
add_tokens(assistant_start, 0)
if isinstance(content, str):
# simple string => simply add the tokens
value_ids = self.encode(content)
add_tokens(value_ids, 1)
elif isinstance(content, list):
for part in content:
value_ids = self.encode(part["text"])
if part["type"] == "text":
# string part => simply add the tokens
add_tokens(value_ids, 1)
elif part["type"] == "python":
# python tool call => add the tokens inside <|python_start|> and <|python_end|>
add_tokens(python_start, 1)
add_tokens(value_ids, 1)
add_tokens(python_end, 1)
elif part["type"] == "python_output":
# python output => add the tokens inside <|output_start|> and <|output_end|>
# none of these tokens are supervised because the tokens come from Python at test time
add_tokens(output_start, 0)
add_tokens(value_ids, 0)
add_tokens(output_end, 0)
else:
raise ValueError(f"Unknown part type: {part['type']}")
else:
raise ValueError(f"Unknown content type: {type(content)}")
add_tokens(assistant_end, 1)
# truncate to max_tokens tokens MAX (helps prevent OOMs)
ids = ids[:max_tokens]
mask = mask[:max_tokens]
return ids, mask
def visualize_tokenization(self, ids, mask, with_token_id=False):
"""Small helper function useful in debugging: visualize the tokenization of render_conversation"""
RED = '\033[91m'
GREEN = '\033[92m'
RESET = '\033[0m'
GRAY = '\033[90m'
tokens = []
for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
token_str = self.decode([token_id])
color = GREEN if mask_val == 1 else RED
tokens.append(f"{color}{token_str}{RESET}")
if with_token_id:
tokens.append(f"{GRAY}({token_id}){RESET}")
return '|'.join(tokens)
def render_for_completion(self, conversation):
"""
Used during Reinforcement Learning. In that setting, we want to
render the conversation priming the Assistant for a completion.
Unlike the Chat SFT case, we don't need to return the mask.
"""
# We have some surgery to do: we need to pop the last message (of the Assistant)
conversation = copy.deepcopy(conversation) # avoid mutating the original
messages = conversation["messages"]
assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
messages.pop() # remove the last message (of the Assistant) inplace
# Now tokenize the conversation
ids, mask = self.render_conversation(conversation)
# Finally, to prime the Assistant for a completion, append the Assistant start token
assistant_start = self.encode_special("<|assistant_start|>")
ids.append(assistant_start)
return ids
# -----------------------------------------------------------------------------
# nanochat-specific convenience functions
def get_tokenizer():
from nanochat.common import get_base_dir
base_dir = get_base_dir()
tokenizer_dir = os.path.join(base_dir, "tokenizer")
# return HuggingFaceTokenizer.from_directory(tokenizer_dir)
return RustBPETokenizer.from_directory(tokenizer_dir)
def get_token_bytes(device="cpu"):
import torch
from nanochat.common import get_base_dir
base_dir = get_base_dir()
tokenizer_dir = os.path.join(base_dir, "tokenizer")
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
with open(token_bytes_path, "rb") as f:
token_bytes = torch.load(f, map_location=device)
return token_bytes

View File

@ -1,562 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
<title>NanoChat</title>
<link rel="icon" type="image/svg+xml" href="/logo.svg">
<style>
:root {
color-scheme: light;
}
* {
box-sizing: border-box;
}
body {
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
background-color: #ffffff;
color: #111827;
min-height: 100dvh;
margin: 0;
display: flex;
flex-direction: column;
}
.header {
background-color: #ffffff;
padding: 1.25rem 1.5rem;
}
.header-left {
display: flex;
align-items: center;
gap: 0.75rem;
}
.header-logo {
height: 32px;
width: auto;
}
.header h1 {
font-size: 1.25rem;
font-weight: 600;
margin: 0;
color: #111827;
}
.new-conversation-btn {
width: 32px;
height: 32px;
padding: 0;
border: 1px solid #e5e7eb;
border-radius: 0.5rem;
background-color: #ffffff;
color: #6b7280;
cursor: pointer;
display: flex;
align-items: center;
justify-content: center;
transition: all 0.2s ease;
}
.new-conversation-btn:hover {
background-color: #f3f4f6;
border-color: #d1d5db;
color: #374151;
}
.chat-container {
flex: 1;
overflow-y: auto;
background-color: #ffffff;
}
.chat-wrapper {
max-width: 48rem;
margin: 0 auto;
padding: 2rem 1.5rem 3rem;
display: flex;
flex-direction: column;
gap: 0.75rem;
}
.message {
display: flex;
justify-content: flex-start;
margin-bottom: 0.5rem;
color: #0d0d0d;
}
.message.assistant {
justify-content: flex-start;
}
.message.user {
justify-content: flex-end;
}
.message-content {
white-space: pre-wrap;
line-height: 1.6;
max-width: 100%;
}
.message.assistant .message-content {
background: transparent;
border: none;
padding: 0.25rem 0;
cursor: pointer;
border-radius: 0.5rem;
padding: 0.5rem;
margin-left: -0.5rem;
transition: background-color 0.2s ease;
}
.message.assistant .message-content:hover {
background-color: #f9fafb;
}
.message.user .message-content {
background-color: #f3f4f6;
border-radius: 1.25rem;
padding: 0.8rem 1rem;
max-width: 65%;
cursor: pointer;
transition: background-color 0.2s ease;
}
.message.user .message-content:hover {
background-color: #e5e7eb;
}
.message.console .message-content {
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
font-size: 0.875rem;
background-color: #fafafa;
padding: 0.75rem 1rem;
color: #374151;
max-width: 80%;
}
.input-container {
background-color: #ffffff;
padding: 1rem;
padding-bottom: calc(1rem + env(safe-area-inset-bottom))
}
.input-wrapper {
max-width: 48rem;
margin: 0 auto;
display: flex;
gap: 0.75rem;
align-items: flex-end;
}
.chat-input {
flex: 1;
padding: 0.8rem 1rem;
border: 1px solid #d1d5db;
border-radius: 0.75rem;
background-color: #ffffff;
color: #111827;
font-size: 1rem;
line-height: 1.5;
resize: none;
outline: none;
min-height: 54px;
max-height: 200px;
transition: border-color 0.2s ease, box-shadow 0.2s ease;
}
.chat-input::placeholder {
color: #9ca3af;
}
.chat-input:focus {
border-color: #2563eb;
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
}
.send-button {
flex-shrink: 0;
padding: 0;
width: 54px;
height: 54px;
border: 1px solid #111827;
border-radius: 0.75rem;
background-color: #111827;
color: #ffffff;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
}
.send-button:hover:not(:disabled) {
background-color: #2563eb;
border-color: #2563eb;
}
.send-button:disabled {
cursor: not-allowed;
border-color: #d1d5db;
background-color: #e5e7eb;
color: #9ca3af;
}
.typing-indicator {
display: inline-block;
color: #6b7280;
letter-spacing: 0.15em;
}
.typing-indicator::after {
content: '···';
animation: typing 1.4s infinite;
}
@keyframes typing {
0%, 60%, 100% { opacity: 0.2; }
30% { opacity: 1; }
}
.error-message {
background-color: #fee2e2;
border: 1px solid #fecaca;
color: #b91c1c;
padding: 0.75rem 1rem;
border-radius: 0.75rem;
margin-top: 0.5rem;
}
</style>
</head>
<body>
<div class="header">
<div class="header-left">
<button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
<svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M12 5v14"></path>
<path d="M5 12h14"></path>
</svg>
</button>
<h1>nanochat</h1>
</div>
</div>
<div class="chat-container" id="chatContainer">
<div class="chat-wrapper" id="chatWrapper">
<!-- Messages will be added here -->
</div>
</div>
<div class="input-container">
<div class="input-wrapper">
<textarea
id="chatInput"
class="chat-input"
placeholder="Ask anything"
rows="1"
onkeydown="handleKeyDown(event)"
></textarea>
<button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
<svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M22 2L11 13"></path>
<path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
</svg>
</button>
</div>
</div>
<script>
const API_URL = '';
const chatContainer = document.getElementById('chatContainer');
const chatWrapper = document.getElementById('chatWrapper');
const chatInput = document.getElementById('chatInput');
const sendButton = document.getElementById('sendButton');
let messages = [];
let isGenerating = false;
let currentTemperature = 0.8;
let currentTopK = 50;
chatInput.addEventListener('input', function() {
this.style.height = 'auto';
this.style.height = Math.min(this.scrollHeight, 200) + 'px';
sendButton.disabled = !this.value.trim() || isGenerating;
});
function handleKeyDown(event) {
if (event.key === 'Enter' && !event.shiftKey) {
event.preventDefault();
sendMessage();
}
}
document.addEventListener('keydown', function(event) {
// Ctrl+Shift+N for new conversation
if (event.ctrlKey && event.shiftKey && event.key === 'N') {
event.preventDefault();
if (!isGenerating) {
newConversation();
}
}
});
function newConversation() {
messages = [];
chatWrapper.innerHTML = '';
chatInput.value = '';
chatInput.style.height = 'auto';
sendButton.disabled = false;
isGenerating = false;
chatInput.focus();
}
function addMessage(role, content, messageIndex = null) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${role}`;
const contentDiv = document.createElement('div');
contentDiv.className = 'message-content';
contentDiv.textContent = content;
// Add click handler for user messages to enable editing
if (role === 'user' && messageIndex !== null) {
contentDiv.setAttribute('data-message-index', messageIndex);
contentDiv.setAttribute('title', 'Click to edit and restart from here');
contentDiv.addEventListener('click', function() {
if (!isGenerating) {
editMessage(messageIndex);
}
});
}
// Add click handler for assistant messages to enable regeneration
if (role === 'assistant' && messageIndex !== null) {
contentDiv.setAttribute('data-message-index', messageIndex);
contentDiv.setAttribute('title', 'Click to regenerate this response');
contentDiv.addEventListener('click', function() {
if (!isGenerating) {
regenerateMessage(messageIndex);
}
});
}
messageDiv.appendChild(contentDiv);
chatWrapper.appendChild(messageDiv);
chatContainer.scrollTop = chatContainer.scrollHeight;
return contentDiv;
}
function editMessage(messageIndex) {
// Find the message in the messages array
if (messageIndex < 0 || messageIndex >= messages.length) return;
const messageToEdit = messages[messageIndex];
if (messageToEdit.role !== 'user') return;
// Copy message content to input
chatInput.value = messageToEdit.content;
chatInput.style.height = 'auto';
chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
// Remove this message and all subsequent messages from the array
messages = messages.slice(0, messageIndex);
// Remove message elements from DOM starting from messageIndex
const allMessages = chatWrapper.querySelectorAll('.message');
for (let i = messageIndex; i < allMessages.length; i++) {
allMessages[i].remove();
}
// Enable send button and focus input
sendButton.disabled = false;
chatInput.focus();
}
async function generateAssistantResponse() {
isGenerating = true;
sendButton.disabled = true;
const assistantContent = addMessage('assistant', '');
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
try {
const response = await fetch(`${API_URL}/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
messages: messages,
temperature: currentTemperature,
top_k: currentTopK,
max_tokens: 512
}),
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let fullResponse = '';
assistantContent.textContent = '';
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
try {
const data = JSON.parse(line.slice(6));
if (data.token) {
fullResponse += data.token;
assistantContent.textContent = fullResponse;
chatContainer.scrollTop = chatContainer.scrollHeight;
}
} catch (e) {
}
}
}
}
const assistantMessageIndex = messages.length;
messages.push({ role: 'assistant', content: fullResponse });
// Add click handler to regenerate this assistant message
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
assistantContent.setAttribute('title', 'Click to regenerate this response');
assistantContent.addEventListener('click', function() {
if (!isGenerating) {
regenerateMessage(assistantMessageIndex);
}
});
} catch (error) {
console.error('Error:', error);
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
} finally {
isGenerating = false;
sendButton.disabled = !chatInput.value.trim();
}
}
async function regenerateMessage(messageIndex) {
// Find the message in the messages array
if (messageIndex < 0 || messageIndex >= messages.length) return;
const messageToRegenerate = messages[messageIndex];
if (messageToRegenerate.role !== 'assistant') return;
// Remove this message and all subsequent messages from the array
messages = messages.slice(0, messageIndex);
// Remove message elements from DOM starting from messageIndex
const allMessages = chatWrapper.querySelectorAll('.message');
for (let i = messageIndex; i < allMessages.length; i++) {
allMessages[i].remove();
}
// Regenerate the assistant response
await generateAssistantResponse();
}
function handleSlashCommand(command) {
const parts = command.trim().split(/\s+/);
const cmd = parts[0].toLowerCase();
const arg = parts[1];
if (cmd === '/temperature') {
if (arg === undefined) {
addMessage('console', `Current temperature: ${currentTemperature}`);
} else {
const temp = parseFloat(arg);
if (isNaN(temp) || temp < 0 || temp > 2) {
addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
} else {
currentTemperature = temp;
addMessage('console', `Temperature set to ${currentTemperature}`);
}
}
return true;
} else if (cmd === '/topk') {
if (arg === undefined) {
addMessage('console', `Current top-k: ${currentTopK}`);
} else {
const topk = parseInt(arg);
if (isNaN(topk) || topk < 1 || topk > 200) {
addMessage('console', 'Invalid top-k. Must be between 1 and 200');
} else {
currentTopK = topk;
addMessage('console', `Top-k set to ${currentTopK}`);
}
}
return true;
} else if (cmd === '/clear') {
newConversation();
return true;
} else if (cmd === '/help') {
addMessage('console',
'Available commands:\n' +
'/temperature - Show current temperature\n' +
'/temperature <value> - Set temperature (0.0-2.0)\n' +
'/topk - Show current top-k\n' +
'/topk <value> - Set top-k (1-200)\n' +
'/clear - Clear conversation\n' +
'/help - Show this help message'
);
return true;
}
return false;
}
async function sendMessage() {
const message = chatInput.value.trim();
if (!message || isGenerating) return;
// Handle slash commands
if (message.startsWith('/')) {
chatInput.value = '';
chatInput.style.height = 'auto';
handleSlashCommand(message);
return;
}
chatInput.value = '';
chatInput.style.height = 'auto';
const userMessageIndex = messages.length;
messages.push({ role: 'user', content: message });
addMessage('user', message, userMessageIndex);
await generateAssistantResponse();
}
sendButton.disabled = false;
// Autofocus the chat input on page load
chatInput.focus();
fetch(`${API_URL}/health`)
.then(response => response.json())
.then(data => {
console.log('Engine status:', data);
})
.catch(error => {
console.error('Engine not available:', error);
chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
});
</script>
</body>
</html>

View File

@ -1,144 +0,0 @@
# hf-export/sft/tokenization_nanochat.py
import os
import pickle
from typing import List, Optional
from transformers import PreTrainedTokenizer
class NanoChatTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, tokenizer_file: str = "tokenizer.pkl", **kwargs):
# 1) 先加载 enc先把 vocab size 准备好(一定要在 super() 之前)
if tokenizer_file is None:
tokenizer_file = "tokenizer.pkl"
# 关键:优先用 HF 传进来的“模型目录”找文件,而不是 __file__
# 常见字段pretrained_model_name_or_path / name_or_path
base_dir = (
kwargs.pop("pretrained_model_name_or_path", None)
or kwargs.get("name_or_path", None)
)
if base_dir and os.path.isdir(base_dir):
path = os.path.join(base_dir, tokenizer_file)
else:
# 兜底:同目录(仅在不走 dynamic module 时才可能对)
path = os.path.join(os.path.dirname(__file__), tokenizer_file)
with open(path, "rb") as f:
self.enc = pickle.load(f)
self._vocab_size = int(self.enc.n_vocab)
# Avoid letting HF create new token ids for specials; we'll bind to existing ids from the pickle.
bos = "<|bos|>"
eos = "<|assistant_end|>"
# Drop potential duplicates coming from tokenizer_config.json to avoid double kwargs.
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("pad_token", None)
# Call parent without special tokens to prevent added_tokens growth.
super().__init__(bos_token=None, eos_token=None, pad_token=None, **kwargs)
# Reset any auto-added tokens (should be empty, but be safe)
self.added_tokens_encoder.clear()
self.added_tokens_decoder.clear()
# Bind specials to existing ids from the exported encoder
self._bos_token = bos
self._eos_token = eos
self._pad_token = bos
self._bos_token_id = self.enc._special_tokens[bos]
self._eos_token_id = self.enc._special_tokens[eos]
self._pad_token_id = self.enc._special_tokens[bos]
# HF 常用len(tokenizer) 会被调用
def __len__(self):
return self._vocab_size
# 有些地方会访问 tokenizer.vocab_size
@property
def vocab_size(self) -> int:
return self._vocab_size
# Override special token accessors to bind to existing ids (no new tokens added).
@property
def bos_token(self) -> str:
return "<|bos|>"
@property
def eos_token(self) -> str:
return "<|assistant_end|>"
@property
def pad_token(self) -> str:
return "<|bos|>"
@property
def bos_token_id(self) -> int:
return self.enc._special_tokens["<|bos|>"]
@property
def eos_token_id(self) -> int:
return self.enc._special_tokens["<|assistant_end|>"]
@property
def pad_token_id(self) -> int:
return self.enc._special_tokens["<|bos|>"]
def get_vocab(self):
# 注意:这里不要用 self.vocab_size可能触发基类 __getattr__ 的时序坑)
return {str(i): i for i in range(self._vocab_size)}
def _tokenize(self, text: str) -> List[str]:
# Allow special tokens like <|assistant_start|> to pass through with their ids.
ids = self.enc.encode(text, allowed_special=self.enc.special_tokens_set)
return [str(i) for i in ids]
def _convert_token_to_id(self, token: str) -> int:
if isinstance(token, str) and token in self.enc._special_tokens:
return self.enc._special_tokens[token]
return int(token)
def _convert_id_to_token(self, index: int) -> str:
# Translate special token ids back to their string form.
specials = {v: k for k, v in self.enc._special_tokens.items()}
if index in specials:
return specials[index]
return str(index)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
# Preserve token order instead of front-loading specials.
specials = {v: k for k, v in self.enc._special_tokens.items()}
parts: List[str] = []
pending_ids: List[int] = []
def flush_pending():
nonlocal pending_ids
if pending_ids:
parts.append(self.enc.decode(pending_ids))
pending_ids = []
for t in tokens:
# pass through known special token strings
if isinstance(t, str) and t in self.enc._special_tokens:
flush_pending()
parts.append(t)
continue
# or map special ids back to strings
try:
tid = int(t)
except (ValueError, TypeError):
continue
if tid in specials:
flush_pending()
parts.append(specials[tid])
else:
pending_ids.append(tid)
flush_pending()
return "".join(parts)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
):
return token_ids_0 if token_ids_1 is None else token_ids_0 + token_ids_1
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None):
return ()

View File

@ -1,6 +0,0 @@
{
"chat_template": "<|bos|>{% for message in messages %}{% if message['role'] == 'user' %}<|user_start|>{{ message['content'] }}<|user_end|>{% elif message['role'] == 'assistant' %}<|assistant_start|>{{ message['content'] }}<|assistant_end|>{% endif %}{% endfor %}<|assistant_start|>",
"bos_token": "<|bos|>",
"eos_token": "<|assistant_end|>",
"pad_token": "<|bos|>"
}

View File

@ -8,15 +8,21 @@ source .venv/bin/activate
```
## 2) Export a trained checkpoint to HF format
- `nanochat/to_hf.py` loads the latest checkpoint from `~/.cache/nanochat/<source>_checkpoints` and writes an HF folder.
- Choose source: `base` | `mid` | `chatsft` | `chatrl`.
- `nanochat/to_hf.py` (dense) loads the latest checkpoint from `~/.cache/nanochat/<source>_checkpoints` and tokenizer from `~/.cache/nanochat/tokenizer/`, then writes an HF folder.
- `nanochat_moe/to_hf.py` (MoE) loads the latest checkpoint from `~/.cache/nanochat/<source>_checkpoints` and, by default, exports with the `gpt2` tiktoken tokenizer. Use `--tokenizer cache` if you want the cached rustbpe tokenizer from `~/.cache/nanochat/tokenizer/`.
- Choose source: `base` | `mid` | `chatsft` | `chatrl` (same for MoE; `n_layer/n_embd` etc. come from checkpoint metadata).
- A checkpoint directory looks like: `~/.cache/nanochat/<source>_checkpoints/<model_tag>/model_XXXXXX.pt` + `meta_XXXXXX.json` (optimizer shards optional, ignored for export). The exporter auto-picks the largest `model_tag` and latest step if you dont pass `--model-tag/--step`.
```bash
# export latest base checkpoint to hf-export/base
# export latest dense base checkpoint to hf-export/base
uv run python -m nanochat.to_hf --source base --output hf-export/base
# export latest SFT checkpoint (chat model)
# export latest MoE base checkpoint to hf-export/moe_gpt2 (gpt2 tokenizer)
uv run python -m nanochat_moe.to_hf --source base --model-tag d20 --step 1000 --output hf-export/moe_gpt2 --tokenizer gpt2
# export latest SFT checkpoint (chat model, dense)
uv run python -m nanochat.to_hf --source sft --output hf-export/sft
```
- An exported folder should contain (minimum): `config.json`, `pytorch_model.bin`, `tokenizer.pkl`, `tokenizer_config.json`, and the custom code files `configuration_nanochat.py`/`configuration_nanochat_moe.py`, `modeling_nanochat.py`/`modeling_nanochat_moe.py`, `tokenization_nanochat.py`, `gpt.py` (written for `trust_remote_code=True`).
## 3) Run lm-eval benchmarks on the exported model
Use the HF backend (`--model hf`). Pick tasks; nanochat's built-in evals cover these, so they're good starters in lm-eval too:
@ -33,27 +39,62 @@ uv run lm-eval run --model hf \
--tasks mmlu \
--batch_size 1
# A small suite similar to nanochat chat_eval coverage (vanilla HF backend)
# HumanEval requires both flags below to allow executing generated code.
# arc_easy,arc_challenge,mmlu
HF_ALLOW_CODE_EVAL=1 uv run lm-eval run --confirm_run_unsafe_code --model hf \
--model_args pretrained=hf-export/sft,trust_remote_code=True \
--model_args pretrained=hf-export/pretrain,trust_remote_code=True \
--tasks arc_easy,arc_challenge,mmlu \
--batch_size 1 > log.log 2>&1
# Nanochat-aligned tool-use backend (matches nanochat eval formatting)
HF_ALLOW_CODE_EVAL=1 uv run lm-eval run \
# Nanochat-aligned tool-use backend
# gsm8k, humaneval
uv pip install -e tools/lm-eval
PYTHONPATH=tools/lm-eval HF_ALLOW_CODE_EVAL=1 uv run lm-eval run \
--include_path tools/lm-eval/lm_eval/tasks \
--confirm_run_unsafe_code \
--model hf-nanochat-tool \
--model_args pretrained=hf-export/sft,trust_remote_code=True,tokenizer=hf-export/sft \
--model hf-nanochat-no-tool \
--model_args pretrained=hf-export/std,trust_remote_code=True,tokenizer=hf-export/std \
--tasks gsm8k_nanochat,humaneval_nanochat \
--batch_size 1 \
--log_samples \
--output_path lm_eval_sample_nanochat > log.log 2>&1
--output_path lm_eval_sample_nanochat_notool > notool_std_gsm8k_humaneval.log 2>&1
PYTHONPATH=tools/lm-eval HF_ALLOW_CODE_EVAL=1 uv run lm-eval run \
--include_path tools/lm-eval/lm_eval/tasks \
--confirm_run_unsafe_code \
--model hf \
--model_args pretrained=hf-export/pretrain,trust_remote_code=True,tokenizer=hf-export/pretrain \
--tasks gsm8k,humaneval \
--batch_size 1 \
--log_samples \
--output_path lm_eval_sample_nanochat_test > test_pretrain_gsm8k_humaneval.log 2>&1
PYTHONPATH=tools/lm-eval HF_ALLOW_CODE_EVAL=1 uv run lm-eval run \
--include_path tools/lm-eval/lm_eval/tasks \
--confirm_run_unsafe_code \
--model hf-nanochat-no-tool \
--model_args pretrained=hf-export/std,trust_remote_code=True,tokenizer=hf-export/std \
--tasks gsm8k_nanochat,humaneval_nanochat \
--batch_size 1 \
--log_samples \
--limit 3 \
--output_path lm_eval_sample_nanochat_notool > notool_std_gsm8k_humaneval.log 2>&1
# for nanomoe based models - exported with tokenizer = `gpt2`
PYTHONPATH=tools/lm-eval HF_ALLOW_CODE_EVAL=1 uv run lm-eval run \
--include_path tools/lm-eval/lm_eval/tasks \
--confirm_run_unsafe_code \
--model hf \
--model_args pretrained=hf-export/moe_gpt2,trust_remote_code=True,tokenizer=hf-export/moe_gpt2,max_length=1024 \
--gen_kwargs max_gen_toks=128 \
--tasks arc_easy,arc_challenge,mmlu,gsm8k,humaneval \
--batch_size 1 \
--log_samples \
--output_path lm_eval_sample_nanomoe > moe_gpt2.log 2>&1
```
Notes:
- If you exported to a different folder, change `pretrained=...` accordingly. You can also point to a remote HF repo name.
- If you must stay offline, add `HF_DATASETS_OFFLINE=1 HF_HUB_OFFLINE=1 TRANSFORMERS_OFFLINE=1`, **but** ensure the datasets are already cached locally (e.g., `allenai/ai2_arc`, `openai_humaneval`, `gsm8k`, `cais/mmlu`). Otherwise, leave them unset so the harness can download once.
- `--batch_size auto` can help find the largest batch that fits GPU RAM. On CPU, keep it small.
- No KV cache is implemented in the HF wrapper; generation is standard `AutoModelForCausalLM` style. The `hf-nanochat-tool` wrapper runs a nanochat-style tool loop (greedy, batch=1) and does not need `--apply_chat_template` because the prompts already contain special tokens.
- No KV cache is implemented in the HF wrapper; generation is standard `AutoModelForCausalLM` style. The `hf-nanochat-tool` wrapper runs a nanochat-style tool loop (greedy, batch=1) and does not need `--apply_chat_template` because the prompts already contain special tokens. The `hf-nanochat-no-tool` wrapper uses the same greedy loop but does not execute tool-use blocks.

View File

@ -12,7 +12,9 @@ Notes
"""
import argparse
import os
import json
import shutil
from pathlib import Path
from typing import Optional
import torch
@ -32,6 +34,14 @@ from nanochat.gpt import GPT, GPTConfig
from nanochat.common import get_base_dir
from nanochat.tokenizer import get_tokenizer
CHAT_TEMPLATE = (
"<|bos|>{% for message in messages %}{% if message['role'] == 'user' %}"
"<|user_start|>{{ message['content'] }}<|user_end|>"
"{% elif message['role'] == 'assistant' %}"
"<|assistant_start|>{{ message['content'] }}<|assistant_end|>"
"{% endif %}{% endfor %}<|assistant_start|>"
)
class NanoChatHFConfig(PretrainedConfig):
model_type = "nanochat"
@ -137,335 +147,20 @@ def copy_tokenizer_files(output_dir: str):
print(f"[to_hf] Copied tokenizer files from {tokenizer_dir} to {output_dir}")
def write_hf_code(output_dir: str):
cfg_py = r'''
from transformers import PretrainedConfig
class NanoChatHFConfig(PretrainedConfig):
model_type = "nanochat"
def __init__(self, sequence_len=1024, vocab_size=50304, n_layer=12, n_head=6, n_kv_head=6, n_embd=768, **kwargs):
kwargs.setdefault("tie_word_embeddings", False)
super().__init__(**kwargs)
self.sequence_len = sequence_len
self.vocab_size = vocab_size
self.n_layer = n_layer
self.n_head = n_head
self.n_kv_head = n_kv_head
self.n_embd = n_embd
# HF compatibility aliases
self.num_hidden_layers = n_layer
self.hidden_size = n_embd
self.num_attention_heads = n_head
self.num_key_value_heads = n_kv_head
self.max_position_embeddings = sequence_len
'''
mdl_py = r'''
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_nanochat import NanoChatHFConfig
from .gpt import GPT, GPTConfig
class NanoChatHFForCausalLM(PreTrainedModel, GenerationMixin):
config_class = NanoChatHFConfig
def __init__(self, config: NanoChatHFConfig):
super().__init__(config)
gpt_cfg = GPTConfig(
sequence_len=config.sequence_len,
vocab_size=config.vocab_size,
n_layer=config.n_layer,
n_head=config.n_head,
n_kv_head=config.n_kv_head,
n_embd=config.n_embd,
)
self.model = GPT(gpt_cfg)
def get_input_embeddings(self):
return self.model.transformer.wte
def set_input_embeddings(self, value):
self.model.transformer.wte = value
def get_output_embeddings(self):
return self.model.lm_head
def tie_weights(self):
return
def forward(self, input_ids=None, attention_mask=None, labels=None, past_key_values=None, **_):
if input_ids is None:
raise ValueError("input_ids must be provided")
logits = self.model(input_ids)
loss = None
if labels is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1)
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
'''
gpt_py = r'''
"""
Minimal GPT implementation for HF export (inference-only utilities).
"""
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class GPTConfig:
sequence_len: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
y2 = x1 * (-sin) + x2 * cos
out = torch.cat([y1, y2], 3) # re-assemble
out = out.to(x.dtype) # ensure input/output dtypes match
return out
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache):
B, T, _ = x.size()
# Project the input to get queries, keys, and values
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = norm(q), norm(k) # QK norm
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # (B, T, H, D) -> (B, H, T, D)
# Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk:
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif Tq == 1:
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
else:
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
prefix_len = Tk - Tq
if prefix_len > 0:
attn_mask[:, :prefix_len] = True
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
# Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache)
x = x + self.mlp(norm(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Precompute rotary embeddings (small overhead, avoids realloc each forward)
self.rotary_seq_len = config.sequence_len * 10
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
def get_device(self):
return self.transformer.wte.weight.device
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
if device is None:
device = self.transformer.wte.weight.device
# Avoid meta buffers when HF initializes under init_empty_weights; fall back to CPU.
if getattr(device, "type", None) == "meta":
device = torch.device("cpu")
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16()
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
return cos, sin
def forward(self, idx, targets=None, kv_cache=None, loss_reduction="mean"):
B, T = idx.size()
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
x = self.transformer.wte(idx)
x = norm(x)
for block in self.transformer.h:
x = block(x, cos_sin, kv_cache)
x = norm(x)
softcap = 15
if targets is not None:
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap)
logits = logits.float()
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap)
return logits
'''
tok_py = r'''
import os
import pickle
from typing import List, Optional
from transformers import PreTrainedTokenizer
class NanoChatTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, tokenizer_file: str = "tokenizer.pkl", **kwargs):
if tokenizer_file is None:
tokenizer_file = "tokenizer.pkl"
base_dir = (
kwargs.pop("pretrained_model_name_or_path", None)
or kwargs.get("name_or_path", None)
)
if base_dir and os.path.isdir(base_dir):
path = os.path.join(base_dir, tokenizer_file)
else:
path = os.path.join(os.path.dirname(__file__), tokenizer_file)
with open(path, "rb") as f:
self.enc = pickle.load(f)
self._vocab_size = int(self.enc.n_vocab)
kwargs.setdefault("bos_token", "<|bos|>")
# Fallback: reuse bos as eos/pad to satisfy HF APIs; underlying vocab is index-only.
kwargs.setdefault("eos_token", kwargs["bos_token"])
kwargs.setdefault("pad_token", kwargs["bos_token"])
super().__init__(**kwargs)
def __len__(self):
return self._vocab_size
@property
def vocab_size(self) -> int:
return self._vocab_size
def get_vocab(self):
return {str(i): i for i in range(self._vocab_size)}
def _tokenize(self, text: str) -> List[str]:
ids = self.enc.encode_ordinary(text)
return [str(i) for i in ids]
def _convert_token_to_id(self, token: str) -> int:
if isinstance(token, str) and token.startswith("<|") and token.endswith("|>"):
return 0
return int(token)
def _convert_id_to_token(self, index: int) -> str:
return str(index)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
ids = []
for t in tokens:
try:
ids.append(int(t))
except (ValueError, TypeError):
# Skip special tokens like <|pad|> that are not part of the numeric vocab.
continue
return self.enc.decode(ids)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
):
return token_ids_0 if token_ids_1 is None else token_ids_0 + token_ids_1
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None):
return ()
'''
chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}<|user_start|>{{ message['content'] }}<|user_end|>{% elif message['role'] == 'assistant' %}<|assistant_start|>{{ message['content'] }}<|assistant_end|>{% endif %}{% endfor %}<|assistant_start|>"
with open(os.path.join(output_dir, "configuration_nanochat.py"), "w") as f:
f.write(cfg_py)
with open(os.path.join(output_dir, "modeling_nanochat.py"), "w") as f:
f.write(mdl_py)
with open(os.path.join(output_dir, "gpt.py"), "w") as f:
f.write(gpt_py)
with open(os.path.join(output_dir, "tokenization_nanochat.py"), "w") as f:
f.write(tok_py)
# Minimal tokenizer_config for transformers (adds chat template and tokens).
import json
tok_cfg = {
"chat_template": chat_template
}
with open(os.path.join(output_dir, "tokenizer_config.json"), "w") as f:
json.dump(tok_cfg, f, indent=2)
"""
Copy the HF-compatible python modules into the export dir.
Prefer the checked-in template under hf-export/std; fallback to writing nothing
if not present (export will still work if trust_remote_code imports from local FS).
"""
template_dir = Path(__file__).resolve().parent.parent / "hf-export" / "std"
targets = ["configuration_nanochat.py", "modeling_nanochat.py", "tokenization_nanochat.py", "gpt.py"]
if template_dir.is_dir():
for name in targets:
src = template_dir / name
if src.exists():
shutil.copy2(src, Path(output_dir) / name)
else:
print(f"[to_hf] Template dir not found at {template_dir}, skipping python module copy")
def export_to_hf(source: str, output_dir: str, model_tag: Optional[str], step: Optional[int]):
@ -477,7 +172,7 @@ def export_to_hf(source: str, output_dir: str, model_tag: Optional[str], step: O
hf_model.model.load_state_dict(model.state_dict(), strict=True)
os.makedirs(output_dir, exist_ok=True)
# Tell transformers how to load custom code when trust_remote_code=True
# Tell transformers how to load custom code when trust_remote_code=True
hf_model.config.auto_map = {
"AutoConfig": "configuration_nanochat.NanoChatHFConfig",
"AutoModelForCausalLM": "modeling_nanochat.NanoChatHFForCausalLM",
@ -485,10 +180,25 @@ def export_to_hf(source: str, output_dir: str, model_tag: Optional[str], step: O
}
hf_model.config.tokenizer_class = "NanoChatTokenizer"
hf_model.config.architectures = ["NanoChatHFForCausalLM"]
# Bind special token ids/strings from tokenizer
bos_id = tokenizer.encode_special("<|bos|>")
eos_id = tokenizer.encode_special("<|assistant_end|>")
hf_model.config.bos_token_id = bos_id
hf_model.config.eos_token_id = eos_id
hf_model.config.pad_token_id = bos_id
hf_model.save_pretrained(output_dir, safe_serialization=False)
# Best effort: drop tokenizer files alongside weights
copy_tokenizer_files(output_dir)
# Write tokenizer_config with chat template and special tokens
tok_cfg = {
"chat_template": CHAT_TEMPLATE,
"bos_token": "<|bos|>",
"eos_token": "<|assistant_end|>",
"pad_token": "<|bos|>",
}
with open(os.path.join(output_dir, "tokenizer_config.json"), "w") as f:
json.dump(tok_cfg, f, indent=2)
print(f"[to_hf] Exported {source} checkpoint to {output_dir}")
write_hf_code(output_dir)

View File

@ -151,11 +151,12 @@ class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.gelu(x)
x = self.c_proj(x)
return x
@ -273,17 +274,18 @@ class Router(nn.Module):
class MLPExperts(nn.Module):
"""Multiple MLP experts - adapted for nanochat (relu^2, no bias)"""
"""Multiple MLP experts - adapted for nanochat (GELU, no bias)"""
def __init__(self, config):
super().__init__()
# no bias, using relu^2 activation like nanochat
# no bias, using GELU activation to match standard GPT
self.c_fc = nn.Parameter(torch.empty(config.n_exp, config.n_embd, 4 * config.n_embd))
self.c_proj = nn.Parameter(torch.empty(config.n_exp, 4 * config.n_embd, config.n_embd))
self.gelu = nn.GELU()
def forward(self, x):
# x: [n_exp, exp_capacity, n_embd]
x = torch.bmm(x, self.c_fc) # [n_exp, exp_capacity, 4*n_embd]
x = F.relu(x).square() # relu^2 activation (nanochat style)
x = self.gelu(x)
x = torch.bmm(x, self.c_proj) # [n_exp, exp_capacity, n_embd]
return x

382
nanochat_moe/to_hf.py Normal file
View File

@ -0,0 +1,382 @@
"""
Convert a nanochat-MoE checkpoint into a HuggingFace-style folder.
Usage (example):
python -m nanochat_moe.to_hf --source base --output hf-export/moe
Notes
- Assumes checkpoints live under ~/.cache/nanochat/<source>_checkpoints/ (same as training scripts).
- The exported model can be loaded with transformers via:
AutoModelForCausalLM.from_pretrained(<export_dir>, trust_remote_code=True)
- KV cache is not implemented in the HF wrapper; generation works but is not incremental.
"""
import argparse
import os
import json
import shutil
from pathlib import Path
from typing import Optional
import torch
import torch.nn.functional as F
import sys
try:
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerationMixin
except ImportError as exc:
raise SystemExit(
"transformers is required for HF export. Run `uv sync` (with the hf extra) first."
) from exc
from nanochat_moe.checkpoint_manager import find_last_step, find_largest_model
# standard.py expects `manager` to be importable; alias the package module to satisfy that import.
import nanochat_moe.manager as _moe_manager
sys.modules.setdefault("manager", _moe_manager)
from nanochat_moe.standard import GPT, GPTConfig
from nanochat_moe.common import get_base_dir
from nanochat_moe.tokenizer import get_tokenizer, RustBPETokenizer
CHAT_TEMPLATE = (
"<|bos|>{% for message in messages %}{% if message['role'] == 'user' %}"
"<|user_start|>{{ message['content'] }}<|user_end|>"
"{% elif message['role'] == 'assistant' %}"
"<|assistant_start|>{{ message['content'] }}<|assistant_end|>"
"{% endif %}{% endfor %}<|assistant_start|>"
)
CHAT_SPECIAL_TOKENS = {
"<|user_start|>",
"<|user_end|>",
"<|assistant_start|>",
"<|assistant_end|>",
}
def load_export_tokenizer(tokenizer_mode: str):
if tokenizer_mode == "gpt2":
return RustBPETokenizer.from_pretrained("gpt2")
if tokenizer_mode == "cache":
return get_tokenizer()
raise ValueError(f"Unknown tokenizer mode: {tokenizer_mode}")
def resolve_special_tokens(tokenizer):
specials = set(tokenizer.get_special_tokens())
if "<|bos|>" in specials:
bos = "<|bos|>"
elif "<|endoftext|>" in specials:
bos = "<|endoftext|>"
else:
raise ValueError(f"Tokenizer specials missing BOS token: {sorted(specials)}")
if "<|assistant_end|>" in specials:
eos = "<|assistant_end|>"
else:
eos = bos
pad = bos
chat_template = (
CHAT_TEMPLATE
if "<|bos|>" in specials and CHAT_SPECIAL_TOKENS.issubset(specials)
else None
)
return bos, eos, pad, chat_template
class NanoChatMoEHFConfig(PretrainedConfig):
model_type = "nanochat-moe"
def __init__(
self,
block_size: int = 1024,
vocab_size: int = 50304,
n_layer: int = 6,
n_head: int = 6,
n_embd: int = 384,
dropout: float = 0.0,
bias: bool = False,
n_exp: int = 8,
top_k: int = 2,
use_aux_loss: bool = True,
use_router_z_loss: bool = True,
use_noisy_top_k: bool = False,
aux_loss_weight: float = 0.01,
router_z_loss_weight: float = 0.001,
train_capacity: float = 1.25,
eval_capacity: float = 2.0,
min_capacity: int = 4,
stride: int = 2,
use_switch_tfm_init: bool = True,
switch_tfm_init_scale: float = 1.0,
router_use_full_prec: bool = True,
**kwargs,
):
# Don't tie embeddings; nanochat uses untied wte/lm_head
kwargs.setdefault("tie_word_embeddings", False)
super().__init__(**kwargs)
self.block_size = block_size
self.vocab_size = vocab_size
self.n_layer = n_layer
self.n_head = n_head
self.n_embd = n_embd
self.dropout = dropout
self.bias = bias
self.n_exp = n_exp
self.top_k = top_k
self.use_aux_loss = use_aux_loss
self.use_router_z_loss = use_router_z_loss
self.use_noisy_top_k = use_noisy_top_k
self.aux_loss_weight = aux_loss_weight
self.router_z_loss_weight = router_z_loss_weight
self.train_capacity = train_capacity
self.eval_capacity = eval_capacity
self.min_capacity = min_capacity
self.stride = stride
self.use_switch_tfm_init = use_switch_tfm_init
self.switch_tfm_init_scale = switch_tfm_init_scale
self.router_use_full_prec = router_use_full_prec
# HF compatibility aliases
self.n_positions = block_size
self.n_ctx = block_size
self.n_embd = n_embd
self.n_head = n_head
self.n_layer = n_layer
self.n_inner = 4 * n_embd
self.max_position_embeddings = block_size
class NanoChatMoEHFForCausalLM(PreTrainedModel, GenerationMixin):
config_class = NanoChatMoEHFConfig
def __init__(self, config: NanoChatMoEHFConfig):
super().__init__(config)
gpt_cfg = GPTConfig(
block_size=config.block_size,
vocab_size=config.vocab_size,
n_layer=config.n_layer,
n_head=config.n_head,
n_embd=config.n_embd,
dropout=config.dropout,
bias=config.bias,
n_exp=config.n_exp,
top_k=config.top_k,
use_aux_loss=config.use_aux_loss,
use_router_z_loss=config.use_router_z_loss,
use_noisy_top_k=config.use_noisy_top_k,
aux_loss_weight=config.aux_loss_weight,
router_z_loss_weight=config.router_z_loss_weight,
train_capacity=config.train_capacity,
eval_capacity=config.eval_capacity,
min_capacity=config.min_capacity,
stride=config.stride,
use_switch_tfm_init=config.use_switch_tfm_init,
switch_tfm_init_scale=config.switch_tfm_init_scale,
router_use_full_prec=config.router_use_full_prec,
)
self.model = GPT(gpt_cfg)
def get_input_embeddings(self):
return self.model.transformer.wte
def set_input_embeddings(self, value):
self.model.transformer.wte = value
def get_output_embeddings(self):
return self.model.lm_head
def tie_weights(self):
# nanochat uses untied embeddings; override to no-op
return
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, # unused
labels: Optional[torch.LongTensor] = None,
past_key_values=None, # not implemented
**_: dict,
) -> CausalLMOutputWithPast:
if input_ids is None:
raise ValueError("input_ids must be provided")
logits = self.model(input_ids)
loss = None
if labels is not None:
# Align shapes for CE: shift labels to match logits
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
def export_tokenizer_files(output_dir: str, tokenizer, tokenizer_mode: str):
if tokenizer_mode == "cache":
base_dir = get_base_dir()
tokenizer_dir = os.path.join(base_dir, "tokenizer")
if os.path.isdir(tokenizer_dir):
for name in os.listdir(tokenizer_dir):
src = os.path.join(tokenizer_dir, name)
dst = os.path.join(output_dir, name)
if os.path.isdir(src):
shutil.copytree(src, dst, dirs_exist_ok=True)
else:
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy2(src, dst)
print(f"[to_hf] Copied tokenizer files from {tokenizer_dir} to {output_dir}")
return
print(f"[to_hf] tokenizer directory not found at {tokenizer_dir}, exporting tokenizer.pkl only")
tokenizer.save(output_dir)
def load_moe_checkpoint(
source: str,
model_tag: Optional[str],
step: Optional[int],
device: torch.device,
tokenizer,
):
"""
Load a nanochat-MoE checkpoint (model_<step>.pt + meta_<step>.json) using the standard GPT architecture.
"""
base_dir = get_base_dir()
checkpoint_root = os.path.join(base_dir, f"{source}_checkpoints")
if model_tag is None:
model_tag = find_largest_model(checkpoint_root)
checkpoint_dir = os.path.join(checkpoint_root, model_tag)
if step is None:
step = find_last_step(checkpoint_dir)
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
model_data = torch.load(model_path, map_location=device)
if device.type in {"cpu", "mps"}:
model_data = {k: (v.float() if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16 else v) for k, v in model_data.items()}
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f)
cfg_kwargs = meta["model_config"]
# If tokenizer size differs, pad embeddings and update config
tok_vocab = tokenizer.get_vocab_size()
cfg_vocab = cfg_kwargs.get("vocab_size")
if tok_vocab > cfg_vocab:
cfg_kwargs["vocab_size"] = tok_vocab
# pad wte and lm_head to match tokenizer size
for key in ["transformer.wte.weight", "lm_head.weight"]:
if key in model_data:
tensor = model_data[key]
if tensor.size(0) < tok_vocab:
pad_rows = tok_vocab - tensor.size(0)
pad = torch.zeros((pad_rows, tensor.size(1)), device=tensor.device, dtype=tensor.dtype)
model_data[key] = torch.cat([tensor, pad], dim=0)
meta["model_config"] = cfg_kwargs
elif tok_vocab < cfg_vocab:
print(f"[to_hf] tokenizer vocab ({tok_vocab}) < model vocab ({cfg_vocab}); keeping model vocab")
gpt_cfg = GPTConfig(**cfg_kwargs)
model = GPT(gpt_cfg)
model.load_state_dict(model_data, strict=True)
model.eval()
return model, meta
def write_hf_code(output_dir: str):
"""
Copy the HF-compatible python modules into the export dir.
Prefer the checked-in template under hf-export/std; fallback to writing nothing
if not present (export will still work if trust_remote_code imports from local FS).
"""
template_dir = Path(__file__).resolve().parent.parent / "hf-export" / "moe"
targets = [
"configuration_nanochat_moe.py",
"modeling_nanochat_moe.py",
"tokenization_nanochat.py",
"gpt.py",
]
if template_dir.is_dir():
for name in targets:
src = template_dir / name
if src.exists():
shutil.copy2(src, Path(output_dir) / name)
else:
print(f"[to_hf] Template dir not found at {template_dir}, skipping python module copy")
def export_to_hf(
source: str,
output_dir: str,
model_tag: Optional[str],
step: Optional[int],
tokenizer_mode: str,
):
device = torch.device("cpu")
tokenizer = load_export_tokenizer(tokenizer_mode)
model, meta = load_moe_checkpoint(source, model_tag, step, device, tokenizer)
cfg_kwargs = meta["model_config"]
hf_config = NanoChatMoEHFConfig(**cfg_kwargs)
hf_model = NanoChatMoEHFForCausalLM(hf_config)
hf_model.model.load_state_dict(model.state_dict(), strict=True)
os.makedirs(output_dir, exist_ok=True)
# Tell transformers how to load custom code when trust_remote_code=True
hf_model.config.auto_map = {
"AutoConfig": "configuration_nanochat_moe.NanoChatMoEHFConfig",
"AutoModelForCausalLM": "modeling_nanochat_moe.NanoChatMoEHFForCausalLM",
"AutoTokenizer": ["tokenization_nanochat.NanoChatTokenizer", None],
}
hf_model.config.tokenizer_class = "NanoChatTokenizer"
hf_model.config.architectures = ["NanoChatMoEHFForCausalLM"]
# Bind special token ids/strings from tokenizer
bos_token, eos_token, pad_token, chat_template = resolve_special_tokens(tokenizer)
bos_id = tokenizer.encode_special(bos_token)
eos_id = tokenizer.encode_special(eos_token)
pad_id = tokenizer.encode_special(pad_token)
hf_model.config.bos_token_id = bos_id
hf_model.config.eos_token_id = eos_id
hf_model.config.pad_token_id = pad_id
# generation config safety: clear sampling-only fields when do_sample is False
if hasattr(hf_model, "generation_config"):
hf_model.generation_config.top_k = None
hf_model.save_pretrained(output_dir, safe_serialization=False)
# Best effort: drop tokenizer files alongside weights
export_tokenizer_files(output_dir, tokenizer, tokenizer_mode)
# Write tokenizer_config with chat template and special tokens
tok_cfg = {
"bos_token": bos_token,
"eos_token": eos_token,
"pad_token": pad_token,
}
if chat_template is not None:
tok_cfg["chat_template"] = chat_template
with open(os.path.join(output_dir, "tokenizer_config.json"), "w") as f:
json.dump(tok_cfg, f, indent=2)
print(f"[to_hf] Exported {source} checkpoint to {output_dir}")
write_hf_code(output_dir)
def main():
parser = argparse.ArgumentParser(description="Export nanochat-MoE checkpoint to HuggingFace format")
parser.add_argument("--source", choices=["base", "mid", "sft", "rl"], default="base", help="Which checkpoint family to export")
parser.add_argument("--model-tag", type=str, default=None, help="Model tag (e.g., d20). Defaults to largest available.")
parser.add_argument("--step", type=int, default=None, help="Checkpoint step. Defaults to latest.")
parser.add_argument("--output", type=str, default="hf-export/moe", help="Output directory for HF files")
parser.add_argument(
"--tokenizer",
choices=["gpt2", "cache"],
default="gpt2",
help="Tokenizer source for export: gpt2 uses tiktoken; cache uses ~/.cache/nanochat/tokenizer",
)
args = parser.parse_args()
export_to_hf(args.source, args.output, args.model_tag, args.step, args.tokenizer)
if __name__ == "__main__":
main()

235
scripts_moe/quick_infer.py Normal file
View File

@ -0,0 +1,235 @@
"""
Quick local inference for nanochat-MoE checkpoints (pretrain-style, plain text).
Example:
uv run python scripts_moe/quick_infer.py --model-tag d20 --prompt "what's 1+1 equal to?"
uv run python scripts_moe/quick_infer.py --hf-path hf-export/moe_gpt2 --prompt "what's 1+1 equal to?"
"""
import argparse
import json
import os
import sys
import torch
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if REPO_ROOT not in sys.path:
sys.path.insert(0, REPO_ROOT)
import nanochat_moe.manager as _moe_manager
sys.modules.setdefault("manager", _moe_manager)
from nanochat_moe.checkpoint_manager import find_last_step, find_largest_model
from nanochat_moe.common import get_base_dir
from nanochat_moe.standard import GPT, GPTConfig
from nanochat_moe.tokenizer import RustBPETokenizer, get_tokenizer
def pad_vocab_weights(state_dict, target_vocab):
for key in ("transformer.wte.weight", "lm_head.weight"):
if key not in state_dict:
continue
tensor = state_dict[key]
if tensor.size(0) >= target_vocab:
continue
pad_rows = target_vocab - tensor.size(0)
pad = tensor.new_zeros((pad_rows, tensor.size(1)))
state_dict[key] = torch.cat([tensor, pad], dim=0)
def load_tokenizer(mode):
if mode == "gpt2":
return RustBPETokenizer.from_pretrained("gpt2")
if mode == "cache":
return get_tokenizer()
raise ValueError(f"Unknown tokenizer mode: {mode}")
def load_hf_model(hf_path, device):
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(hf_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(hf_path, trust_remote_code=True)
model.to(device)
model.eval()
vocab_limit = None
if tokenizer.vocab_size < model.config.vocab_size:
vocab_limit = tokenizer.vocab_size
print(
f"Warning: tokenizer vocab {tokenizer.vocab_size} < model vocab {model.config.vocab_size}; "
"clamping logits to tokenizer vocab."
)
return model, tokenizer, vocab_limit
def load_standard_model(source, device, model_tag, step, tokenizer):
base_dir = get_base_dir()
checkpoint_root = os.path.join(base_dir, f"{source}_checkpoints")
if model_tag is None:
model_tag = find_largest_model(checkpoint_root)
checkpoint_dir = os.path.join(checkpoint_root, model_tag)
if step is None:
step = find_last_step(checkpoint_dir)
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
model_data = torch.load(model_path, map_location=device)
if device.type in {"cpu", "mps"}:
model_data = {
k: (v.float() if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16 else v)
for k, v in model_data.items()
}
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f)
cfg_kwargs = meta["model_config"]
tok_vocab = tokenizer.get_vocab_size()
cfg_vocab = cfg_kwargs.get("vocab_size")
vocab_limit = None
if tok_vocab > cfg_vocab:
print(
f"Warning: tokenizer vocab {tok_vocab} > model vocab {cfg_vocab}; "
"padding embeddings to match."
)
cfg_kwargs = dict(cfg_kwargs)
cfg_kwargs["vocab_size"] = tok_vocab
pad_vocab_weights(model_data, tok_vocab)
elif tok_vocab < cfg_vocab:
vocab_limit = tok_vocab
print(
f"Warning: tokenizer vocab {tok_vocab} < model vocab {cfg_vocab}; "
"clamping logits to tokenizer vocab."
)
model = GPT(GPTConfig(**cfg_kwargs))
model.load_state_dict(model_data, strict=True)
model.to(device)
model.eval()
return model, tokenizer, meta, model_tag, step, vocab_limit
def build_prompt_tokens(tokenizer, question: str):
prompt = f"Question: {question}\nAnswer:"
bos = tokenizer.get_bos_token_id()
return tokenizer.encode(prompt, prepend=bos), prompt
def build_prompt_tokens_hf(tokenizer, question: str):
prompt = f"Question: {question}\nAnswer:"
ids = tokenizer.encode(prompt, add_special_tokens=False)
bos_id = getattr(tokenizer, "bos_token_id", None)
if bos_id is not None:
ids = [bos_id] + ids
return ids, prompt
def sample_next_token(logits, temperature, top_k):
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits = torch.where(
logits < v[:, [-1]],
torch.full_like(logits, -float("inf")),
logits,
)
if temperature <= 0:
return torch.argmax(logits, dim=-1, keepdim=True)
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1)
@torch.inference_mode()
def generate_tokens(model, input_ids, max_new_tokens, temperature, top_k, vocab_limit):
for _ in range(max_new_tokens):
idx_cond = (
input_ids
if input_ids.size(1) <= model.config.block_size
else input_ids[:, -model.config.block_size :]
)
logits, _ = model(idx_cond)
logits = logits[:, -1, :]
if vocab_limit is not None and vocab_limit < logits.size(-1):
logits = logits[:, :vocab_limit]
next_token = sample_next_token(logits, temperature, top_k)
input_ids = torch.cat((input_ids, next_token), dim=1)
return input_ids
@torch.inference_mode()
def generate_tokens_hf(model, input_ids, max_new_tokens, temperature, top_k, vocab_limit):
for _ in range(max_new_tokens):
logits = model(input_ids).logits
logits = logits[:, -1, :]
if vocab_limit is not None and vocab_limit < logits.size(-1):
logits = logits[:, :vocab_limit]
next_token = sample_next_token(logits, temperature, top_k)
input_ids = torch.cat((input_ids, next_token), dim=1)
return input_ids
@torch.inference_mode()
def main():
parser = argparse.ArgumentParser(description="Run a quick nanochat-MoE inference")
parser.add_argument("--source", type=str, default="base", choices=["base", "mid", "sft", "rl"])
parser.add_argument("--model-tag", type=str, default="d20")
parser.add_argument("--step", type=int, default=None)
parser.add_argument("--hf-path", type=str, default=None)
parser.add_argument("--prompt", type=str, default="what's 1+1 equal to?")
parser.add_argument("--max-tokens", type=int, default=64)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top-k", type=int, default=None)
parser.add_argument("--tokenizer", type=str, default="gpt2", choices=["gpt2", "cache"])
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.hf_path:
model, tokenizer, vocab_limit = load_hf_model(args.hf_path, device)
prompt_tokens, prompt_text = build_prompt_tokens_hf(tokenizer, args.prompt)
input_ids = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
output_ids = generate_tokens_hf(
model,
input_ids,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
vocab_limit=vocab_limit,
)
generated_tokens = output_ids[0].tolist()[len(prompt_tokens) :]
output_text = tokenizer.decode(generated_tokens)
print(f"Loaded HF model: {args.hf_path}")
else:
tokenizer = load_tokenizer(args.tokenizer)
model, tokenizer, _, resolved_tag, resolved_step, vocab_limit = load_standard_model(
args.source, device, args.model_tag, args.step, tokenizer
)
prompt_tokens, prompt_text = build_prompt_tokens(tokenizer, args.prompt)
if len(prompt_tokens) > model.config.block_size:
print(
f"Warning: prompt length {len(prompt_tokens)} exceeds block_size "
f"{model.config.block_size}; context will be truncated."
)
input_ids = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
output_ids = generate_tokens(
model,
input_ids,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
vocab_limit=vocab_limit,
)
generated_tokens = output_ids[0].tolist()[len(prompt_tokens) :]
output_text = tokenizer.decode(generated_tokens)
print(f"Loaded: {args.source}/{resolved_tag} step {resolved_step}")
# print("Prompt:", prompt_text)
# print("Output:", output_text)
print("===============Prompt===============")
print(prompt_text)
print("===============Output===============")
print(output_text)
if __name__ == "__main__":
main()

@ -1 +1 @@
Subproject commit 32c4b74696a41586712a8a8b7906591833ba1a78
Subproject commit 367008686f55588c325b9dc0ae64b449fa943158