onnxruntime/tensor/
ort_owned_tensor.rsuse std::{fmt::Debug, ops::Deref};
use ndarray::{Array, ArrayView};
use tracing::debug;
use onnxruntime_sys as sys;
use crate::{
error::status_to_result, g_ort, memory::MemoryInfo, tensor::ndarray_tensor::NdArrayTensor,
OrtError, Result, TypeToTensorElementDataType,
};
#[derive(Debug)]
pub struct OrtOwnedTensor<'t, 'm, T, D>
where
T: TypeToTensorElementDataType + Debug + Clone,
D: ndarray::Dimension,
'm: 't, {
pub(crate) tensor_ptr: *mut sys::OrtValue,
array_view: ArrayView<'t, T, D>,
memory_info: &'m MemoryInfo,
}
impl<'t, 'm, T, D> Deref for OrtOwnedTensor<'t, 'm, T, D>
where
T: TypeToTensorElementDataType + Debug + Clone,
D: ndarray::Dimension,
{
type Target = ArrayView<'t, T, D>;
fn deref(&self) -> &Self::Target {
&self.array_view
}
}
impl<'t, 'm, T, D> OrtOwnedTensor<'t, 'm, T, D>
where
T: TypeToTensorElementDataType + Debug + Clone,
D: ndarray::Dimension,
{
pub fn softmax(&self, axis: ndarray::Axis) -> Array<T, D>
where
D: ndarray::RemoveAxis,
T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign,
{
self.array_view.softmax(axis)
}
}
#[derive(Debug)]
pub(crate) struct OrtOwnedTensorExtractor<'m, D>
where
D: ndarray::Dimension,
{
pub(crate) tensor_ptr: *mut sys::OrtValue,
memory_info: &'m MemoryInfo,
shape: D,
}
impl<'m, D> OrtOwnedTensorExtractor<'m, D>
where
D: ndarray::Dimension,
{
pub(crate) fn new(memory_info: &'m MemoryInfo, shape: D) -> OrtOwnedTensorExtractor<'m, D> {
OrtOwnedTensorExtractor {
tensor_ptr: std::ptr::null_mut(),
memory_info,
shape,
}
}
pub(crate) fn extract<'t, T>(self) -> Result<OrtOwnedTensor<'t, 'm, T, D>>
where
T: TypeToTensorElementDataType + Debug + Clone,
{
assert_ne!(self.tensor_ptr, std::ptr::null_mut());
let mut is_tensor = 0;
let status = unsafe { g_ort().IsTensor.unwrap()(self.tensor_ptr, &mut is_tensor) };
status_to_result(status).map_err(OrtError::IsTensor)?;
assert_eq!(is_tensor, 1);
let mut output_array_ptr: *mut T = std::ptr::null_mut();
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void =
output_array_ptr_ptr as *mut *mut std::ffi::c_void;
let status = unsafe {
g_ort().GetTensorMutableData.unwrap()(self.tensor_ptr, output_array_ptr_ptr_void)
};
status_to_result(status).map_err(OrtError::IsTensor)?;
assert_ne!(output_array_ptr, std::ptr::null_mut());
let array_view = unsafe { ArrayView::from_shape_ptr(self.shape, output_array_ptr) };
Ok(OrtOwnedTensor {
tensor_ptr: self.tensor_ptr,
array_view,
memory_info: self.memory_info,
})
}
}
impl<'t, 'm, T, D> Drop for OrtOwnedTensor<'t, 'm, T, D>
where
T: TypeToTensorElementDataType + Debug + Clone,
D: ndarray::Dimension,
'm: 't, {
#[tracing::instrument]
fn drop(&mut self) {
debug!("Dropping OrtOwnedTensor.");
unsafe { g_ort().ReleaseValue.unwrap()(self.tensor_ptr) }
self.tensor_ptr = std::ptr::null_mut();
}
}