onnxruntime/error.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
//! Module containing error definitions.
use std::{io, path::PathBuf};
use thiserror::Error;
use onnxruntime_sys as sys;
use crate::{char_p_to_string, g_ort};
/// Type alias for the `Result`
pub type Result<T> = std::result::Result<T, OrtError>;
/// Error type centralizing all possible errors
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtError {
/// The C API can message to the caller using a C `char *` which needs to be converted
/// to Rust's `String`. This operation can fail.
#[error("Failed to construct String")]
StringConversion(OrtApiError),
// FIXME: Move these to another enum (they are C API calls errors)
/// An error occurred when creating an ONNX environment
#[error("Failed to create environment: {0}")]
Environment(OrtApiError),
/// Error occurred when creating an ONNX session options
#[error("Failed to create session options: {0}")]
SessionOptions(OrtApiError),
/// Error occurred when creating an ONNX session
#[error("Failed to create session: {0}")]
Session(OrtApiError),
/// Error occurred when creating an ONNX allocator
#[error("Failed to get allocator: {0}")]
Allocator(OrtApiError),
/// Error occurred when counting ONNX input or output count
#[error("Failed to get input or output count: {0}")]
InOutCount(OrtApiError),
/// Error occurred when getting ONNX input name
#[error("Failed to get input name: {0}")]
InputName(OrtApiError),
/// Error occurred when getting ONNX type information
#[error("Failed to get type info: {0}")]
GetTypeInfo(OrtApiError),
/// Error occurred when casting ONNX type information to tensor information
#[error("Failed to cast type info to tensor info: {0}")]
CastTypeInfoToTensorInfo(OrtApiError),
/// Error occurred when getting tensor elements type
#[error("Failed to get tensor element type: {0}")]
TensorElementType(OrtApiError),
/// Error occurred when getting ONNX dimensions count
#[error("Failed to get dimensions count: {0}")]
GetDimensionsCount(OrtApiError),
/// Error occurred when getting ONNX dimensions
#[error("Failed to get dimensions: {0}")]
GetDimensions(OrtApiError),
/// Error occurred when creating CPU memory information
#[error("Failed to get dimensions: {0}")]
CreateCpuMemoryInfo(OrtApiError),
/// Error occurred when creating ONNX tensor
#[error("Failed to create tensor: {0}")]
CreateTensor(OrtApiError),
/// Error occurred when creating ONNX tensor with specific data
#[error("Failed to create tensor with data: {0}")]
CreateTensorWithData(OrtApiError),
/// Error occurred when filling a tensor with string data
#[error("Failed to fill string tensor: {0}")]
FillStringTensor(OrtApiError),
/// Error occurred when checking if ONNX tensor was properly initialized
#[error("Failed to check if tensor: {0}")]
IsTensor(OrtApiError),
/// Error occurred when getting tensor type and shape
#[error("Failed to get tensor type and shape: {0}")]
GetTensorTypeAndShape(OrtApiError),
/// Error occurred when ONNX inference operation was called
#[error("Failed to run: {0}")]
Run(OrtApiError),
/// Error occurred when extracting data from an ONNX tensor into an C array to be used as an `ndarray::ArrayView`
#[error("Failed to get tensor data: {0}")]
GetTensorMutableData(OrtApiError),
/// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models)
#[error("Failed to download ONNX model: {0}")]
DownloadError(#[from] OrtDownloadError),
/// Dimensions of input data and ONNX model loaded from file do not match
#[error("Dimensions do not match: {0:?}")]
NonMatchingDimensions(NonMatchingDimensionsError),
/// File does not exists
#[error("File {filename:?} does not exists")]
FileDoesNotExists {
/// Path which does not exists
filename: PathBuf,
},
/// Path is an invalid UTF-8
#[error("Path {path:?} cannot be converted to UTF-8")]
NonUtf8Path {
/// Path with invalid UTF-8
path: PathBuf,
},
/// Attempt to build a Rust `CString` from a null pointer
#[error("Failed to build CString when original contains null: {0}")]
CStringNulError(#[from] std::ffi::NulError),
}
/// Error used when dimensions of input (from model and from inference call)
/// do not match (as they should).
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum NonMatchingDimensionsError {
/// Number of inputs from model does not match number of inputs from inference call
#[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})")]
InputsCount {
/// Number of input dimensions used by inference call
inference_input_count: usize,
/// Number of input dimensions defined in model
model_input_count: usize,
/// Input dimensions used by inference call
inference_input: Vec<Vec<usize>>,
/// Input dimensions defined in model
model_input: Vec<Vec<Option<u32>>>,
},
/// Inputs length from model does not match the expected input from inference call
#[error("Different input lengths: Expected Input: {model_input:?} vs Received Input: {inference_input:?}")]
InputsLength {
/// Input dimensions used by inference call
inference_input: Vec<Vec<usize>>,
/// Input dimensions defined in model
model_input: Vec<Vec<Option<u32>>>,
},
}
/// Error details when ONNX C API fail
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtApiError {
/// Details as reported by the ONNX C API in case of error
#[error("Error calling ONNX Runtime C function: {0}")]
Msg(String),
/// Details as reported by the ONNX C API in case of error cannot be converted to UTF-8
#[error("Error calling ONNX Runtime C function and failed to convert error message to UTF-8")]
IntoStringError(std::ffi::IntoStringError),
}
/// Error from downloading pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models).
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum OrtDownloadError {
/// Generic input/output error
#[error("Error downloading data to file: {0}")]
IoError(#[from] io::Error),
#[cfg(feature = "model-fetching")]
/// Download error by ureq
#[error("Error downloading data to file: {0}")]
UreqError(#[from] Box<ureq::Error>),
/// Error getting content-length from an HTTP GET request
#[error("Error getting content-length")]
ContentLengthError,
/// Mismatch between amount of downloaded and expected bytes
#[error("Error copying data to file: expected {expected} length, received {io}")]
CopyError {
/// Expected amount of bytes to download
expected: u64,
/// Number of bytes read from network and written to file
io: u64,
},
}
/// Wrapper type around a ONNX C API's `OrtStatus` pointer
///
/// This wrapper exists to facilitate conversion from C raw pointers to Rust error types
pub struct OrtStatusWrapper(*const sys::OrtStatus);
impl From<*const sys::OrtStatus> for OrtStatusWrapper {
fn from(status: *const sys::OrtStatus) -> Self {
OrtStatusWrapper(status)
}
}
impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
fn from(status: OrtStatusWrapper) -> Self {
if status.0.is_null() {
Ok(())
} else {
let raw: *const i8 = unsafe { g_ort().GetErrorMessage.unwrap()(status.0) };
match char_p_to_string(raw) {
Ok(msg) => Err(OrtApiError::Msg(msg)),
Err(err) => match err {
OrtError::StringConversion(OrtApiError::IntoStringError(e)) => {
Err(OrtApiError::IntoStringError(e))
}
_ => unreachable!(),
},
}
}
}
}
pub(crate) fn status_to_result(
status: *const sys::OrtStatus,
) -> std::result::Result<(), OrtApiError> {
let status_wrapper: OrtStatusWrapper = status.into();
status_wrapper.into()
}
/// A wrapper around a function on OrtApi that maps the status code into [OrtApiError]
pub(crate) unsafe fn call_ort<F>(mut f: F) -> std::result::Result<(), OrtApiError>
where
F: FnMut(sys::OrtApi) -> *const sys::OrtStatus,
{
status_to_result(f(g_ort()))
}