mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-08 16:59:59 +00:00
Landing page with desi street-food aesthetic: lemon-mirchi toran with pendulum animation, dual-script hero (Devanagari + English cursive), samosa illustration with floating animation, brass chai kettle with steam wisps, ambient chilli/lemon doodles. Chat page carries the warm samosa-chaat palette with cream/gold user bubbles, steam-wisp typing indicator, and WebGPU integration hooks (window.samosaChaat API for local inference mode switching). Added scripts/export_onnx.py for ONNX model export with KV cache support, targeting WebGPU browser inference. Credit to Andrej Karpathy's nanochat in footer. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
470 lines
18 KiB
Python
470 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Export the nanochat/samosaChaat model to ONNX for WebGPU inference in the browser.
|
|
|
|
Creates two ONNX models:
|
|
1. prefill.onnx - Processes the full prompt, returns logits + KV cache
|
|
2. decode.onnx - Single-token generation with KV cache, returns logits + updated cache
|
|
|
|
Usage:
|
|
python -m scripts.export_onnx
|
|
python -m scripts.export_onnx --quantize # also produce INT4 quantized versions
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import json
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from nanochat.common import COMPUTE_DTYPE
|
|
from nanochat.checkpoint_manager import load_model
|
|
from nanochat.gpt import GPTConfig, apply_rotary_emb, has_ve
|
|
|
|
|
|
def norm(x):
|
|
"""RMS norm implemented with basic ops for ONNX compatibility."""
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6)
|
|
|
|
parser = argparse.ArgumentParser(description='Export samosaChaat to ONNX')
|
|
parser.add_argument('--output-dir', type=str, default='onnx_export', help='Output directory')
|
|
parser.add_argument('--quantize', action='store_true', help='Also produce INT4 quantized models')
|
|
parser.add_argument('-i', '--source', type=str, default='sft', help='Model source: sft|rl')
|
|
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag')
|
|
parser.add_argument('-s', '--step', type=int, default=None, help='Step')
|
|
args = parser.parse_args()
|
|
|
|
|
|
class OnnxAttention(nn.Module):
|
|
"""Attention layer rewritten for ONNX export (no flash attention, explicit KV cache I/O)."""
|
|
|
|
def __init__(self, attn, layer_idx, n_layer):
|
|
super().__init__()
|
|
self.layer_idx = layer_idx
|
|
self.n_head = attn.n_head
|
|
self.n_kv_head = attn.n_kv_head
|
|
self.head_dim = attn.head_dim
|
|
self.n_groups = self.n_head // self.n_kv_head
|
|
# Copy weights
|
|
self.c_q = attn.c_q
|
|
self.c_k = attn.c_k
|
|
self.c_v = attn.c_v
|
|
self.c_proj = attn.c_proj
|
|
self.ve_gate = attn.ve_gate
|
|
|
|
def forward(self, x, ve, cos, sin, past_key, past_value):
|
|
"""
|
|
Args:
|
|
x: (B, T, C)
|
|
ve: (B, T, kv_dim) or None
|
|
cos, sin: (1, T, 1, head_dim//2) - already offset for position
|
|
past_key: (B, past_len, n_kv_head, head_dim)
|
|
past_value: (B, past_len, n_kv_head, head_dim)
|
|
Returns:
|
|
output: (B, T, C)
|
|
present_key: (B, past_len+T, n_kv_head, head_dim)
|
|
present_value: (B, past_len+T, n_kv_head, head_dim)
|
|
"""
|
|
B, T, C = x.size()
|
|
|
|
# Project Q, K, V
|
|
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
|
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
|
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
|
|
|
# Value embeddings (ResFormer)
|
|
if ve is not None and self.ve_gate is not None:
|
|
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
|
gate = 3 * torch.sigmoid(self.ve_gate(x[..., :12]))
|
|
v = v + gate.unsqueeze(-1) * ve
|
|
|
|
# Rotary embeddings
|
|
q = apply_rotary_emb(q, cos, sin)
|
|
k = apply_rotary_emb(k, cos, sin)
|
|
|
|
# QK norm + scaling
|
|
q = norm(q) * 1.2
|
|
k = norm(k) * 1.2
|
|
|
|
# Concatenate past KV cache
|
|
# past_key/past_value: (B, past_len, n_kv_head, head_dim)
|
|
present_key = torch.cat([past_key, k], dim=1)
|
|
present_value = torch.cat([past_key, v], dim=1) # BUG: should be past_value
|
|
present_value = torch.cat([past_value, v], dim=1)
|
|
|
|
# Standard attention (no flash attention for ONNX compatibility)
|
|
# Transpose to (B, H, T, D) for matmul
|
|
q_t = q.permute(0, 2, 1, 3) # (B, n_head, T, head_dim)
|
|
|
|
# Expand KV heads for GQA: (B, n_kv_head, S, D) -> (B, n_head, S, D)
|
|
k_t = present_key.permute(0, 2, 1, 3) # (B, n_kv_head, S, head_dim)
|
|
v_t = present_value.permute(0, 2, 1, 3)
|
|
if self.n_groups > 1:
|
|
k_t = k_t.unsqueeze(2).expand(-1, -1, self.n_groups, -1, -1).reshape(B, self.n_head, -1, self.head_dim)
|
|
v_t = v_t.unsqueeze(2).expand(-1, -1, self.n_groups, -1, -1).reshape(B, self.n_head, -1, self.head_dim)
|
|
|
|
S = present_key.size(1) # total sequence length (past + current)
|
|
|
|
# Scaled dot-product attention with causal mask
|
|
scale = self.head_dim ** -0.5
|
|
attn_weights = torch.matmul(q_t, k_t.transpose(-2, -1)) * scale # (B, H, T, S)
|
|
|
|
# Causal mask: each query position can only attend to positions <= its own
|
|
# Query positions are [S-T, S-T+1, ..., S-1], key positions are [0, 1, ..., S-1]
|
|
query_pos = torch.arange(S - T, S, device=x.device).unsqueeze(1) # (T, 1)
|
|
key_pos = torch.arange(S, device=x.device).unsqueeze(0) # (1, S)
|
|
causal_mask = key_pos <= query_pos # (T, S)
|
|
attn_weights = attn_weights.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
y = torch.matmul(attn_weights, v_t) # (B, H, T, D)
|
|
|
|
# Transpose back and project
|
|
y = y.permute(0, 2, 1, 3).contiguous().view(B, T, -1)
|
|
y = self.c_proj(y)
|
|
|
|
return y, present_key, present_value
|
|
|
|
|
|
class OnnxMLP(nn.Module):
|
|
"""MLP rewritten for ONNX (avoids importing norm from gpt.py)."""
|
|
def __init__(self, mlp):
|
|
super().__init__()
|
|
self.c_fc = mlp.c_fc
|
|
self.c_proj = mlp.c_proj
|
|
|
|
def forward(self, x):
|
|
x = self.c_fc(x)
|
|
x = F.relu(x).square()
|
|
x = self.c_proj(x)
|
|
return x
|
|
|
|
|
|
class OnnxBlock(nn.Module):
|
|
"""Transformer block rewritten for ONNX (explicit KV cache)."""
|
|
|
|
def __init__(self, block, layer_idx, n_layer):
|
|
super().__init__()
|
|
self.attn = OnnxAttention(block.attn, layer_idx, n_layer)
|
|
self.mlp = OnnxMLP(block.mlp)
|
|
|
|
def forward(self, x, ve, cos, sin, past_key, past_value):
|
|
attn_out, present_key, present_value = self.attn(norm(x), ve, cos, sin, past_key, past_value)
|
|
x = x + attn_out
|
|
x = x + self.mlp(norm(x))
|
|
return x, present_key, present_value
|
|
|
|
|
|
class OnnxGPT(nn.Module):
|
|
"""
|
|
ONNX-exportable wrapper around the GPT model.
|
|
|
|
Inputs:
|
|
input_ids: (B, T) int64
|
|
position: (1,) int64 - position offset for rotary embeddings
|
|
prev_embedding: (B, 1, n_embd) - previous token's embedding for smear
|
|
past_keys: (n_layers, B, past_len, n_kv_head, head_dim)
|
|
past_values: (n_layers, B, past_len, n_kv_head, head_dim)
|
|
|
|
Outputs:
|
|
logits: (B, T, vocab_size)
|
|
new_prev_embedding: (B, 1, n_embd)
|
|
present_keys: (n_layers, B, past_len+T, n_kv_head, head_dim)
|
|
present_values: (n_layers, B, past_len+T, n_kv_head, head_dim)
|
|
"""
|
|
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
config = model.config
|
|
self.config = config
|
|
self.n_layer = config.n_layer
|
|
self.n_kv_head = config.n_kv_head
|
|
self.head_dim = config.n_embd // config.n_head
|
|
self.vocab_size = config.vocab_size
|
|
|
|
# Copy model components
|
|
self.wte = model.transformer.wte
|
|
self.lm_head = model.lm_head
|
|
self.resid_lambdas = model.resid_lambdas
|
|
self.x0_lambdas = model.x0_lambdas
|
|
self.smear_gate = model.smear_gate
|
|
self.smear_lambda = model.smear_lambda
|
|
self.backout_lambda = model.backout_lambda
|
|
self.value_embeds = model.value_embeds
|
|
|
|
# Rotary embeddings (precomputed)
|
|
self.register_buffer("cos", model.cos, persistent=False)
|
|
self.register_buffer("sin", model.sin, persistent=False)
|
|
|
|
# Window sizes (baked in as constants)
|
|
self.window_sizes = model.window_sizes
|
|
|
|
# Rebuild blocks with ONNX-compatible attention
|
|
self.blocks = nn.ModuleList([
|
|
OnnxBlock(model.transformer.h[i], i, config.n_layer)
|
|
for i in range(config.n_layer)
|
|
])
|
|
|
|
def forward(self, input_ids, cos_slice, sin_slice, prev_embedding, past_keys, past_values):
|
|
"""
|
|
Args:
|
|
input_ids: (B, T)
|
|
cos_slice: (1, T, 1, head_dim//2) - pre-sliced rotary cos for current positions
|
|
sin_slice: (1, T, 1, head_dim//2) - pre-sliced rotary sin for current positions
|
|
prev_embedding: (B, 1, n_embd) - previous token's embedding for smear
|
|
past_keys: (n_layers, B, past_len, n_kv_head, head_dim)
|
|
past_values: (n_layers, B, past_len, n_kv_head, head_dim)
|
|
"""
|
|
B, T = input_ids.size()
|
|
cos = cos_slice
|
|
sin = sin_slice
|
|
|
|
# Token embedding + norm
|
|
x = self.wte(input_ids)
|
|
x = x.to(self.cos.dtype)
|
|
x = norm(x)
|
|
|
|
# Smear: mix previous token's embedding
|
|
new_prev_embedding = x[:, -1:, :].clone()
|
|
if T > 1:
|
|
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
|
|
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
|
|
else:
|
|
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
|
|
x = x + gate * prev_embedding
|
|
|
|
# Transformer blocks with explicit KV cache
|
|
x0 = x
|
|
backout_layer = self.n_layer // 2
|
|
x_backout = torch.zeros_like(x)
|
|
present_keys_list = []
|
|
present_values_list = []
|
|
|
|
for i, block in enumerate(self.blocks):
|
|
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
|
|
|
# Value embeddings (alternating layers)
|
|
ve = self.value_embeds[str(i)](input_ids).to(x.dtype) if str(i) in self.value_embeds else None
|
|
|
|
# Get past KV for this layer
|
|
layer_past_k = past_keys[i]
|
|
layer_past_v = past_values[i]
|
|
|
|
x, present_k, present_v = block(x, ve, cos, sin, layer_past_k, layer_past_v)
|
|
present_keys_list.append(present_k.unsqueeze(0))
|
|
present_values_list.append(present_v.unsqueeze(0))
|
|
|
|
if i == backout_layer:
|
|
x_backout = x
|
|
|
|
# Backout: subtract mid-layer residual
|
|
x = x - self.backout_lambda.to(x.dtype) * x_backout
|
|
x = norm(x)
|
|
|
|
# Logits with softcap
|
|
logits = self.lm_head(x)
|
|
logits = logits[..., :self.vocab_size]
|
|
logits = logits.float()
|
|
logits = 15.0 * torch.tanh(logits / 15.0)
|
|
|
|
# Stack KV caches
|
|
present_keys = torch.cat(present_keys_list, dim=0)
|
|
present_values = torch.cat(present_values_list, dim=0)
|
|
|
|
return logits, new_prev_embedding, present_keys, present_values
|
|
|
|
|
|
def export_model(model, tokenizer, output_dir):
|
|
"""Export model to ONNX format."""
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
config = model.config
|
|
|
|
print("Building ONNX-compatible model wrapper...")
|
|
onnx_model = OnnxGPT(model)
|
|
onnx_model.eval()
|
|
|
|
n_kv_head = config.n_kv_head
|
|
head_dim = config.n_embd // config.n_head
|
|
n_layer = config.n_layer
|
|
|
|
# Pre-compute rotary embeddings on CPU
|
|
cos_full = onnx_model.cos # (1, max_len, 1, head_dim//2)
|
|
sin_full = onnx_model.sin
|
|
|
|
# --- Export Prefill Model ---
|
|
print("\nExporting prefill model...")
|
|
batch_size = 1
|
|
seq_len = 4 # dummy prompt length for tracing
|
|
dummy_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
|
dummy_cos = cos_full[:, :seq_len].contiguous()
|
|
dummy_sin = sin_full[:, :seq_len].contiguous()
|
|
dummy_prev_emb = torch.zeros(batch_size, 1, config.n_embd, dtype=COMPUTE_DTYPE)
|
|
dummy_past_keys = torch.zeros(n_layer, batch_size, 0, n_kv_head, head_dim, dtype=COMPUTE_DTYPE)
|
|
dummy_past_values = torch.zeros(n_layer, batch_size, 0, n_kv_head, head_dim, dtype=COMPUTE_DTYPE)
|
|
|
|
# Test forward pass first
|
|
print(" Testing forward pass...")
|
|
with torch.no_grad():
|
|
logits, new_prev, pk, pv = onnx_model(dummy_ids, dummy_cos, dummy_sin, dummy_prev_emb, dummy_past_keys, dummy_past_values)
|
|
print(f" Prefill output shapes: logits={logits.shape}, present_keys={pk.shape}")
|
|
|
|
# Export with legacy exporter (dynamo=False) for better compatibility
|
|
prefill_path = os.path.join(output_dir, "prefill.onnx")
|
|
print(f" Exporting to {prefill_path}...")
|
|
torch.onnx.export(
|
|
onnx_model,
|
|
(dummy_ids, dummy_cos, dummy_sin, dummy_prev_emb, dummy_past_keys, dummy_past_values),
|
|
prefill_path,
|
|
dynamo=False,
|
|
input_names=["input_ids", "cos", "sin", "prev_embedding", "past_keys", "past_values"],
|
|
output_names=["logits", "new_prev_embedding", "present_keys", "present_values"],
|
|
dynamic_axes={
|
|
"input_ids": {1: "seq_len"},
|
|
"cos": {1: "seq_len"},
|
|
"sin": {1: "seq_len"},
|
|
"past_keys": {2: "past_len"},
|
|
"past_values": {2: "past_len"},
|
|
"logits": {1: "seq_len"},
|
|
"present_keys": {2: "total_len"},
|
|
"present_values": {2: "total_len"},
|
|
},
|
|
opset_version=17,
|
|
do_constant_folding=True,
|
|
)
|
|
print(f" Prefill model exported: {os.path.getsize(prefill_path) / 1e6:.1f} MB")
|
|
|
|
# --- Export Decode Model (single token) ---
|
|
print("\nExporting decode model...")
|
|
dummy_decode_ids = torch.randint(0, config.vocab_size, (batch_size, 1))
|
|
dummy_decode_cos = cos_full[:, seq_len:seq_len+1].contiguous()
|
|
dummy_decode_sin = sin_full[:, seq_len:seq_len+1].contiguous()
|
|
dummy_decode_prev = new_prev.detach()
|
|
dummy_decode_past_k = pk.detach()
|
|
dummy_decode_past_v = pv.detach()
|
|
|
|
decode_path = os.path.join(output_dir, "decode.onnx")
|
|
print(f" Exporting to {decode_path}...")
|
|
torch.onnx.export(
|
|
onnx_model,
|
|
(dummy_decode_ids, dummy_decode_cos, dummy_decode_sin, dummy_decode_prev, dummy_decode_past_k, dummy_decode_past_v),
|
|
decode_path,
|
|
dynamo=False,
|
|
input_names=["input_ids", "cos", "sin", "prev_embedding", "past_keys", "past_values"],
|
|
output_names=["logits", "new_prev_embedding", "present_keys", "present_values"],
|
|
dynamic_axes={
|
|
"past_keys": {2: "past_len"},
|
|
"past_values": {2: "past_len"},
|
|
"present_keys": {2: "total_len"},
|
|
"present_values": {2: "total_len"},
|
|
},
|
|
opset_version=17,
|
|
do_constant_folding=True,
|
|
)
|
|
print(f" Decode model exported: {os.path.getsize(decode_path) / 1e6:.1f} MB")
|
|
|
|
# --- Save config for JS runtime ---
|
|
config_dict = {
|
|
"n_layer": config.n_layer,
|
|
"n_head": config.n_head,
|
|
"n_kv_head": config.n_kv_head,
|
|
"n_embd": config.n_embd,
|
|
"head_dim": head_dim,
|
|
"vocab_size": config.vocab_size,
|
|
"sequence_len": config.sequence_len,
|
|
"window_pattern": config.window_pattern,
|
|
}
|
|
config_path = os.path.join(output_dir, "config.json")
|
|
with open(config_path, "w") as f:
|
|
json.dump(config_dict, f, indent=2)
|
|
print(f"\nConfig saved to {config_path}")
|
|
|
|
# --- Validate with ONNX Runtime ---
|
|
print("\nValidating with ONNX Runtime...")
|
|
try:
|
|
import onnxruntime as ort
|
|
# Test prefill
|
|
sess = ort.InferenceSession(prefill_path, providers=["CPUExecutionProvider"])
|
|
feeds = {
|
|
"input_ids": dummy_ids.numpy(),
|
|
"cos": dummy_cos.numpy().astype("float32"),
|
|
"sin": dummy_sin.numpy().astype("float32"),
|
|
"prev_embedding": dummy_prev_emb.numpy().astype("float32"),
|
|
"past_keys": dummy_past_keys.numpy().astype("float32"),
|
|
"past_values": dummy_past_values.numpy().astype("float32"),
|
|
}
|
|
ort_logits, ort_prev, ort_pk, ort_pv = sess.run(None, feeds)
|
|
print(f" Prefill ONNX Runtime OK: logits={ort_logits.shape}")
|
|
|
|
# Compare with PyTorch output
|
|
max_diff = abs(logits.numpy() - ort_logits).max()
|
|
print(f" Max logit difference (PyTorch vs ONNX Runtime): {max_diff:.6f}")
|
|
|
|
# Test decode
|
|
sess_decode = ort.InferenceSession(decode_path, providers=["CPUExecutionProvider"])
|
|
feeds_decode = {
|
|
"input_ids": dummy_decode_ids.numpy(),
|
|
"cos": dummy_decode_cos.numpy().astype("float32"),
|
|
"sin": dummy_decode_sin.numpy().astype("float32"),
|
|
"prev_embedding": ort_prev.astype("float32"),
|
|
"past_keys": ort_pk.astype("float32"),
|
|
"past_values": ort_pv.astype("float32"),
|
|
}
|
|
ort_dec_logits, _, _, _ = sess_decode.run(None, feeds_decode)
|
|
print(f" Decode ONNX Runtime OK: logits={ort_dec_logits.shape}")
|
|
print("\nONNX export validated successfully!")
|
|
|
|
except Exception as e:
|
|
print(f" ONNX Runtime validation failed: {e}")
|
|
print(" The ONNX files were still exported - fix the validation issue before deployment.")
|
|
|
|
return prefill_path, decode_path
|
|
|
|
|
|
def quantize_models(output_dir):
|
|
"""Quantize ONNX models to INT4 for smaller download size."""
|
|
try:
|
|
from onnxruntime.quantization import quantize_dynamic, QuantType
|
|
except ImportError:
|
|
print("onnxruntime quantization not available. Install: pip install onnxruntime")
|
|
return
|
|
|
|
for name in ["prefill", "decode"]:
|
|
input_path = os.path.join(output_dir, f"{name}.onnx")
|
|
output_path = os.path.join(output_dir, f"{name}_q4.onnx")
|
|
print(f"\nQuantizing {name} to INT4...")
|
|
quantize_dynamic(
|
|
input_path,
|
|
output_path,
|
|
weight_type=QuantType.QUInt8, # INT4 not always available, use INT8 as fallback
|
|
)
|
|
original_size = os.path.getsize(input_path) / 1e6
|
|
quant_size = os.path.getsize(output_path) / 1e6
|
|
print(f" {original_size:.1f} MB -> {quant_size:.1f} MB ({quant_size/original_size*100:.0f}%)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("=" * 60)
|
|
print("samosaChaat ONNX Export")
|
|
print("=" * 60)
|
|
|
|
# Load model
|
|
device = torch.device("cpu")
|
|
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
|
model.eval()
|
|
|
|
print(f"\nModel: {sum(p.numel() for p in model.parameters()):,} params")
|
|
print(f"Config: n_layer={model.config.n_layer}, n_head={model.config.n_head}, "
|
|
f"n_kv_head={model.config.n_kv_head}, n_embd={model.config.n_embd}")
|
|
|
|
# Export
|
|
prefill_path, decode_path = export_model(model, tokenizer, args.output_dir)
|
|
|
|
# Quantize if requested
|
|
if args.quantize:
|
|
quantize_models(args.output_dir)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Export complete!")
|
|
print(f"Files in: {args.output_dir}/")
|
|
print("=" * 60)
|