use std::{ffi::CString, fmt::Debug, path::Path};
#[cfg(not(target_family = "windows"))]
use std::os::unix::ffi::OsStrExt;
#[cfg(target_family = "windows")]
use std::os::windows::ffi::OsStrExt;
#[cfg(feature = "model-fetching")]
use std::env;
use ndarray::Array;
use tracing::{debug, error};
use onnxruntime_sys as sys;
use crate::{
char_p_to_string,
environment::Environment,
error::{status_to_result, NonMatchingDimensionsError, OrtError, Result},
g_ort,
memory::MemoryInfo,
tensor::{
ort_owned_tensor::{OrtOwnedTensor, OrtOwnedTensorExtractor},
OrtTensor,
},
AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType,
TypeToTensorElementDataType,
};
#[cfg(feature = "model-fetching")]
use crate::{download::AvailableOnnxModel, error::OrtDownloadError};
#[derive(Debug)]
pub struct SessionBuilder<'a> {
env: &'a Environment,
session_options_ptr: *mut sys::OrtSessionOptions,
allocator: AllocatorType,
memory_type: MemType,
}
impl<'a> Drop for SessionBuilder<'a> {
#[tracing::instrument]
fn drop(&mut self) {
debug!("Dropping the session options.");
assert_ne!(self.session_options_ptr, std::ptr::null_mut());
unsafe { g_ort().ReleaseSessionOptions.unwrap()(self.session_options_ptr) };
}
}
impl<'a> SessionBuilder<'a> {
pub(crate) fn new(env: &'a Environment) -> Result<SessionBuilder<'a>> {
let mut session_options_ptr: *mut sys::OrtSessionOptions = std::ptr::null_mut();
let status = unsafe { g_ort().CreateSessionOptions.unwrap()(&mut session_options_ptr) };
status_to_result(status).map_err(OrtError::SessionOptions)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(session_options_ptr, std::ptr::null_mut());
Ok(SessionBuilder {
env,
session_options_ptr,
allocator: AllocatorType::Arena,
memory_type: MemType::Default,
})
}
pub fn with_number_threads(self, num_threads: i16) -> Result<SessionBuilder<'a>> {
let num_threads = num_threads as i32;
let status =
unsafe { g_ort().SetIntraOpNumThreads.unwrap()(self.session_options_ptr, num_threads) };
status_to_result(status).map_err(OrtError::SessionOptions)?;
assert_eq!(status, std::ptr::null_mut());
Ok(self)
}
pub fn with_optimization_level(
self,
opt_level: GraphOptimizationLevel,
) -> Result<SessionBuilder<'a>> {
unsafe {
g_ort().SetSessionGraphOptimizationLevel.unwrap()(
self.session_options_ptr,
opt_level.into(),
)
};
Ok(self)
}
pub fn with_allocator(mut self, allocator: AllocatorType) -> Result<SessionBuilder<'a>> {
self.allocator = allocator;
Ok(self)
}
pub fn with_memory_type(mut self, memory_type: MemType) -> Result<SessionBuilder<'a>> {
self.memory_type = memory_type;
Ok(self)
}
#[cfg(feature = "model-fetching")]
pub fn with_model_downloaded<M>(self, model: M) -> Result<Session<'a>>
where
M: Into<AvailableOnnxModel>,
{
self.with_model_downloaded_monomorphized(model.into())
}
#[cfg(feature = "model-fetching")]
fn with_model_downloaded_monomorphized(self, model: AvailableOnnxModel) -> Result<Session<'a>> {
let download_dir = env::current_dir().map_err(OrtDownloadError::IoError)?;
let downloaded_path = model.download_to(download_dir)?;
self.with_model_from_file(downloaded_path)
}
pub fn with_model_from_file<P>(self, model_filepath_ref: P) -> Result<Session<'a>>
where
P: AsRef<Path> + 'a,
{
let model_filepath = model_filepath_ref.as_ref();
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
if !model_filepath.exists() {
return Err(OrtError::FileDoesNotExists {
filename: model_filepath.to_path_buf(),
});
}
let model_path = std::ffi::OsString::from(model_filepath);
#[cfg(target_family = "windows")]
let model_path: Vec<u16> = model_path
.encode_wide()
.chain(std::iter::once(0)) .collect();
#[cfg(not(target_family = "windows"))]
let model_path: Vec<std::os::raw::c_char> = model_path
.as_bytes()
.iter()
.chain(std::iter::once(&b'\0')) .map(|b| *b as std::os::raw::c_char)
.collect();
let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
let status = unsafe {
g_ort().CreateSession.unwrap()(
env_ptr,
model_path.as_ptr(),
self.session_options_ptr,
&mut session_ptr,
)
};
status_to_result(status).map_err(OrtError::Session)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(session_ptr, std::ptr::null_mut());
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
let status = unsafe { g_ort().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) };
status_to_result(status).map_err(OrtError::Allocator)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(allocator_ptr, std::ptr::null_mut());
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i))
.collect::<Result<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
.collect::<Result<Vec<Output>>>()?;
Ok(Session {
env: self.env,
session_ptr,
allocator_ptr,
memory_info,
inputs,
outputs,
})
}
pub fn with_model_from_memory<B>(self, model_bytes: B) -> Result<Session<'a>>
where
B: AsRef<[u8]>,
{
self.with_model_from_memory_monomorphized(model_bytes.as_ref())
}
fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> Result<Session<'a>> {
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
let status = unsafe {
let model_data = model_bytes.as_ptr() as *const std::ffi::c_void;
let model_data_length = model_bytes.len();
g_ort().CreateSessionFromArray.unwrap()(
env_ptr,
model_data,
model_data_length,
self.session_options_ptr,
&mut session_ptr,
)
};
status_to_result(status).map_err(OrtError::Session)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(session_ptr, std::ptr::null_mut());
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
let status = unsafe { g_ort().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) };
status_to_result(status).map_err(OrtError::Allocator)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(allocator_ptr, std::ptr::null_mut());
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i))
.collect::<Result<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
.collect::<Result<Vec<Output>>>()?;
Ok(Session {
env: self.env,
session_ptr,
allocator_ptr,
memory_info,
inputs,
outputs,
})
}
}
#[derive(Debug)]
pub struct Session<'a> {
env: &'a Environment,
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
memory_info: MemoryInfo,
pub inputs: Vec<Input>,
pub outputs: Vec<Output>,
}
#[derive(Debug)]
pub struct Input {
pub name: String,
pub input_type: TensorElementDataType,
pub dimensions: Vec<Option<u32>>,
}
#[derive(Debug)]
pub struct Output {
pub name: String,
pub output_type: TensorElementDataType,
pub dimensions: Vec<Option<u32>>,
}
impl Input {
pub fn dimensions(&self) -> impl Iterator<Item = Option<usize>> + '_ {
self.dimensions.iter().map(|d| d.map(|d2| d2 as usize))
}
}
impl Output {
pub fn dimensions(&self) -> impl Iterator<Item = Option<usize>> + '_ {
self.dimensions.iter().map(|d| d.map(|d2| d2 as usize))
}
}
impl<'a> Drop for Session<'a> {
#[tracing::instrument]
fn drop(&mut self) {
debug!("Dropping the session.");
unsafe { g_ort().ReleaseSession.unwrap()(self.session_ptr) };
self.session_ptr = std::ptr::null_mut();
self.allocator_ptr = std::ptr::null_mut();
}
}
impl<'a> Session<'a> {
pub fn run<'s, 't, 'm, TIn, TOut, D>(
&'s mut self,
input_arrays: Vec<Array<TIn, D>>,
) -> Result<Vec<OrtOwnedTensor<'t, 'm, TOut, ndarray::IxDyn>>>
where
TIn: TypeToTensorElementDataType + Debug + Clone,
TOut: TypeToTensorElementDataType + Debug + Clone,
D: ndarray::Dimension,
'm: 't, 's: 'm, {
self.validate_input_shapes(&input_arrays)?;
let input_names_ptr: Vec<*const i8> = self
.inputs
.iter()
.map(|input| input.name.clone())
.map(|n| CString::new(n).unwrap())
.map(|n| n.into_raw() as *const i8)
.collect();
let output_names_cstring: Vec<CString> = self
.outputs
.iter()
.map(|output| output.name.clone())
.map(|n| CString::new(n).unwrap())
.collect();
let output_names_ptr: Vec<*const i8> = output_names_cstring
.iter()
.map(|n| n.as_ptr() as *const i8)
.collect();
let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> =
vec![std::ptr::null_mut(); self.outputs.len()];
let input_ort_tensors: Vec<OrtTensor<TIn, D>> = input_arrays
.into_iter()
.map(|input_array| {
OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array)
})
.collect::<Result<Vec<OrtTensor<TIn, D>>>>()?;
let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors
.iter()
.map(|input_array_ort| input_array_ort.c_ptr as *const sys::OrtValue)
.collect();
let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null();
let status = unsafe {
g_ort().Run.unwrap()(
self.session_ptr,
run_options_ptr,
input_names_ptr.as_ptr(),
input_ort_values.as_ptr(),
input_ort_values.len(),
output_names_ptr.as_ptr(),
output_names_ptr.len(),
output_tensor_extractors_ptrs.as_mut_ptr(),
)
};
status_to_result(status).map_err(OrtError::Run)?;
let memory_info_ref = &self.memory_info;
let outputs: Result<Vec<OrtOwnedTensor<TOut, ndarray::Dim<ndarray::IxDynImpl>>>> =
output_tensor_extractors_ptrs
.into_iter()
.map(|ptr| {
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo =
std::ptr::null_mut();
let status = unsafe {
g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _)
};
status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?;
let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) };
unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) };
let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect();
let mut output_tensor_extractor =
OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims));
output_tensor_extractor.tensor_ptr = ptr;
output_tensor_extractor.extract::<TOut>()
})
.collect();
let _: Vec<CString> = input_names_ptr
.into_iter()
.map(|p| {
assert_ne!(p, std::ptr::null());
unsafe { CString::from_raw(p as *mut i8) }
})
.collect();
outputs
}
fn validate_input_shapes<TIn, D>(&mut self, input_arrays: &[Array<TIn, D>]) -> Result<()>
where
TIn: TypeToTensorElementDataType + Debug + Clone,
D: ndarray::Dimension,
{
if input_arrays.len() != self.inputs.len() {
error!(
"Non-matching number of inputs: {} (inference) vs {} (model)",
input_arrays.len(),
self.inputs.len()
);
return Err(OrtError::NonMatchingDimensions(
NonMatchingDimensionsError::InputsCount {
inference_input_count: 0,
model_input_count: 0,
inference_input: input_arrays
.iter()
.map(|input_array| input_array.shape().to_vec())
.collect(),
model_input: self
.inputs
.iter()
.map(|input| input.dimensions.clone())
.collect(),
},
));
}
let inputs_different_length = input_arrays
.iter()
.zip(self.inputs.iter())
.any(|(l, r)| l.shape().len() != r.dimensions.len());
if inputs_different_length {
error!(
"Different input lengths: {:?} vs {:?}",
self.inputs, input_arrays
);
return Err(OrtError::NonMatchingDimensions(
NonMatchingDimensionsError::InputsLength {
inference_input: input_arrays
.iter()
.map(|input_array| input_array.shape().to_vec())
.collect(),
model_input: self
.inputs
.iter()
.map(|input| input.dimensions.clone())
.collect(),
},
));
}
let inputs_different_shape = input_arrays.iter().zip(self.inputs.iter()).any(|(l, r)| {
let l_shape = l.shape();
let r_shape = r.dimensions.as_slice();
l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 {
Some(r3) => *r3 as usize != *l2,
None => false, })
});
if inputs_different_shape {
error!(
"Different input lengths: {:?} vs {:?}",
self.inputs, input_arrays
);
return Err(OrtError::NonMatchingDimensions(
NonMatchingDimensionsError::InputsLength {
inference_input: input_arrays
.iter()
.map(|input_array| input_array.shape().to_vec())
.collect(),
model_input: self
.inputs
.iter()
.map(|input| input.dimensions.clone())
.collect(),
},
));
}
Ok(())
}
}
unsafe fn get_tensor_dimensions(
tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo,
) -> Result<Vec<i64>> {
let mut num_dims = 0;
let status = g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims);
status_to_result(status).map_err(OrtError::GetDimensionsCount)?;
assert_ne!(num_dims, 0);
let mut node_dims: Vec<i64> = vec![0; num_dims as usize];
let status = g_ort().GetDimensions.unwrap()(
tensor_info_ptr,
node_dims.as_mut_ptr(), num_dims,
);
status_to_result(status).map_err(OrtError::GetDimensions)?;
Ok(node_dims)
}
mod dangerous {
use super::*;
pub(super) fn extract_inputs_count(session_ptr: *mut sys::OrtSession) -> Result<usize> {
let f = g_ort().SessionGetInputCount.unwrap();
extract_io_count(f, session_ptr)
}
pub(super) fn extract_outputs_count(session_ptr: *mut sys::OrtSession) -> Result<usize> {
let f = g_ort().SessionGetOutputCount.unwrap();
extract_io_count(f, session_ptr)
}
fn extract_io_count(
f: extern_system_fn! { unsafe fn(*const sys::OrtSession, *mut usize) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
) -> Result<usize> {
let mut num_nodes: usize = 0;
let status = unsafe { f(session_ptr, &mut num_nodes) };
status_to_result(status).map_err(OrtError::InOutCount)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(num_nodes, 0);
Ok(num_nodes)
}
fn extract_input_name(
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
i: usize,
) -> Result<String> {
let f = g_ort().SessionGetInputName.unwrap();
extract_io_name(f, session_ptr, allocator_ptr, i)
}
fn extract_output_name(
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
i: usize,
) -> Result<String> {
let f = g_ort().SessionGetOutputName.unwrap();
extract_io_name(f, session_ptr, allocator_ptr, i)
}
fn extract_io_name(
f: extern_system_fn! { unsafe fn(
*const sys::OrtSession,
usize,
*mut sys::OrtAllocator,
*mut *mut i8,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
i: usize,
) -> Result<String> {
let mut name_bytes: *mut i8 = std::ptr::null_mut();
let status = unsafe { f(session_ptr, i, allocator_ptr, &mut name_bytes) };
status_to_result(status).map_err(OrtError::InputName)?;
assert_ne!(name_bytes, std::ptr::null_mut());
let name = char_p_to_string(name_bytes)?;
Ok(name)
}
pub(super) fn extract_input(
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
i: usize,
) -> Result<Input> {
let input_name = extract_input_name(session_ptr, allocator_ptr, i)?;
let f = g_ort().SessionGetInputTypeInfo.unwrap();
let (input_type, dimensions) = extract_io(f, session_ptr, i)?;
Ok(Input {
name: input_name,
input_type,
dimensions,
})
}
pub(super) fn extract_output(
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
i: usize,
) -> Result<Output> {
let output_name = extract_output_name(session_ptr, allocator_ptr, i)?;
let f = g_ort().SessionGetOutputTypeInfo.unwrap();
let (output_type, dimensions) = extract_io(f, session_ptr, i)?;
Ok(Output {
name: output_name,
output_type,
dimensions,
})
}
fn extract_io(
f: extern_system_fn! { unsafe fn(
*const sys::OrtSession,
usize,
*mut *mut sys::OrtTypeInfo,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
i: usize,
) -> Result<(TensorElementDataType, Vec<Option<u32>>)> {
let mut typeinfo_ptr: *mut sys::OrtTypeInfo = std::ptr::null_mut();
let status = unsafe { f(session_ptr, i, &mut typeinfo_ptr) };
status_to_result(status).map_err(OrtError::GetTypeInfo)?;
assert_ne!(typeinfo_ptr, std::ptr::null_mut());
let mut tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
let status = unsafe {
g_ort().CastTypeInfoToTensorInfo.unwrap()(typeinfo_ptr, &mut tensor_info_ptr)
};
status_to_result(status).map_err(OrtError::CastTypeInfoToTensorInfo)?;
assert_ne!(tensor_info_ptr, std::ptr::null_mut());
let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
let status =
unsafe { g_ort().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys) };
status_to_result(status).map_err(OrtError::TensorElementType)?;
assert_ne!(
type_sys,
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
);
let io_type: TensorElementDataType = unsafe { std::mem::transmute(type_sys) };
let node_dims = unsafe { get_tensor_dimensions(tensor_info_ptr)? };
unsafe { g_ort().ReleaseTypeInfo.unwrap()(typeinfo_ptr) };
Ok((
io_type,
node_dims
.into_iter()
.map(|d| if d == -1 { None } else { Some(d as u32) })
.collect(),
))
}
}