mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
453 lines
15 KiB
Python
453 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Export nanochat models to TorchScript and ONNX formats.
|
|
|
|
This script exports trained nanochat models to formats that can be used
|
|
for inference in C++, C#, Java, and other languages.
|
|
|
|
Supported formats:
|
|
- TorchScript (.pt): For use with LibTorch (C++ PyTorch API)
|
|
- ONNX (.onnx): For use with ONNX Runtime (cross-platform)
|
|
|
|
Usage examples:
|
|
|
|
# Export SFT model to TorchScript
|
|
python -m scripts.export_model --source sft --format torchscript --output model.pt
|
|
|
|
# Export to ONNX
|
|
python -m scripts.export_model --source sft --format onnx --output model.onnx
|
|
|
|
# Export with specific model tag and step
|
|
python -m scripts.export_model --source mid --model-tag d20 --step 10000 --format both
|
|
|
|
# Export with KV cache support (experimental)
|
|
python -m scripts.export_model --source sft --format torchscript --with-cache --output model_cache.pt
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
from nanochat.common import compute_init, autodetect_device_type
|
|
from nanochat.checkpoint_manager import load_model
|
|
from nanochat.export_wrapper import ExportableGPT, ExportableGPTWithCache
|
|
|
|
|
|
def export_to_torchscript(
|
|
model,
|
|
output_path: str,
|
|
with_cache: bool = False,
|
|
max_seq_len: int = 4096,
|
|
example_seq_len: int = 32,
|
|
device: torch.device = None
|
|
):
|
|
"""
|
|
Export model to TorchScript format.
|
|
|
|
Args:
|
|
model: The GPT model to export
|
|
output_path: Path to save the exported model
|
|
with_cache: Whether to include KV cache support
|
|
max_seq_len: Maximum sequence length for rotary embeddings
|
|
example_seq_len: Sequence length for tracing
|
|
device: Device to use for export
|
|
"""
|
|
print(f"Exporting to TorchScript (with_cache={with_cache})...")
|
|
|
|
# Create wrapper
|
|
if with_cache:
|
|
wrapper = ExportableGPTWithCache(model, max_seq_len=max_seq_len)
|
|
else:
|
|
wrapper = ExportableGPT(model, max_seq_len=max_seq_len)
|
|
|
|
wrapper.eval()
|
|
|
|
# Create example inputs for tracing
|
|
batch_size = 1
|
|
example_input_ids = torch.randint(
|
|
0, model.config.vocab_size,
|
|
(batch_size, example_seq_len),
|
|
dtype=torch.long,
|
|
device=device
|
|
)
|
|
|
|
if with_cache:
|
|
# Example with cache
|
|
n_layers = model.config.n_layer
|
|
n_kv_head = model.config.n_kv_head
|
|
head_dim = model.config.n_embd // model.config.n_head
|
|
|
|
example_cache_k = torch.zeros(
|
|
n_layers, batch_size, n_kv_head, max_seq_len, head_dim,
|
|
dtype=torch.bfloat16, device=device
|
|
)
|
|
example_cache_v = torch.zeros(
|
|
n_layers, batch_size, n_kv_head, max_seq_len, head_dim,
|
|
dtype=torch.bfloat16, device=device
|
|
)
|
|
example_position = torch.tensor(0, dtype=torch.long, device=device)
|
|
|
|
example_inputs = (example_input_ids, example_cache_k, example_cache_v, example_position)
|
|
else:
|
|
example_inputs = (example_input_ids,)
|
|
|
|
# Trace the model
|
|
print("Tracing model with example inputs...")
|
|
try:
|
|
traced_model = torch.jit.trace(wrapper, example_inputs)
|
|
|
|
# Save the traced model
|
|
print(f"Saving TorchScript model to {output_path}...")
|
|
traced_model.save(output_path)
|
|
print(f"✓ Successfully exported to TorchScript: {output_path}")
|
|
|
|
# Verify the export
|
|
print("Verifying export...")
|
|
with torch.no_grad():
|
|
original_output = wrapper(*example_inputs)
|
|
loaded_model = torch.jit.load(output_path)
|
|
traced_output = loaded_model(*example_inputs)
|
|
|
|
if with_cache:
|
|
# Compare logits only (first output)
|
|
max_diff = torch.max(torch.abs(original_output[0] - traced_output[0])).item()
|
|
else:
|
|
max_diff = torch.max(torch.abs(original_output - traced_output)).item()
|
|
|
|
print(f" Max difference between original and traced: {max_diff:.6e}")
|
|
if max_diff < 1e-4:
|
|
print(" ✓ Verification passed!")
|
|
else:
|
|
print(f" ⚠ Warning: Difference is larger than expected")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"✗ Failed to export to TorchScript: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def export_to_onnx(
|
|
model,
|
|
output_path: str,
|
|
with_cache: bool = False,
|
|
max_seq_len: int = 4096,
|
|
example_seq_len: int = 32,
|
|
device: torch.device = None,
|
|
opset_version: int = 17
|
|
):
|
|
"""
|
|
Export model to ONNX format.
|
|
|
|
Args:
|
|
model: The GPT model to export
|
|
output_path: Path to save the exported model
|
|
with_cache: Whether to include KV cache support
|
|
max_seq_len: Maximum sequence length for rotary embeddings
|
|
example_seq_len: Sequence length for export
|
|
device: Device to use for export
|
|
opset_version: ONNX opset version
|
|
"""
|
|
print(f"Exporting to ONNX (with_cache={with_cache}, opset={opset_version})...")
|
|
|
|
# Create wrapper
|
|
if with_cache:
|
|
wrapper = ExportableGPTWithCache(model, max_seq_len=max_seq_len)
|
|
else:
|
|
wrapper = ExportableGPT(model, max_seq_len=max_seq_len)
|
|
|
|
wrapper.eval()
|
|
|
|
# Create example inputs
|
|
batch_size = 1
|
|
example_input_ids = torch.randint(
|
|
0, model.config.vocab_size,
|
|
(batch_size, example_seq_len),
|
|
dtype=torch.long,
|
|
device=device
|
|
)
|
|
|
|
if with_cache:
|
|
n_layers = model.config.n_layer
|
|
n_kv_head = model.config.n_kv_head
|
|
head_dim = model.config.n_embd // model.config.n_head
|
|
|
|
example_cache_k = torch.zeros(
|
|
n_layers, batch_size, n_kv_head, max_seq_len, head_dim,
|
|
dtype=torch.bfloat16, device=device
|
|
)
|
|
example_cache_v = torch.zeros(
|
|
n_layers, batch_size, n_kv_head, max_seq_len, head_dim,
|
|
dtype=torch.bfloat16, device=device
|
|
)
|
|
example_position = torch.tensor(0, dtype=torch.long, device=device)
|
|
|
|
example_inputs = (example_input_ids, example_cache_k, example_cache_v, example_position)
|
|
input_names = ["input_ids", "cache_k", "cache_v", "position"]
|
|
output_names = ["logits", "cache_k_out", "cache_v_out"]
|
|
|
|
# Dynamic axes for variable sequence length and batch size
|
|
dynamic_axes = {
|
|
"input_ids": {0: "batch_size", 1: "seq_len"},
|
|
"logits": {0: "batch_size", 1: "seq_len"},
|
|
}
|
|
else:
|
|
example_inputs = (example_input_ids,)
|
|
input_names = ["input_ids"]
|
|
output_names = ["logits"]
|
|
|
|
# Dynamic axes for variable sequence length and batch size
|
|
dynamic_axes = {
|
|
"input_ids": {0: "batch_size", 1: "seq_len"},
|
|
"logits": {0: "batch_size", 1: "seq_len"},
|
|
}
|
|
|
|
# Export to ONNX
|
|
print("Exporting model to ONNX format...")
|
|
try:
|
|
with torch.no_grad():
|
|
torch.onnx.export(
|
|
wrapper,
|
|
example_inputs,
|
|
output_path,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
dynamic_axes=dynamic_axes,
|
|
opset_version=opset_version,
|
|
do_constant_folding=True,
|
|
export_params=True,
|
|
)
|
|
|
|
print(f"✓ Successfully exported to ONNX: {output_path}")
|
|
|
|
# Verify with ONNX
|
|
try:
|
|
import onnx
|
|
print("Verifying ONNX model...")
|
|
onnx_model = onnx.load(output_path)
|
|
onnx.checker.check_model(onnx_model)
|
|
print(" ✓ ONNX model is valid!")
|
|
|
|
# Try to verify with ONNX Runtime if available
|
|
try:
|
|
import onnxruntime as ort
|
|
print("Verifying with ONNX Runtime...")
|
|
|
|
# Create inference session
|
|
ort_session = ort.InferenceSession(
|
|
output_path,
|
|
providers=['CPUExecutionProvider']
|
|
)
|
|
|
|
# Prepare inputs
|
|
if with_cache:
|
|
ort_inputs = {
|
|
"input_ids": example_input_ids.cpu().numpy(),
|
|
"cache_k": example_cache_k.cpu().numpy(),
|
|
"cache_v": example_cache_v.cpu().numpy(),
|
|
"position": example_position.cpu().numpy(),
|
|
}
|
|
else:
|
|
ort_inputs = {
|
|
"input_ids": example_input_ids.cpu().numpy(),
|
|
}
|
|
|
|
# Run inference
|
|
ort_outputs = ort_session.run(None, ort_inputs)
|
|
|
|
# Compare with PyTorch
|
|
with torch.no_grad():
|
|
torch_outputs = wrapper(*example_inputs)
|
|
if with_cache:
|
|
torch_logits = torch_outputs[0].cpu().numpy()
|
|
else:
|
|
torch_logits = torch_outputs.cpu().numpy()
|
|
|
|
ort_logits = ort_outputs[0]
|
|
max_diff = abs(torch_logits - ort_logits).max()
|
|
|
|
print(f" Max difference between PyTorch and ONNX Runtime: {max_diff:.6e}")
|
|
if max_diff < 1e-3:
|
|
print(" ✓ ONNX Runtime verification passed!")
|
|
else:
|
|
print(f" ⚠ Warning: Difference is larger than expected")
|
|
|
|
except ImportError:
|
|
print(" ⓘ ONNX Runtime not available, skipping runtime verification")
|
|
|
|
except ImportError:
|
|
print(" ⓘ ONNX package not available, skipping validation")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ Failed to export to ONNX: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Export nanochat models to TorchScript and ONNX",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog=__doc__
|
|
)
|
|
|
|
# Model selection
|
|
parser.add_argument(
|
|
"--source", "-s",
|
|
type=str,
|
|
default="sft",
|
|
choices=["base", "mid", "sft", "rl"],
|
|
help="Model source to export (default: sft)"
|
|
)
|
|
parser.add_argument(
|
|
"--model-tag", "-g",
|
|
type=str,
|
|
default=None,
|
|
help="Model tag to load (e.g., d20, d26)"
|
|
)
|
|
parser.add_argument(
|
|
"--step",
|
|
type=int,
|
|
default=None,
|
|
help="Specific checkpoint step to load"
|
|
)
|
|
|
|
# Export options
|
|
parser.add_argument(
|
|
"--format", "-f",
|
|
type=str,
|
|
default="torchscript",
|
|
choices=["torchscript", "onnx", "both"],
|
|
help="Export format (default: torchscript)"
|
|
)
|
|
parser.add_argument(
|
|
"--output", "-o",
|
|
type=str,
|
|
default=None,
|
|
help="Output file path (default: model.pt or model.onnx)"
|
|
)
|
|
parser.add_argument(
|
|
"--with-cache",
|
|
action="store_true",
|
|
help="Export with KV cache support (experimental, may not work with ONNX)"
|
|
)
|
|
parser.add_argument(
|
|
"--max-seq-len",
|
|
type=int,
|
|
default=4096,
|
|
help="Maximum sequence length for rotary embeddings (default: 4096)"
|
|
)
|
|
parser.add_argument(
|
|
"--example-seq-len",
|
|
type=int,
|
|
default=32,
|
|
help="Sequence length for tracing/export (default: 32)"
|
|
)
|
|
parser.add_argument(
|
|
"--opset-version",
|
|
type=int,
|
|
default=17,
|
|
help="ONNX opset version (default: 17)"
|
|
)
|
|
|
|
# Device options
|
|
parser.add_argument(
|
|
"--device-type",
|
|
type=str,
|
|
default="",
|
|
choices=["cuda", "cpu", "mps", ""],
|
|
help="Device type: cuda|cpu|mps (empty = autodetect)"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Initialize device
|
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
|
|
|
print("="*60)
|
|
print("nanochat Model Export")
|
|
print("="*60)
|
|
print(f"Source: {args.source}")
|
|
print(f"Model tag: {args.model_tag or 'default'}")
|
|
print(f"Step: {args.step or 'latest'}")
|
|
print(f"Format: {args.format}")
|
|
print(f"Device: {device}")
|
|
print(f"Max sequence length: {args.max_seq_len}")
|
|
print(f"With KV cache: {args.with_cache}")
|
|
print("="*60)
|
|
|
|
# Load the model
|
|
print("\nLoading model...")
|
|
model, tokenizer, meta = load_model(
|
|
args.source,
|
|
device,
|
|
phase="eval",
|
|
model_tag=args.model_tag,
|
|
step=args.step
|
|
)
|
|
model.eval()
|
|
|
|
print(f"✓ Model loaded successfully")
|
|
print(f" Config: {model.config}")
|
|
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
# Determine output paths
|
|
if args.output:
|
|
base_path = args.output.rsplit(".", 1)[0]
|
|
ext = args.output.rsplit(".", 1)[1] if "." in args.output else ""
|
|
else:
|
|
base_path = f"nanochat_{args.source}"
|
|
if args.model_tag:
|
|
base_path += f"_{args.model_tag}"
|
|
if args.step:
|
|
base_path += f"_step{args.step}"
|
|
if args.with_cache:
|
|
base_path += "_cache"
|
|
ext = ""
|
|
|
|
# Export to requested formats
|
|
success = True
|
|
|
|
if args.format in ["torchscript", "both"]:
|
|
output_path = f"{base_path}.pt" if not ext or ext == "pt" else args.output
|
|
success &= export_to_torchscript(
|
|
model,
|
|
output_path,
|
|
with_cache=args.with_cache,
|
|
max_seq_len=args.max_seq_len,
|
|
example_seq_len=args.example_seq_len,
|
|
device=device
|
|
)
|
|
|
|
if args.format in ["onnx", "both"]:
|
|
output_path = f"{base_path}.onnx" if not ext or ext == "onnx" else args.output
|
|
success &= export_to_onnx(
|
|
model,
|
|
output_path,
|
|
with_cache=args.with_cache,
|
|
max_seq_len=args.max_seq_len,
|
|
example_seq_len=args.example_seq_len,
|
|
device=device,
|
|
opset_version=args.opset_version
|
|
)
|
|
|
|
print("\n" + "="*60)
|
|
if success:
|
|
print("✓ Export completed successfully!")
|
|
print("\nNext steps:")
|
|
print(" - For TorchScript: Use LibTorch C++ API to load and run inference")
|
|
print(" - For ONNX: Use ONNX Runtime in C++, C#, Java, or other languages")
|
|
print("\nSee examples/cpp_inference/ for C++ usage examples")
|
|
else:
|
|
print("✗ Export failed. See errors above.")
|
|
print("="*60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|