mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
245 lines
8.0 KiB
C++
245 lines
8.0 KiB
C++
/**
|
|
* LibTorch (TorchScript) Inference Example for nanochat
|
|
*
|
|
* This example demonstrates how to load and run inference with a nanochat
|
|
* model exported to TorchScript format using LibTorch C++ API.
|
|
*
|
|
* Build:
|
|
* mkdir build && cd build
|
|
* cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
|
|
* cmake --build . --config Release
|
|
*
|
|
* Run:
|
|
* ./libtorch_inference ../model.pt
|
|
*/
|
|
|
|
#include <torch/script.h>
|
|
#include <torch/torch.h>
|
|
#include <iostream>
|
|
#include <vector>
|
|
#include <string>
|
|
#include <memory>
|
|
|
|
class NanoChatInference {
|
|
public:
|
|
NanoChatInference(const std::string& model_path, torch::Device device = torch::kCPU)
|
|
: device_(device) {
|
|
try {
|
|
// Load the TorchScript model
|
|
std::cout << "Loading model from: " << model_path << std::endl;
|
|
module_ = torch::jit::load(model_path);
|
|
module_.to(device_);
|
|
module_.eval();
|
|
std::cout << "✓ Model loaded successfully" << std::endl;
|
|
} catch (const c10::Error& e) {
|
|
std::cerr << "Error loading model: " << e.what() << std::endl;
|
|
throw;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Run inference on a sequence of token IDs.
|
|
*
|
|
* @param input_ids Vector of token IDs (shape: [seq_len])
|
|
* @return Logits tensor of shape [1, seq_len, vocab_size]
|
|
*/
|
|
torch::Tensor forward(const std::vector<int64_t>& input_ids) {
|
|
// Convert input to tensor
|
|
auto options = torch::TensorOptions()
|
|
.dtype(torch::kLong)
|
|
.device(device_);
|
|
|
|
torch::Tensor input_tensor = torch::from_blob(
|
|
const_cast<int64_t*>(input_ids.data()),
|
|
{1, static_cast<int64_t>(input_ids.size())},
|
|
torch::kLong
|
|
).to(device_);
|
|
|
|
// Run inference
|
|
std::vector<torch::jit::IValue> inputs;
|
|
inputs.push_back(input_tensor);
|
|
|
|
torch::NoGradGuard no_grad;
|
|
auto output = module_.forward(inputs).toTensor();
|
|
|
|
return output;
|
|
}
|
|
|
|
/**
|
|
* Sample next token from logits using greedy decoding.
|
|
*
|
|
* @param logits Logits tensor of shape [1, seq_len, vocab_size]
|
|
* @return Next token ID
|
|
*/
|
|
int64_t sample_greedy(const torch::Tensor& logits) {
|
|
// Get logits for last position
|
|
auto last_logits = logits.index({0, -1, torch::indexing::Slice()});
|
|
|
|
// Greedy sampling: argmax
|
|
auto next_token = last_logits.argmax().item<int64_t>();
|
|
|
|
return next_token;
|
|
}
|
|
|
|
/**
|
|
* Sample next token with temperature and top-k sampling.
|
|
*
|
|
* @param logits Logits tensor of shape [1, seq_len, vocab_size]
|
|
* @param temperature Temperature for sampling (0.0 = greedy)
|
|
* @param top_k Top-k filtering (0 = no filtering)
|
|
* @return Next token ID
|
|
*/
|
|
int64_t sample(const torch::Tensor& logits, float temperature = 1.0, int top_k = 0) {
|
|
// Get logits for last position
|
|
auto last_logits = logits.index({0, -1, torch::indexing::Slice()}).clone();
|
|
|
|
// Greedy decoding if temperature is 0
|
|
if (temperature <= 0.0f) {
|
|
return last_logits.argmax().item<int64_t>();
|
|
}
|
|
|
|
// Apply temperature
|
|
last_logits = last_logits / temperature;
|
|
|
|
// Apply top-k filtering
|
|
if (top_k > 0) {
|
|
auto vocab_size = last_logits.size(0);
|
|
auto k = std::min(top_k, static_cast<int>(vocab_size));
|
|
|
|
auto topk_result = torch::topk(last_logits, k);
|
|
auto topk_values = std::get<0>(topk_result);
|
|
auto topk_indices = std::get<1>(topk_result);
|
|
|
|
// Set all non-top-k values to -inf
|
|
auto threshold = topk_values[-1].item<float>();
|
|
last_logits = torch::where(
|
|
last_logits < threshold,
|
|
torch::full_like(last_logits, -std::numeric_limits<float>::infinity()),
|
|
last_logits
|
|
);
|
|
}
|
|
|
|
// Apply softmax to get probabilities
|
|
auto probs = torch::softmax(last_logits, /*dim=*/0);
|
|
|
|
// Sample from the distribution
|
|
auto next_token = torch::multinomial(probs, /*num_samples=*/1).item<int64_t>();
|
|
|
|
return next_token;
|
|
}
|
|
|
|
/**
|
|
* Generate tokens autoregressively.
|
|
*
|
|
* @param prompt_ids Initial prompt token IDs
|
|
* @param max_tokens Maximum number of tokens to generate
|
|
* @param temperature Temperature for sampling
|
|
* @param top_k Top-k filtering
|
|
* @return Generated token IDs (including prompt)
|
|
*/
|
|
std::vector<int64_t> generate(
|
|
const std::vector<int64_t>& prompt_ids,
|
|
int max_tokens = 100,
|
|
float temperature = 1.0,
|
|
int top_k = 50
|
|
) {
|
|
std::vector<int64_t> generated_ids = prompt_ids;
|
|
|
|
std::cout << "Generating " << max_tokens << " tokens..." << std::endl;
|
|
|
|
for (int i = 0; i < max_tokens; ++i) {
|
|
// Forward pass
|
|
auto logits = forward(generated_ids);
|
|
|
|
// Sample next token
|
|
auto next_token = sample(logits, temperature, top_k);
|
|
|
|
// Append to sequence
|
|
generated_ids.push_back(next_token);
|
|
|
|
// Print progress
|
|
if ((i + 1) % 10 == 0) {
|
|
std::cout << " Generated " << (i + 1) << "/" << max_tokens << " tokens" << std::endl;
|
|
}
|
|
}
|
|
|
|
return generated_ids;
|
|
}
|
|
|
|
private:
|
|
torch::jit::script::Module module_;
|
|
torch::Device device_;
|
|
};
|
|
|
|
|
|
int main(int argc, char* argv[]) {
|
|
if (argc < 2) {
|
|
std::cerr << "Usage: " << argv[0] << " <model_path> [use_cuda]" << std::endl;
|
|
std::cerr << "Example: " << argv[0] << " model.pt 1" << std::endl;
|
|
return 1;
|
|
}
|
|
|
|
std::string model_path = argv[1];
|
|
bool use_cuda = (argc > 2 && std::string(argv[2]) == "1");
|
|
|
|
// Setup device
|
|
torch::Device device = torch::kCPU;
|
|
if (use_cuda && torch::cuda::is_available()) {
|
|
device = torch::kCUDA;
|
|
std::cout << "Using CUDA device" << std::endl;
|
|
} else {
|
|
std::cout << "Using CPU device" << std::endl;
|
|
}
|
|
|
|
try {
|
|
// Load model
|
|
NanoChatInference model(model_path, device);
|
|
|
|
// Example prompt (you would normally get these from a tokenizer)
|
|
// These are just example token IDs - replace with actual tokenized text
|
|
std::vector<int64_t> prompt_ids = {1, 464, 11742, 15150, 315, 3090, 374};
|
|
// Corresponds roughly to: "The chemical formula of water is"
|
|
|
|
std::cout << "\nPrompt token IDs: ";
|
|
for (auto id : prompt_ids) {
|
|
std::cout << id << " ";
|
|
}
|
|
std::cout << std::endl;
|
|
|
|
// Run single forward pass
|
|
std::cout << "\n--- Single Forward Pass ---" << std::endl;
|
|
auto logits = model.forward(prompt_ids);
|
|
std::cout << "Output shape: [" << logits.size(0) << ", "
|
|
<< logits.size(1) << ", " << logits.size(2) << "]" << std::endl;
|
|
|
|
// Sample next token
|
|
auto next_token = model.sample_greedy(logits);
|
|
std::cout << "Next token (greedy): " << next_token << std::endl;
|
|
|
|
// Generate sequence
|
|
std::cout << "\n--- Autoregressive Generation ---" << std::endl;
|
|
auto generated_ids = model.generate(
|
|
prompt_ids,
|
|
/*max_tokens=*/20,
|
|
/*temperature=*/0.8,
|
|
/*top_k=*/50
|
|
);
|
|
|
|
std::cout << "\nGenerated token IDs: ";
|
|
for (auto id : generated_ids) {
|
|
std::cout << id << " ";
|
|
}
|
|
std::cout << std::endl;
|
|
|
|
std::cout << "\n✓ Inference completed successfully!" << std::endl;
|
|
std::cout << "\nNote: To decode tokens to text, you need to implement" << std::endl;
|
|
std::cout << " a tokenizer in C++ or use the Python tokenizer." << std::endl;
|
|
|
|
} catch (const std::exception& e) {
|
|
std::cerr << "Error: " << e.what() << std::endl;
|
|
return 1;
|
|
}
|
|
|
|
return 0;
|
|
}
|