onnxruntime/tensor/ndarray_tensor.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
//! Module containing a tensor trait extending [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)
use ndarray::{Array, ArrayBase};
/// Trait extending [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)
/// with useful tensor operations.
///
/// # Generic
///
/// The trait is generic over:
/// * `S`: [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)'s data container
/// * `T`: Type contained inside the tensor (for example `f32`)
/// * `D`: Tensor's dimension ([`ndarray::Dimension`](https://docs.rs/ndarray/latest/ndarray/trait.Dimension.html))
pub trait NdArrayTensor<S, T, D> {
/// Calculate the [softmax](https://en.wikipedia.org/wiki/Softmax_function) of the tensor along a given axis
///
/// # Trait Bounds
///
/// The function is generic and thus has some trait bounds:
/// * `D: ndarray::RemoveAxis`: The summation over an axis reduces the dimension of the tensor. A 0-D tensor thus
/// cannot have a softmax calculated.
/// * `S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>`: The storage of the tensor can be an owned
/// array ([`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)) or an array view
/// ([`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html)).
/// * `<S as ndarray::RawData>::Elem: std::clone::Clone`: The elements of the tensor must be `Clone`.
/// * `T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign`: The elements of the tensor must be workable
/// as floats and must support `-=` and `/=` operations.
fn softmax(&self, axis: ndarray::Axis) -> Array<T, D>
where
D: ndarray::RemoveAxis,
S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
<S as ndarray::RawData>::Elem: std::clone::Clone,
T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign;
}
impl<S, T, D> NdArrayTensor<S, T, D> for ArrayBase<S, D>
where
D: ndarray::RemoveAxis,
S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
<S as ndarray::RawData>::Elem: std::clone::Clone,
T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign,
{
fn softmax(&self, axis: ndarray::Axis) -> Array<T, D> {
let mut new_array: Array<T, D> = self.to_owned();
// FIXME: Change to non-overflowing formula
// e = np.exp(A - np.sum(A, axis=1, keepdims=True))
// np.exp(a) / np.sum(np.exp(a))
new_array.map_inplace(|v| *v = v.exp());
let sum = new_array.sum_axis(axis).insert_axis(axis);
new_array /= ∑
new_array
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{arr1, arr2, arr3};
use test_env_log::test;
#[test]
fn softmax_1d() {
let array = arr1(&[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]);
let expected_softmax = arr1(&[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
]);
let softmax = array.softmax(ndarray::Axis(0));
assert_eq!(softmax.shape(), expected_softmax.shape());
let diff = softmax - expected_softmax;
assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
}
#[test]
fn softmax_2d() {
let array = arr2(&[
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
]);
let expected_softmax = arr2(&[
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
]);
let softmax = array.softmax(ndarray::Axis(1));
assert_eq!(softmax.shape(), expected_softmax.shape());
let diff = softmax - expected_softmax;
assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
}
#[test]
fn softmax_3d() {
let array = arr3(&[
[
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
],
[
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
],
[
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
],
]);
let expected_softmax = arr3(&[
[
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
],
[
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
],
[
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
],
]);
let softmax = array.softmax(ndarray::Axis(2));
assert_eq!(softmax.shape(), expected_softmax.shape());
let diff = softmax - expected_softmax;
assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
}
}