/** * LibTorch (TorchScript) Inference Example for nanochat * * This example demonstrates how to load and run inference with a nanochat * model exported to TorchScript format using LibTorch C++ API. * * Build: * mkdir build && cd build * cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. * cmake --build . --config Release * * Run: * ./libtorch_inference ../model.pt */ #include #include #include #include #include #include class NanoChatInference { public: NanoChatInference(const std::string& model_path, torch::Device device = torch::kCPU) : device_(device) { try { // Load the TorchScript model std::cout << "Loading model from: " << model_path << std::endl; module_ = torch::jit::load(model_path); module_.to(device_); module_.eval(); std::cout << "āœ“ Model loaded successfully" << std::endl; } catch (const c10::Error& e) { std::cerr << "Error loading model: " << e.what() << std::endl; throw; } } /** * Run inference on a sequence of token IDs. * * @param input_ids Vector of token IDs (shape: [seq_len]) * @return Logits tensor of shape [1, seq_len, vocab_size] */ torch::Tensor forward(const std::vector& input_ids) { // Convert input to tensor auto options = torch::TensorOptions() .dtype(torch::kLong) .device(device_); torch::Tensor input_tensor = torch::from_blob( const_cast(input_ids.data()), {1, static_cast(input_ids.size())}, torch::kLong ).to(device_); // Run inference std::vector inputs; inputs.push_back(input_tensor); torch::NoGradGuard no_grad; auto output = module_.forward(inputs).toTensor(); return output; } /** * Sample next token from logits using greedy decoding. * * @param logits Logits tensor of shape [1, seq_len, vocab_size] * @return Next token ID */ int64_t sample_greedy(const torch::Tensor& logits) { // Get logits for last position auto last_logits = logits.index({0, -1, torch::indexing::Slice()}); // Greedy sampling: argmax auto next_token = last_logits.argmax().item(); return next_token; } /** * Sample next token with temperature and top-k sampling. * * @param logits Logits tensor of shape [1, seq_len, vocab_size] * @param temperature Temperature for sampling (0.0 = greedy) * @param top_k Top-k filtering (0 = no filtering) * @return Next token ID */ int64_t sample(const torch::Tensor& logits, float temperature = 1.0, int top_k = 0) { // Get logits for last position auto last_logits = logits.index({0, -1, torch::indexing::Slice()}).clone(); // Greedy decoding if temperature is 0 if (temperature <= 0.0f) { return last_logits.argmax().item(); } // Apply temperature last_logits = last_logits / temperature; // Apply top-k filtering if (top_k > 0) { auto vocab_size = last_logits.size(0); auto k = std::min(top_k, static_cast(vocab_size)); auto topk_result = torch::topk(last_logits, k); auto topk_values = std::get<0>(topk_result); auto topk_indices = std::get<1>(topk_result); // Set all non-top-k values to -inf auto threshold = topk_values[-1].item(); last_logits = torch::where( last_logits < threshold, torch::full_like(last_logits, -std::numeric_limits::infinity()), last_logits ); } // Apply softmax to get probabilities auto probs = torch::softmax(last_logits, /*dim=*/0); // Sample from the distribution auto next_token = torch::multinomial(probs, /*num_samples=*/1).item(); return next_token; } /** * Generate tokens autoregressively. * * @param prompt_ids Initial prompt token IDs * @param max_tokens Maximum number of tokens to generate * @param temperature Temperature for sampling * @param top_k Top-k filtering * @return Generated token IDs (including prompt) */ std::vector generate( const std::vector& prompt_ids, int max_tokens = 100, float temperature = 1.0, int top_k = 50 ) { std::vector generated_ids = prompt_ids; std::cout << "Generating " << max_tokens << " tokens..." << std::endl; for (int i = 0; i < max_tokens; ++i) { // Forward pass auto logits = forward(generated_ids); // Sample next token auto next_token = sample(logits, temperature, top_k); // Append to sequence generated_ids.push_back(next_token); // Print progress if ((i + 1) % 10 == 0) { std::cout << " Generated " << (i + 1) << "/" << max_tokens << " tokens" << std::endl; } } return generated_ids; } private: torch::jit::script::Module module_; torch::Device device_; }; int main(int argc, char* argv[]) { if (argc < 2) { std::cerr << "Usage: " << argv[0] << " [use_cuda]" << std::endl; std::cerr << "Example: " << argv[0] << " model.pt 1" << std::endl; return 1; } std::string model_path = argv[1]; bool use_cuda = (argc > 2 && std::string(argv[2]) == "1"); // Setup device torch::Device device = torch::kCPU; if (use_cuda && torch::cuda::is_available()) { device = torch::kCUDA; std::cout << "Using CUDA device" << std::endl; } else { std::cout << "Using CPU device" << std::endl; } try { // Load model NanoChatInference model(model_path, device); // Example prompt (you would normally get these from a tokenizer) // These are just example token IDs - replace with actual tokenized text std::vector prompt_ids = {1, 464, 11742, 15150, 315, 3090, 374}; // Corresponds roughly to: "The chemical formula of water is" std::cout << "\nPrompt token IDs: "; for (auto id : prompt_ids) { std::cout << id << " "; } std::cout << std::endl; // Run single forward pass std::cout << "\n--- Single Forward Pass ---" << std::endl; auto logits = model.forward(prompt_ids); std::cout << "Output shape: [" << logits.size(0) << ", " << logits.size(1) << ", " << logits.size(2) << "]" << std::endl; // Sample next token auto next_token = model.sample_greedy(logits); std::cout << "Next token (greedy): " << next_token << std::endl; // Generate sequence std::cout << "\n--- Autoregressive Generation ---" << std::endl; auto generated_ids = model.generate( prompt_ids, /*max_tokens=*/20, /*temperature=*/0.8, /*top_k=*/50 ); std::cout << "\nGenerated token IDs: "; for (auto id : generated_ids) { std::cout << id << " "; } std::cout << std::endl; std::cout << "\nāœ“ Inference completed successfully!" << std::endl; std::cout << "\nNote: To decode tokens to text, you need to implement" << std::endl; std::cout << " a tokenizer in C++ or use the Python tokenizer." << std::endl; } catch (const std::exception& e) { std::cerr << "Error: " << e.what() << std::endl; return 1; } return 0; }