nanochat/scripts/export_model.py

453 lines
15 KiB
Python

#!/usr/bin/env python3
"""
Export nanochat models to TorchScript and ONNX formats.
This script exports trained nanochat models to formats that can be used
for inference in C++, C#, Java, and other languages.
Supported formats:
- TorchScript (.pt): For use with LibTorch (C++ PyTorch API)
- ONNX (.onnx): For use with ONNX Runtime (cross-platform)
Usage examples:
# Export SFT model to TorchScript
python -m scripts.export_model --source sft --format torchscript --output model.pt
# Export to ONNX
python -m scripts.export_model --source sft --format onnx --output model.onnx
# Export with specific model tag and step
python -m scripts.export_model --source mid --model-tag d20 --step 10000 --format both
# Export with KV cache support (experimental)
python -m scripts.export_model --source sft --format torchscript --with-cache --output model_cache.pt
"""
import argparse
import os
import torch
import torch.nn as nn
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.export_wrapper import ExportableGPT, ExportableGPTWithCache
def export_to_torchscript(
model,
output_path: str,
with_cache: bool = False,
max_seq_len: int = 4096,
example_seq_len: int = 32,
device: torch.device = None
):
"""
Export model to TorchScript format.
Args:
model: The GPT model to export
output_path: Path to save the exported model
with_cache: Whether to include KV cache support
max_seq_len: Maximum sequence length for rotary embeddings
example_seq_len: Sequence length for tracing
device: Device to use for export
"""
print(f"Exporting to TorchScript (with_cache={with_cache})...")
# Create wrapper
if with_cache:
wrapper = ExportableGPTWithCache(model, max_seq_len=max_seq_len)
else:
wrapper = ExportableGPT(model, max_seq_len=max_seq_len)
wrapper.eval()
# Create example inputs for tracing
batch_size = 1
example_input_ids = torch.randint(
0, model.config.vocab_size,
(batch_size, example_seq_len),
dtype=torch.long,
device=device
)
if with_cache:
# Example with cache
n_layers = model.config.n_layer
n_kv_head = model.config.n_kv_head
head_dim = model.config.n_embd // model.config.n_head
example_cache_k = torch.zeros(
n_layers, batch_size, n_kv_head, max_seq_len, head_dim,
dtype=torch.bfloat16, device=device
)
example_cache_v = torch.zeros(
n_layers, batch_size, n_kv_head, max_seq_len, head_dim,
dtype=torch.bfloat16, device=device
)
example_position = torch.tensor(0, dtype=torch.long, device=device)
example_inputs = (example_input_ids, example_cache_k, example_cache_v, example_position)
else:
example_inputs = (example_input_ids,)
# Trace the model
print("Tracing model with example inputs...")
try:
traced_model = torch.jit.trace(wrapper, example_inputs)
# Save the traced model
print(f"Saving TorchScript model to {output_path}...")
traced_model.save(output_path)
print(f"✓ Successfully exported to TorchScript: {output_path}")
# Verify the export
print("Verifying export...")
with torch.no_grad():
original_output = wrapper(*example_inputs)
loaded_model = torch.jit.load(output_path)
traced_output = loaded_model(*example_inputs)
if with_cache:
# Compare logits only (first output)
max_diff = torch.max(torch.abs(original_output[0] - traced_output[0])).item()
else:
max_diff = torch.max(torch.abs(original_output - traced_output)).item()
print(f" Max difference between original and traced: {max_diff:.6e}")
if max_diff < 1e-4:
print(" ✓ Verification passed!")
else:
print(f" ⚠ Warning: Difference is larger than expected")
return True
except Exception as e:
print(f"✗ Failed to export to TorchScript: {e}")
import traceback
traceback.print_exc()
return False
def export_to_onnx(
model,
output_path: str,
with_cache: bool = False,
max_seq_len: int = 4096,
example_seq_len: int = 32,
device: torch.device = None,
opset_version: int = 17
):
"""
Export model to ONNX format.
Args:
model: The GPT model to export
output_path: Path to save the exported model
with_cache: Whether to include KV cache support
max_seq_len: Maximum sequence length for rotary embeddings
example_seq_len: Sequence length for export
device: Device to use for export
opset_version: ONNX opset version
"""
print(f"Exporting to ONNX (with_cache={with_cache}, opset={opset_version})...")
# Create wrapper
if with_cache:
wrapper = ExportableGPTWithCache(model, max_seq_len=max_seq_len)
else:
wrapper = ExportableGPT(model, max_seq_len=max_seq_len)
wrapper.eval()
# Create example inputs
batch_size = 1
example_input_ids = torch.randint(
0, model.config.vocab_size,
(batch_size, example_seq_len),
dtype=torch.long,
device=device
)
if with_cache:
n_layers = model.config.n_layer
n_kv_head = model.config.n_kv_head
head_dim = model.config.n_embd // model.config.n_head
example_cache_k = torch.zeros(
n_layers, batch_size, n_kv_head, max_seq_len, head_dim,
dtype=torch.bfloat16, device=device
)
example_cache_v = torch.zeros(
n_layers, batch_size, n_kv_head, max_seq_len, head_dim,
dtype=torch.bfloat16, device=device
)
example_position = torch.tensor(0, dtype=torch.long, device=device)
example_inputs = (example_input_ids, example_cache_k, example_cache_v, example_position)
input_names = ["input_ids", "cache_k", "cache_v", "position"]
output_names = ["logits", "cache_k_out", "cache_v_out"]
# Dynamic axes for variable sequence length and batch size
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"logits": {0: "batch_size", 1: "seq_len"},
}
else:
example_inputs = (example_input_ids,)
input_names = ["input_ids"]
output_names = ["logits"]
# Dynamic axes for variable sequence length and batch size
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"logits": {0: "batch_size", 1: "seq_len"},
}
# Export to ONNX
print("Exporting model to ONNX format...")
try:
with torch.no_grad():
torch.onnx.export(
wrapper,
example_inputs,
output_path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset_version,
do_constant_folding=True,
export_params=True,
)
print(f"✓ Successfully exported to ONNX: {output_path}")
# Verify with ONNX
try:
import onnx
print("Verifying ONNX model...")
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print(" ✓ ONNX model is valid!")
# Try to verify with ONNX Runtime if available
try:
import onnxruntime as ort
print("Verifying with ONNX Runtime...")
# Create inference session
ort_session = ort.InferenceSession(
output_path,
providers=['CPUExecutionProvider']
)
# Prepare inputs
if with_cache:
ort_inputs = {
"input_ids": example_input_ids.cpu().numpy(),
"cache_k": example_cache_k.cpu().numpy(),
"cache_v": example_cache_v.cpu().numpy(),
"position": example_position.cpu().numpy(),
}
else:
ort_inputs = {
"input_ids": example_input_ids.cpu().numpy(),
}
# Run inference
ort_outputs = ort_session.run(None, ort_inputs)
# Compare with PyTorch
with torch.no_grad():
torch_outputs = wrapper(*example_inputs)
if with_cache:
torch_logits = torch_outputs[0].cpu().numpy()
else:
torch_logits = torch_outputs.cpu().numpy()
ort_logits = ort_outputs[0]
max_diff = abs(torch_logits - ort_logits).max()
print(f" Max difference between PyTorch and ONNX Runtime: {max_diff:.6e}")
if max_diff < 1e-3:
print(" ✓ ONNX Runtime verification passed!")
else:
print(f" ⚠ Warning: Difference is larger than expected")
except ImportError:
print(" ⓘ ONNX Runtime not available, skipping runtime verification")
except ImportError:
print(" ⓘ ONNX package not available, skipping validation")
return True
except Exception as e:
print(f"✗ Failed to export to ONNX: {e}")
import traceback
traceback.print_exc()
return False
def main():
parser = argparse.ArgumentParser(
description="Export nanochat models to TorchScript and ONNX",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__
)
# Model selection
parser.add_argument(
"--source", "-s",
type=str,
default="sft",
choices=["base", "mid", "sft", "rl"],
help="Model source to export (default: sft)"
)
parser.add_argument(
"--model-tag", "-g",
type=str,
default=None,
help="Model tag to load (e.g., d20, d26)"
)
parser.add_argument(
"--step",
type=int,
default=None,
help="Specific checkpoint step to load"
)
# Export options
parser.add_argument(
"--format", "-f",
type=str,
default="torchscript",
choices=["torchscript", "onnx", "both"],
help="Export format (default: torchscript)"
)
parser.add_argument(
"--output", "-o",
type=str,
default=None,
help="Output file path (default: model.pt or model.onnx)"
)
parser.add_argument(
"--with-cache",
action="store_true",
help="Export with KV cache support (experimental, may not work with ONNX)"
)
parser.add_argument(
"--max-seq-len",
type=int,
default=4096,
help="Maximum sequence length for rotary embeddings (default: 4096)"
)
parser.add_argument(
"--example-seq-len",
type=int,
default=32,
help="Sequence length for tracing/export (default: 32)"
)
parser.add_argument(
"--opset-version",
type=int,
default=17,
help="ONNX opset version (default: 17)"
)
# Device options
parser.add_argument(
"--device-type",
type=str,
default="",
choices=["cuda", "cpu", "mps", ""],
help="Device type: cuda|cpu|mps (empty = autodetect)"
)
args = parser.parse_args()
# Initialize device
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
print("="*60)
print("nanochat Model Export")
print("="*60)
print(f"Source: {args.source}")
print(f"Model tag: {args.model_tag or 'default'}")
print(f"Step: {args.step or 'latest'}")
print(f"Format: {args.format}")
print(f"Device: {device}")
print(f"Max sequence length: {args.max_seq_len}")
print(f"With KV cache: {args.with_cache}")
print("="*60)
# Load the model
print("\nLoading model...")
model, tokenizer, meta = load_model(
args.source,
device,
phase="eval",
model_tag=args.model_tag,
step=args.step
)
model.eval()
print(f"✓ Model loaded successfully")
print(f" Config: {model.config}")
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Determine output paths
if args.output:
base_path = args.output.rsplit(".", 1)[0]
ext = args.output.rsplit(".", 1)[1] if "." in args.output else ""
else:
base_path = f"nanochat_{args.source}"
if args.model_tag:
base_path += f"_{args.model_tag}"
if args.step:
base_path += f"_step{args.step}"
if args.with_cache:
base_path += "_cache"
ext = ""
# Export to requested formats
success = True
if args.format in ["torchscript", "both"]:
output_path = f"{base_path}.pt" if not ext or ext == "pt" else args.output
success &= export_to_torchscript(
model,
output_path,
with_cache=args.with_cache,
max_seq_len=args.max_seq_len,
example_seq_len=args.example_seq_len,
device=device
)
if args.format in ["onnx", "both"]:
output_path = f"{base_path}.onnx" if not ext or ext == "onnx" else args.output
success &= export_to_onnx(
model,
output_path,
with_cache=args.with_cache,
max_seq_len=args.max_seq_len,
example_seq_len=args.example_seq_len,
device=device,
opset_version=args.opset_version
)
print("\n" + "="*60)
if success:
print("✓ Export completed successfully!")
print("\nNext steps:")
print(" - For TorchScript: Use LibTorch C++ API to load and run inference")
print(" - For ONNX: Use ONNX Runtime in C++, C#, Java, or other languages")
print("\nSee examples/cpp_inference/ for C++ usage examples")
else:
print("✗ Export failed. See errors above.")
print("="*60)
if __name__ == "__main__":
main()