8.6 KiB
TorchScript/ONNX Export Implementation for nanochat
Summary
This document describes the implementation of TorchScript and ONNX export functionality for nanochat models, addressing GitHub Issue #73.
Problem Statement
The original issue requested the ability to export nanochat models for inference in different languages (C/C++, etc.) using TorchScript, TorchTrace, or ONNX formats. The main challenges were:
- Rotary Embeddings: Pre-computed embeddings stored as buffers
- KV Cache: Dynamic state management during autoregressive generation
- Tool Use: Python-based calculator and special token handling in the Engine class
Solution Overview
The implementation provides:
- Export-friendly model wrappers that encapsulate rotary embeddings and simplify the forward pass
- Export script supporting both TorchScript and ONNX formats
- C++ inference examples for LibTorch and ONNX Runtime
- Comprehensive documentation for users
Implementation Details
1. Export Wrapper (nanochat/export_wrapper.py)
Two wrapper classes were created:
ExportableGPT
- Simplified forward pass without KV cache
- Self-contained rotary embeddings
- Suitable for both TorchScript and ONNX export
- Best for simple use cases and maximum compatibility
wrapper = ExportableGPT(model, max_seq_len=4096)
logits = wrapper(input_ids)
ExportableGPTWithCache
- Explicit KV cache management as inputs/outputs
- Enables stateful inference for better performance
- More complex but suitable for production deployments
- May have limited ONNX support due to dynamic shapes
wrapper = ExportableGPTWithCache(model, max_seq_len=4096)
logits, cache_k, cache_v = wrapper(input_ids, cache_k, cache_v, position)
Key Design Decisions:
- Rotary embeddings are pre-computed and stored as persistent buffers
- Simplified attention mechanism without Engine complexity
- No tool use or special token handling (Python-only features)
- Support for position offsets to enable KV cache usage
2. Export Script (scripts/export_model.py)
A comprehensive CLI tool for exporting models:
# Export to TorchScript
python -m scripts.export_model --source sft --format torchscript --output model.pt
# Export to ONNX
python -m scripts.export_model --source sft --format onnx --output model.onnx
# Export both formats
python -m scripts.export_model --source sft --format both
Features:
- Supports all model sources (base, mid, sft, rl)
- Automatic model loading and validation
- Output verification (compares exported vs original)
- ONNX validation with onnxruntime (if available)
- Configurable sequence lengths and opset versions
- Support for both cached and non-cached variants
3. C++ Examples (examples/cpp_inference/)
LibTorch Example (libtorch_inference.cpp)
Demonstrates inference using PyTorch's C++ API:
- Model loading from TorchScript files
- Single forward pass
- Autoregressive generation with sampling
- Temperature and top-k sampling support
NanoChatInference model(model_path, device);
auto logits = model.forward(input_ids);
auto generated = model.generate(prompt_ids, max_tokens, temperature, top_k);
ONNX Runtime Example (onnx_inference.cpp)
Cross-platform inference with ONNX Runtime:
- ONNX model loading
- CPU and CUDA execution providers
- Efficient inference with ONNX Runtime optimizations
- Compatible with multiple languages (C++, C#, Java, Python)
NanoChatONNXInference model(model_path, use_cuda);
auto logits = model.forward(input_ids, batch_size, seq_len, vocab_size);
auto generated = model.generate(prompt_ids, max_tokens, temperature, top_k);
Build System (CMakeLists.txt)
- Supports both LibTorch and ONNX Runtime
- Optional builds (can build either or both)
- Cross-platform (Linux, macOS, Windows)
- Clear error messages for missing dependencies
4. Documentation
Main README Updates
Added a new "Model Export" section covering:
- Export commands and options
- C++ inference quick start
- Limitations of exported models
- Links to detailed C++ documentation
C++ Examples README
Comprehensive guide including:
- Prerequisites and dependencies
- Build instructions for all platforms
- Export workflow
- Running examples
- Tokenization strategies
- Performance tips
- Troubleshooting
Testing
A test script (test_export.py) was created to verify the implementation:
python3 test_export.py
Test Coverage:
- ✓ ExportableGPT forward pass
- ✓ Position offset handling
- ✓ TorchScript tracing
- ✓ Output verification (original vs traced)
- ✓ ExportableGPTWithCache forward pass
- ✓ Cache shape validation
All tests pass successfully with zero numerical differences between original and traced models.
Limitations
The exported models have intentional limitations:
-
No Tool Use: Calculator and Python execution features are not included
- These require Python runtime and are not suitable for export
- Users can implement similar features in their target language if needed
-
No Special Token Handling: Special tokens like
<|python_start|>are not automatically processed- The exported model only performs the core transformer forward pass
- Special token logic must be implemented in the inference code
-
Tokenization: Token encoding/decoding is not included
- Users must implement BPE tokenization in C++ or use Python for preprocessing
- The nanochat tokenizer is tiktoken-compatible
-
KV Cache Complexity: The cached variant is more complex
- Recommended to start with the simple non-cached version
- Cache management must be handled by the caller
Usage Examples
Python Export
# Export a trained SFT model
python -m scripts.export_model \
--source sft \
--format torchscript \
--output nanochat_sft.pt \
--max-seq-len 4096
C++ Inference (LibTorch)
#include "libtorch_inference.cpp"
int main() {
NanoChatInference model("model.pt", torch::kCUDA);
std::vector<int64_t> prompt = {1, 464, 11742, 15150, 315, 3090, 374};
auto generated = model.generate(prompt, 100, 0.8, 50);
// generated now contains the full sequence including prompt
return 0;
}
C++ Inference (ONNX Runtime)
#include "onnx_inference.cpp"
int main() {
NanoChatONNXInference model("model.onnx", true);
std::vector<int64_t> prompt = {1, 464, 11742, 15150, 315, 3090, 374};
auto generated = model.generate(prompt, 100, 0.8, 50);
return 0;
}
Files Created/Modified
New Files
nanochat/export_wrapper.py- Export-friendly model wrappersscripts/export_model.py- Export script for TorchScript/ONNXexamples/cpp_inference/libtorch_inference.cpp- LibTorch exampleexamples/cpp_inference/onnx_inference.cpp- ONNX Runtime exampleexamples/cpp_inference/CMakeLists.txt- Build configurationexamples/cpp_inference/README.md- C++ documentationtest_export.py- Test script for export functionalityEXPORT_IMPLEMENTATION.md- This document
Modified Files
README.md- Added export documentation and updated file structure
Future Enhancements
Potential improvements for future work:
- Quantization Support: Add INT8/FP16 quantization for faster inference
- Batch Processing: Optimize for batch inference in C++
- Tokenizer Port: Implement BPE tokenizer in C++ for end-to-end inference
- Mobile Deployment: Add support for mobile platforms (iOS/Android)
- WebAssembly: Export to WASM for browser-based inference
- Streaming Generation: Implement streaming token generation in C++
- Model Optimization: Add ONNX graph optimizations and operator fusion
Performance Considerations
- Use GPU: CUDA inference is significantly faster than CPU
- KV Cache: Implement KV caching for production deployments
- Batch Size: Process multiple sequences in parallel when possible
- Quantization: Consider quantizing models for deployment
- Operator Fusion: ONNX Runtime automatically fuses operators for better performance
Conclusion
This implementation successfully addresses GitHub Issue #73 by providing:
- ✅ TorchScript export support
- ✅ ONNX export support
- ✅ C++ inference examples (LibTorch and ONNX Runtime)
- ✅ Comprehensive documentation
- ✅ Tested and verified implementation
Users can now export trained nanochat models and run inference in C++ or other languages, enabling production deployments without Python dependencies. The implementation maintains the simplicity and hackability that nanochat is known for, while providing the flexibility needed for diverse deployment scenarios.