Skip to content

Commit

Permalink
rust: Add TensorRT and OpenVINO execution providers.
Browse files Browse the repository at this point in the history
  • Loading branch information
hgaiser committed Nov 30, 2023
1 parent 094b04e commit a2bc8c5
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 16 deletions.
9 changes: 9 additions & 0 deletions rust/onnxruntime/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ pub enum OrtError {
/// Error occurred when appending CUDA execution provider
#[error("Failed to append CUDA execution provider: {0}")]
AppendExecutionProviderCuda(OrtApiError),
/// Error occurred when creating TensorRT provider options
#[error("Failed to create TensorRT provider options: {0}")]
TensorRtProviderOptions(OrtApiError),
/// Error occurred when appending TensorRT execution provider
#[error("Failed to append TensorRT execution provider: {0}")]
AppendExecutionProviderTensorRT(OrtApiError),
/// Error occurred when appending OpenVINO execution provider
#[error("Failed to append OpenVINO execution provider: {0}")]
AppendExecutionProviderOpenVino(OrtApiError),
}

/// Error used when dimensions of input (from model and from inference call)
Expand Down
132 changes: 116 additions & 16 deletions rust/onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Module containing session types
use std::{convert::TryFrom, ffi::{CString, c_char}, fmt::Debug, path::Path, ptr::null_mut};
use std::{convert::TryFrom, ffi::{CString, c_char, CStr}, fmt::Debug, path::Path, ptr::{null_mut, null}};

#[cfg(not(target_family = "windows"))]
use std::os::unix::ffi::OsStrExt;
Expand Down Expand Up @@ -32,6 +32,47 @@ use tracing::{debug, error};
#[cfg(feature = "model-fetching")]
use crate::{download::AvailableOnnxModel, error::OrtDownloadError};

/// Device to run OpenVINO execution provider with.
pub enum OpenVinoDeviceType {
/// "CPU_FP32"
CpuFp32,
/// "CPU_FP16"
CpuFp16,
/// "GPU_FP32"
GpuFp32,
/// "GPU_FP16"
GpuFp16,
}

impl From<OpenVinoDeviceType> for &str {
fn from(value: OpenVinoDeviceType) -> Self {
match value {
OpenVinoDeviceType::CpuFp32 => "CPU_FP32",
OpenVinoDeviceType::CpuFp16 => "CPU_FP16",
OpenVinoDeviceType::GpuFp32 => "GPU_FP32",
OpenVinoDeviceType::GpuFp16 => "GPU_FP16",
}
}
}

/// Options for the OpenVINO execution provider.
pub struct OpenVinoProviderOptions {
/// Device type to run on.
pub device_type: OpenVinoDeviceType,
///
pub enable_vpu_fast_compile: bool,
///
pub device_id: Option<String>,
/// Number of threads to use. Set to 0 for default number of threads.
pub num_of_threads: usize,
///
pub cache_dir: Option<String>,
///
pub enable_opencl_throttling: bool,
///
pub enable_dynamic_shapes: bool,
}

/// Type used to create a session using the _builder pattern_
///
/// A `SessionBuilder` is created by calling the
Expand Down Expand Up @@ -164,14 +205,14 @@ impl<'a> SessionBuilder<'a> {
/// Append a CUDA execution provider
pub fn with_execution_provider_cuda(self) -> Result<SessionBuilder<'a>> {
let mut cuda_options: *mut sys::OrtCUDAProviderOptionsV2 = null_mut();
// let status = unsafe {
// self.env
// .env()
// .api()
// .CreateCUDAProviderOptions
// .unwrap()(&mut cuda_options)
// };
// status_to_result(status).map_err(OrtError::CudaProviderOptions)?;
let status = unsafe {
self.env
.env()
.api()
.CreateCUDAProviderOptions
.unwrap()(&mut cuda_options)
};
status_to_result(status).map_err(OrtError::CudaProviderOptions)?;

let status = unsafe {
self.env
Expand All @@ -182,13 +223,72 @@ impl<'a> SessionBuilder<'a> {
};
status_to_result(status).map_err(OrtError::AppendExecutionProviderCuda)?;

// unsafe {
// self.env
// .env()
// .api()
// .ReleaseCUDAProviderOptions
// .unwrap()(cuda_options);
// };
unsafe {
self.env
.env()
.api()
.ReleaseCUDAProviderOptions
.unwrap()(cuda_options);
};
Ok(self)
}

/// Append a TensorRT execution provider
pub fn with_execution_provider_tensorrt(self) -> Result<SessionBuilder<'a>> {
let mut tensorrt_options: *mut sys::OrtTensorRTProviderOptionsV2 = null_mut();
let status = unsafe {
self.env
.env()
.api()
.CreateTensorRTProviderOptions
.unwrap()(&mut tensorrt_options)
};
status_to_result(status).map_err(OrtError::TensorRtProviderOptions)?;

let status = unsafe {
self.env
.env()
.api()
.SessionOptionsAppendExecutionProvider_TensorRT_V2
.unwrap()(self.session_options_ptr, tensorrt_options)
};
status_to_result(status).map_err(OrtError::AppendExecutionProviderTensorRT)?;

unsafe {
self.env
.env()
.api()
.ReleaseTensorRTProviderOptions
.unwrap()(tensorrt_options);
};
Ok(self)
}

/// Append a TensorRT execution provider
pub fn with_execution_provider_openvino(self, options: OpenVinoProviderOptions) -> Result<SessionBuilder<'a>> {
// For some reason there is no CreateOpenVINOProviderOptions?
let device_type = CString::new::<&str>(options.device_type.into()).unwrap();
let device_id = CString::new(options.device_id.unwrap_or_default().as_str()).unwrap();
let cache_dir = CString::new(options.cache_dir.unwrap_or_default().as_str()).unwrap();
let mut openvino_options = sys::OrtOpenVINOProviderOptions {
device_type: device_type.as_ptr(),
enable_vpu_fast_compile: options.enable_vpu_fast_compile as u8,
device_id: device_id.as_ptr(),
num_of_threads: options.num_of_threads,
cache_dir: cache_dir.as_ptr(),
context: null_mut(),
enable_opencl_throttling: options.enable_opencl_throttling as u8,
enable_dynamic_shapes: options.enable_dynamic_shapes as u8,
};

let status = unsafe {
self.env
.env()
.api()
.SessionOptionsAppendExecutionProvider_OpenVINO
.unwrap()(self.session_options_ptr, &mut openvino_options)
};
status_to_result(status).map_err(OrtError::AppendExecutionProviderOpenVino)?;
Ok(self)
}

Expand Down

0 comments on commit a2bc8c5

Please sign in to comment.