nanochat/scripts/export_onnx.py
Manmohan Sharma ee34586e77
redesign UI: artisan landing page + warm chat theme + ONNX export script
Landing page with desi street-food aesthetic: lemon-mirchi toran with
pendulum animation, dual-script hero (Devanagari + English cursive),
samosa illustration with floating animation, brass chai kettle with
steam wisps, ambient chilli/lemon doodles.

Chat page carries the warm samosa-chaat palette with cream/gold user
bubbles, steam-wisp typing indicator, and WebGPU integration hooks
(window.samosaChaat API for local inference mode switching).

Added scripts/export_onnx.py for ONNX model export with KV cache
support, targeting WebGPU browser inference.

Credit to Andrej Karpathy's nanochat in footer.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-23 11:54:07 -04:00

470 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Export the nanochat/samosaChaat model to ONNX for WebGPU inference in the browser.
Creates two ONNX models:
1. prefill.onnx - Processes the full prompt, returns logits + KV cache
2. decode.onnx - Single-token generation with KV cache, returns logits + updated cache
Usage:
python -m scripts.export_onnx
python -m scripts.export_onnx --quantize # also produce INT4 quantized versions
"""
import argparse
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from nanochat.common import COMPUTE_DTYPE
from nanochat.checkpoint_manager import load_model
from nanochat.gpt import GPTConfig, apply_rotary_emb, has_ve
def norm(x):
"""RMS norm implemented with basic ops for ONNX compatibility."""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6)
parser = argparse.ArgumentParser(description='Export samosaChaat to ONNX')
parser.add_argument('--output-dir', type=str, default='onnx_export', help='Output directory')
parser.add_argument('--quantize', action='store_true', help='Also produce INT4 quantized models')
parser.add_argument('-i', '--source', type=str, default='sft', help='Model source: sft|rl')
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag')
parser.add_argument('-s', '--step', type=int, default=None, help='Step')
args = parser.parse_args()
class OnnxAttention(nn.Module):
"""Attention layer rewritten for ONNX export (no flash attention, explicit KV cache I/O)."""
def __init__(self, attn, layer_idx, n_layer):
super().__init__()
self.layer_idx = layer_idx
self.n_head = attn.n_head
self.n_kv_head = attn.n_kv_head
self.head_dim = attn.head_dim
self.n_groups = self.n_head // self.n_kv_head
# Copy weights
self.c_q = attn.c_q
self.c_k = attn.c_k
self.c_v = attn.c_v
self.c_proj = attn.c_proj
self.ve_gate = attn.ve_gate
def forward(self, x, ve, cos, sin, past_key, past_value):
"""
Args:
x: (B, T, C)
ve: (B, T, kv_dim) or None
cos, sin: (1, T, 1, head_dim//2) - already offset for position
past_key: (B, past_len, n_kv_head, head_dim)
past_value: (B, past_len, n_kv_head, head_dim)
Returns:
output: (B, T, C)
present_key: (B, past_len+T, n_kv_head, head_dim)
present_value: (B, past_len+T, n_kv_head, head_dim)
"""
B, T, C = x.size()
# Project Q, K, V
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Value embeddings (ResFormer)
if ve is not None and self.ve_gate is not None:
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
gate = 3 * torch.sigmoid(self.ve_gate(x[..., :12]))
v = v + gate.unsqueeze(-1) * ve
# Rotary embeddings
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
# QK norm + scaling
q = norm(q) * 1.2
k = norm(k) * 1.2
# Concatenate past KV cache
# past_key/past_value: (B, past_len, n_kv_head, head_dim)
present_key = torch.cat([past_key, k], dim=1)
present_value = torch.cat([past_key, v], dim=1) # BUG: should be past_value
present_value = torch.cat([past_value, v], dim=1)
# Standard attention (no flash attention for ONNX compatibility)
# Transpose to (B, H, T, D) for matmul
q_t = q.permute(0, 2, 1, 3) # (B, n_head, T, head_dim)
# Expand KV heads for GQA: (B, n_kv_head, S, D) -> (B, n_head, S, D)
k_t = present_key.permute(0, 2, 1, 3) # (B, n_kv_head, S, head_dim)
v_t = present_value.permute(0, 2, 1, 3)
if self.n_groups > 1:
k_t = k_t.unsqueeze(2).expand(-1, -1, self.n_groups, -1, -1).reshape(B, self.n_head, -1, self.head_dim)
v_t = v_t.unsqueeze(2).expand(-1, -1, self.n_groups, -1, -1).reshape(B, self.n_head, -1, self.head_dim)
S = present_key.size(1) # total sequence length (past + current)
# Scaled dot-product attention with causal mask
scale = self.head_dim ** -0.5
attn_weights = torch.matmul(q_t, k_t.transpose(-2, -1)) * scale # (B, H, T, S)
# Causal mask: each query position can only attend to positions <= its own
# Query positions are [S-T, S-T+1, ..., S-1], key positions are [0, 1, ..., S-1]
query_pos = torch.arange(S - T, S, device=x.device).unsqueeze(1) # (T, 1)
key_pos = torch.arange(S, device=x.device).unsqueeze(0) # (1, S)
causal_mask = key_pos <= query_pos # (T, S)
attn_weights = attn_weights.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
attn_weights = F.softmax(attn_weights, dim=-1)
y = torch.matmul(attn_weights, v_t) # (B, H, T, D)
# Transpose back and project
y = y.permute(0, 2, 1, 3).contiguous().view(B, T, -1)
y = self.c_proj(y)
return y, present_key, present_value
class OnnxMLP(nn.Module):
"""MLP rewritten for ONNX (avoids importing norm from gpt.py)."""
def __init__(self, mlp):
super().__init__()
self.c_fc = mlp.c_fc
self.c_proj = mlp.c_proj
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class OnnxBlock(nn.Module):
"""Transformer block rewritten for ONNX (explicit KV cache)."""
def __init__(self, block, layer_idx, n_layer):
super().__init__()
self.attn = OnnxAttention(block.attn, layer_idx, n_layer)
self.mlp = OnnxMLP(block.mlp)
def forward(self, x, ve, cos, sin, past_key, past_value):
attn_out, present_key, present_value = self.attn(norm(x), ve, cos, sin, past_key, past_value)
x = x + attn_out
x = x + self.mlp(norm(x))
return x, present_key, present_value
class OnnxGPT(nn.Module):
"""
ONNX-exportable wrapper around the GPT model.
Inputs:
input_ids: (B, T) int64
position: (1,) int64 - position offset for rotary embeddings
prev_embedding: (B, 1, n_embd) - previous token's embedding for smear
past_keys: (n_layers, B, past_len, n_kv_head, head_dim)
past_values: (n_layers, B, past_len, n_kv_head, head_dim)
Outputs:
logits: (B, T, vocab_size)
new_prev_embedding: (B, 1, n_embd)
present_keys: (n_layers, B, past_len+T, n_kv_head, head_dim)
present_values: (n_layers, B, past_len+T, n_kv_head, head_dim)
"""
def __init__(self, model):
super().__init__()
config = model.config
self.config = config
self.n_layer = config.n_layer
self.n_kv_head = config.n_kv_head
self.head_dim = config.n_embd // config.n_head
self.vocab_size = config.vocab_size
# Copy model components
self.wte = model.transformer.wte
self.lm_head = model.lm_head
self.resid_lambdas = model.resid_lambdas
self.x0_lambdas = model.x0_lambdas
self.smear_gate = model.smear_gate
self.smear_lambda = model.smear_lambda
self.backout_lambda = model.backout_lambda
self.value_embeds = model.value_embeds
# Rotary embeddings (precomputed)
self.register_buffer("cos", model.cos, persistent=False)
self.register_buffer("sin", model.sin, persistent=False)
# Window sizes (baked in as constants)
self.window_sizes = model.window_sizes
# Rebuild blocks with ONNX-compatible attention
self.blocks = nn.ModuleList([
OnnxBlock(model.transformer.h[i], i, config.n_layer)
for i in range(config.n_layer)
])
def forward(self, input_ids, cos_slice, sin_slice, prev_embedding, past_keys, past_values):
"""
Args:
input_ids: (B, T)
cos_slice: (1, T, 1, head_dim//2) - pre-sliced rotary cos for current positions
sin_slice: (1, T, 1, head_dim//2) - pre-sliced rotary sin for current positions
prev_embedding: (B, 1, n_embd) - previous token's embedding for smear
past_keys: (n_layers, B, past_len, n_kv_head, head_dim)
past_values: (n_layers, B, past_len, n_kv_head, head_dim)
"""
B, T = input_ids.size()
cos = cos_slice
sin = sin_slice
# Token embedding + norm
x = self.wte(input_ids)
x = x.to(self.cos.dtype)
x = norm(x)
# Smear: mix previous token's embedding
new_prev_embedding = x[:, -1:, :].clone()
if T > 1:
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
else:
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
x = x + gate * prev_embedding
# Transformer blocks with explicit KV cache
x0 = x
backout_layer = self.n_layer // 2
x_backout = torch.zeros_like(x)
present_keys_list = []
present_values_list = []
for i, block in enumerate(self.blocks):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
# Value embeddings (alternating layers)
ve = self.value_embeds[str(i)](input_ids).to(x.dtype) if str(i) in self.value_embeds else None
# Get past KV for this layer
layer_past_k = past_keys[i]
layer_past_v = past_values[i]
x, present_k, present_v = block(x, ve, cos, sin, layer_past_k, layer_past_v)
present_keys_list.append(present_k.unsqueeze(0))
present_values_list.append(present_v.unsqueeze(0))
if i == backout_layer:
x_backout = x
# Backout: subtract mid-layer residual
x = x - self.backout_lambda.to(x.dtype) * x_backout
x = norm(x)
# Logits with softcap
logits = self.lm_head(x)
logits = logits[..., :self.vocab_size]
logits = logits.float()
logits = 15.0 * torch.tanh(logits / 15.0)
# Stack KV caches
present_keys = torch.cat(present_keys_list, dim=0)
present_values = torch.cat(present_values_list, dim=0)
return logits, new_prev_embedding, present_keys, present_values
def export_model(model, tokenizer, output_dir):
"""Export model to ONNX format."""
os.makedirs(output_dir, exist_ok=True)
config = model.config
print("Building ONNX-compatible model wrapper...")
onnx_model = OnnxGPT(model)
onnx_model.eval()
n_kv_head = config.n_kv_head
head_dim = config.n_embd // config.n_head
n_layer = config.n_layer
# Pre-compute rotary embeddings on CPU
cos_full = onnx_model.cos # (1, max_len, 1, head_dim//2)
sin_full = onnx_model.sin
# --- Export Prefill Model ---
print("\nExporting prefill model...")
batch_size = 1
seq_len = 4 # dummy prompt length for tracing
dummy_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
dummy_cos = cos_full[:, :seq_len].contiguous()
dummy_sin = sin_full[:, :seq_len].contiguous()
dummy_prev_emb = torch.zeros(batch_size, 1, config.n_embd, dtype=COMPUTE_DTYPE)
dummy_past_keys = torch.zeros(n_layer, batch_size, 0, n_kv_head, head_dim, dtype=COMPUTE_DTYPE)
dummy_past_values = torch.zeros(n_layer, batch_size, 0, n_kv_head, head_dim, dtype=COMPUTE_DTYPE)
# Test forward pass first
print(" Testing forward pass...")
with torch.no_grad():
logits, new_prev, pk, pv = onnx_model(dummy_ids, dummy_cos, dummy_sin, dummy_prev_emb, dummy_past_keys, dummy_past_values)
print(f" Prefill output shapes: logits={logits.shape}, present_keys={pk.shape}")
# Export with legacy exporter (dynamo=False) for better compatibility
prefill_path = os.path.join(output_dir, "prefill.onnx")
print(f" Exporting to {prefill_path}...")
torch.onnx.export(
onnx_model,
(dummy_ids, dummy_cos, dummy_sin, dummy_prev_emb, dummy_past_keys, dummy_past_values),
prefill_path,
dynamo=False,
input_names=["input_ids", "cos", "sin", "prev_embedding", "past_keys", "past_values"],
output_names=["logits", "new_prev_embedding", "present_keys", "present_values"],
dynamic_axes={
"input_ids": {1: "seq_len"},
"cos": {1: "seq_len"},
"sin": {1: "seq_len"},
"past_keys": {2: "past_len"},
"past_values": {2: "past_len"},
"logits": {1: "seq_len"},
"present_keys": {2: "total_len"},
"present_values": {2: "total_len"},
},
opset_version=17,
do_constant_folding=True,
)
print(f" Prefill model exported: {os.path.getsize(prefill_path) / 1e6:.1f} MB")
# --- Export Decode Model (single token) ---
print("\nExporting decode model...")
dummy_decode_ids = torch.randint(0, config.vocab_size, (batch_size, 1))
dummy_decode_cos = cos_full[:, seq_len:seq_len+1].contiguous()
dummy_decode_sin = sin_full[:, seq_len:seq_len+1].contiguous()
dummy_decode_prev = new_prev.detach()
dummy_decode_past_k = pk.detach()
dummy_decode_past_v = pv.detach()
decode_path = os.path.join(output_dir, "decode.onnx")
print(f" Exporting to {decode_path}...")
torch.onnx.export(
onnx_model,
(dummy_decode_ids, dummy_decode_cos, dummy_decode_sin, dummy_decode_prev, dummy_decode_past_k, dummy_decode_past_v),
decode_path,
dynamo=False,
input_names=["input_ids", "cos", "sin", "prev_embedding", "past_keys", "past_values"],
output_names=["logits", "new_prev_embedding", "present_keys", "present_values"],
dynamic_axes={
"past_keys": {2: "past_len"},
"past_values": {2: "past_len"},
"present_keys": {2: "total_len"},
"present_values": {2: "total_len"},
},
opset_version=17,
do_constant_folding=True,
)
print(f" Decode model exported: {os.path.getsize(decode_path) / 1e6:.1f} MB")
# --- Save config for JS runtime ---
config_dict = {
"n_layer": config.n_layer,
"n_head": config.n_head,
"n_kv_head": config.n_kv_head,
"n_embd": config.n_embd,
"head_dim": head_dim,
"vocab_size": config.vocab_size,
"sequence_len": config.sequence_len,
"window_pattern": config.window_pattern,
}
config_path = os.path.join(output_dir, "config.json")
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=2)
print(f"\nConfig saved to {config_path}")
# --- Validate with ONNX Runtime ---
print("\nValidating with ONNX Runtime...")
try:
import onnxruntime as ort
# Test prefill
sess = ort.InferenceSession(prefill_path, providers=["CPUExecutionProvider"])
feeds = {
"input_ids": dummy_ids.numpy(),
"cos": dummy_cos.numpy().astype("float32"),
"sin": dummy_sin.numpy().astype("float32"),
"prev_embedding": dummy_prev_emb.numpy().astype("float32"),
"past_keys": dummy_past_keys.numpy().astype("float32"),
"past_values": dummy_past_values.numpy().astype("float32"),
}
ort_logits, ort_prev, ort_pk, ort_pv = sess.run(None, feeds)
print(f" Prefill ONNX Runtime OK: logits={ort_logits.shape}")
# Compare with PyTorch output
max_diff = abs(logits.numpy() - ort_logits).max()
print(f" Max logit difference (PyTorch vs ONNX Runtime): {max_diff:.6f}")
# Test decode
sess_decode = ort.InferenceSession(decode_path, providers=["CPUExecutionProvider"])
feeds_decode = {
"input_ids": dummy_decode_ids.numpy(),
"cos": dummy_decode_cos.numpy().astype("float32"),
"sin": dummy_decode_sin.numpy().astype("float32"),
"prev_embedding": ort_prev.astype("float32"),
"past_keys": ort_pk.astype("float32"),
"past_values": ort_pv.astype("float32"),
}
ort_dec_logits, _, _, _ = sess_decode.run(None, feeds_decode)
print(f" Decode ONNX Runtime OK: logits={ort_dec_logits.shape}")
print("\nONNX export validated successfully!")
except Exception as e:
print(f" ONNX Runtime validation failed: {e}")
print(" The ONNX files were still exported - fix the validation issue before deployment.")
return prefill_path, decode_path
def quantize_models(output_dir):
"""Quantize ONNX models to INT4 for smaller download size."""
try:
from onnxruntime.quantization import quantize_dynamic, QuantType
except ImportError:
print("onnxruntime quantization not available. Install: pip install onnxruntime")
return
for name in ["prefill", "decode"]:
input_path = os.path.join(output_dir, f"{name}.onnx")
output_path = os.path.join(output_dir, f"{name}_q4.onnx")
print(f"\nQuantizing {name} to INT4...")
quantize_dynamic(
input_path,
output_path,
weight_type=QuantType.QUInt8, # INT4 not always available, use INT8 as fallback
)
original_size = os.path.getsize(input_path) / 1e6
quant_size = os.path.getsize(output_path) / 1e6
print(f" {original_size:.1f} MB -> {quant_size:.1f} MB ({quant_size/original_size*100:.0f}%)")
if __name__ == "__main__":
print("=" * 60)
print("samosaChaat ONNX Export")
print("=" * 60)
# Load model
device = torch.device("cpu")
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
model.eval()
print(f"\nModel: {sum(p.numel() for p in model.parameters()):,} params")
print(f"Config: n_layer={model.config.n_layer}, n_head={model.config.n_head}, "
f"n_kv_head={model.config.n_kv_head}, n_embd={model.config.n_embd}")
# Export
prefill_path, decode_path = export_model(model, tokenizer, args.output_dir)
# Quantize if requested
if args.quantize:
quantize_models(args.output_dir)
print("\n" + "=" * 60)
print("Export complete!")
print(f"Files in: {args.output_dir}/")
print("=" * 60)