ndarray/numeric/
impl_numeric.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[cfg(feature = "std")]
10use num_traits::Float;
11use num_traits::{self, FromPrimitive, Zero};
12use std::ops::{Add, Div, Mul};
13
14use crate::imp_prelude::*;
15use crate::itertools::enumerate;
16use crate::numeric_util;
17
18/// # Numerical Methods for Arrays
19impl<A, S, D> ArrayBase<S, D>
20where
21    S: Data<Elem = A>,
22    D: Dimension,
23{
24    /// Return the sum of all elements in the array.
25    ///
26    /// ```
27    /// use ndarray::arr2;
28    ///
29    /// let a = arr2(&[[1., 2.],
30    ///                [3., 4.]]);
31    /// assert_eq!(a.sum(), 10.);
32    /// ```
33    pub fn sum(&self) -> A
34    where
35        A: Clone + Add<Output = A> + num_traits::Zero,
36    {
37        if let Some(slc) = self.as_slice_memory_order() {
38            return numeric_util::unrolled_fold(slc, A::zero, A::add);
39        }
40        let mut sum = A::zero();
41        for row in self.inner_rows() {
42            if let Some(slc) = row.as_slice() {
43                sum = sum + numeric_util::unrolled_fold(slc, A::zero, A::add);
44            } else {
45                sum = sum + row.iter().fold(A::zero(), |acc, elt| acc + elt.clone());
46            }
47        }
48        sum
49    }
50
51    /// Return the sum of all elements in the array.
52    ///
53    /// *This method has been renamed to `.sum()`*
54    #[deprecated(note="renamed to `sum`", since="0.15.0")]
55    pub fn scalar_sum(&self) -> A
56    where
57        A: Clone + Add<Output = A> + num_traits::Zero,
58    {
59        self.sum()
60    }
61
62    /// Returns the [arithmetic mean] x̅ of all elements in the array:
63    ///
64    /// ```text
65    ///     1   n
66    /// x̅ = ―   ∑ xᵢ
67    ///     n  i=1
68    /// ```
69    ///
70    /// If the array is empty, `None` is returned.
71    ///
72    /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
73    ///
74    /// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean
75    pub fn mean(&self) -> Option<A>
76    where
77        A: Clone + FromPrimitive + Add<Output = A> + Div<Output = A> + Zero,
78    {
79        let n_elements = self.len();
80        if n_elements == 0 {
81            None
82        } else {
83            let n_elements = A::from_usize(n_elements)
84                .expect("Converting number of elements to `A` must not fail.");
85            Some(self.sum() / n_elements)
86        }
87    }
88
89    /// Return the product of all elements in the array.
90    ///
91    /// ```
92    /// use ndarray::arr2;
93    ///
94    /// let a = arr2(&[[1., 2.],
95    ///                [3., 4.]]);
96    /// assert_eq!(a.product(), 24.);
97    /// ```
98    pub fn product(&self) -> A
99    where
100        A: Clone + Mul<Output = A> + num_traits::One,
101    {
102        if let Some(slc) = self.as_slice_memory_order() {
103            return numeric_util::unrolled_fold(slc, A::one, A::mul);
104        }
105        let mut sum = A::one();
106        for row in self.inner_rows() {
107            if let Some(slc) = row.as_slice() {
108                sum = sum * numeric_util::unrolled_fold(slc, A::one, A::mul);
109            } else {
110                sum = sum * row.iter().fold(A::one(), |acc, elt| acc * elt.clone());
111            }
112        }
113        sum
114    }
115
116    /// Return variance of elements in the array.
117    ///
118    /// The variance is computed using the [Welford one-pass
119    /// algorithm](https://www.jstor.org/stable/1266577).
120    ///
121    /// The parameter `ddof` specifies the "delta degrees of freedom". For
122    /// example, to calculate the population variance, use `ddof = 0`, or to
123    /// calculate the sample variance, use `ddof = 1`.
124    ///
125    /// The variance is defined as:
126    ///
127    /// ```text
128    ///               1       n
129    /// variance = ――――――――   ∑ (xᵢ - x̅)²
130    ///            n - ddof  i=1
131    /// ```
132    ///
133    /// where
134    ///
135    /// ```text
136    ///     1   n
137    /// x̅ = ―   ∑ xᵢ
138    ///     n  i=1
139    /// ```
140    ///
141    /// and `n` is the length of the array.
142    ///
143    /// **Panics** if `ddof` is less than zero or greater than `n`
144    ///
145    /// # Example
146    ///
147    /// ```
148    /// use ndarray::array;
149    /// use approx::assert_abs_diff_eq;
150    ///
151    /// let a = array![1., -4.32, 1.14, 0.32];
152    /// let var = a.var(1.);
153    /// assert_abs_diff_eq!(var, 6.7331, epsilon = 1e-4);
154    /// ```
155    #[cfg(feature = "std")]
156    pub fn var(&self, ddof: A) -> A
157    where
158        A: Float + FromPrimitive,
159    {
160        let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
161        let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail.");
162        assert!(
163            !(ddof < zero || ddof > n),
164            "`ddof` must not be less than zero or greater than the length of \
165             the axis",
166        );
167        let dof = n - ddof;
168        let mut mean = A::zero();
169        let mut sum_sq = A::zero();
170        let mut i = 0;
171        self.for_each(|&x| {
172            let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
173            let delta = x - mean;
174            mean = mean + delta / count;
175            sum_sq = (x - mean).mul_add(delta, sum_sq);
176            i += 1;
177        });
178        sum_sq / dof
179    }
180
181    /// Return standard deviation of elements in the array.
182    ///
183    /// The standard deviation is computed from the variance using
184    /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
185    ///
186    /// The parameter `ddof` specifies the "delta degrees of freedom". For
187    /// example, to calculate the population standard deviation, use `ddof = 0`,
188    /// or to calculate the sample standard deviation, use `ddof = 1`.
189    ///
190    /// The standard deviation is defined as:
191    ///
192    /// ```text
193    ///               ⎛    1       n          ⎞
194    /// stddev = sqrt ⎜ ――――――――   ∑ (xᵢ - x̅)²⎟
195    ///               ⎝ n - ddof  i=1         ⎠
196    /// ```
197    ///
198    /// where
199    ///
200    /// ```text
201    ///     1   n
202    /// x̅ = ―   ∑ xᵢ
203    ///     n  i=1
204    /// ```
205    ///
206    /// and `n` is the length of the array.
207    ///
208    /// **Panics** if `ddof` is less than zero or greater than `n`
209    ///
210    /// # Example
211    ///
212    /// ```
213    /// use ndarray::array;
214    /// use approx::assert_abs_diff_eq;
215    ///
216    /// let a = array![1., -4.32, 1.14, 0.32];
217    /// let stddev = a.std(1.);
218    /// assert_abs_diff_eq!(stddev, 2.59483, epsilon = 1e-4);
219    /// ```
220    #[cfg(feature = "std")]
221    pub fn std(&self, ddof: A) -> A
222    where
223        A: Float + FromPrimitive,
224    {
225        self.var(ddof).sqrt()
226    }
227
228    /// Return sum along `axis`.
229    ///
230    /// ```
231    /// use ndarray::{aview0, aview1, arr2, Axis};
232    ///
233    /// let a = arr2(&[[1., 2., 3.],
234    ///                [4., 5., 6.]]);
235    /// assert!(
236    ///     a.sum_axis(Axis(0)) == aview1(&[5., 7., 9.]) &&
237    ///     a.sum_axis(Axis(1)) == aview1(&[6., 15.]) &&
238    ///
239    ///     a.sum_axis(Axis(0)).sum_axis(Axis(0)) == aview0(&21.)
240    /// );
241    /// ```
242    ///
243    /// **Panics** if `axis` is out of bounds.
244    pub fn sum_axis(&self, axis: Axis) -> Array<A, D::Smaller>
245    where
246        A: Clone + Zero + Add<Output = A>,
247        D: RemoveAxis,
248    {
249        let n = self.len_of(axis);
250        let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
251        let stride = self.strides()[axis.index()];
252        if self.ndim() == 2 && stride == 1 {
253            // contiguous along the axis we are summing
254            let ax = axis.index();
255            for (i, elt) in enumerate(&mut res) {
256                *elt = self.index_axis(Axis(1 - ax), i).sum();
257            }
258        } else {
259            for i in 0..n {
260                let view = self.index_axis(axis, i);
261                res = res + &view;
262            }
263        }
264        res
265    }
266
267    /// Return mean along `axis`.
268    ///
269    /// Return `None` if the length of the axis is zero.
270    ///
271    /// **Panics** if `axis` is out of bounds or if `A::from_usize()`
272    /// fails for the axis length.
273    ///
274    /// ```
275    /// use ndarray::{aview0, aview1, arr2, Axis};
276    ///
277    /// let a = arr2(&[[1., 2., 3.],
278    ///                [4., 5., 6.]]);
279    /// assert!(
280    ///     a.mean_axis(Axis(0)).unwrap() == aview1(&[2.5, 3.5, 4.5]) &&
281    ///     a.mean_axis(Axis(1)).unwrap() == aview1(&[2., 5.]) &&
282    ///
283    ///     a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
284    /// );
285    /// ```
286    pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
287    where
288        A: Clone + Zero + FromPrimitive + Add<Output = A> + Div<Output = A>,
289        D: RemoveAxis,
290    {
291        let axis_length = self.len_of(axis);
292        if axis_length == 0 {
293            None
294        } else {
295            let axis_length =
296                A::from_usize(axis_length).expect("Converting axis length to `A` must not fail.");
297            let sum = self.sum_axis(axis);
298            Some(sum / aview0(&axis_length))
299        }
300    }
301
302    /// Return variance along `axis`.
303    ///
304    /// The variance is computed using the [Welford one-pass
305    /// algorithm](https://www.jstor.org/stable/1266577).
306    ///
307    /// The parameter `ddof` specifies the "delta degrees of freedom". For
308    /// example, to calculate the population variance, use `ddof = 0`, or to
309    /// calculate the sample variance, use `ddof = 1`.
310    ///
311    /// The variance is defined as:
312    ///
313    /// ```text
314    ///               1       n
315    /// variance = ――――――――   ∑ (xᵢ - x̅)²
316    ///            n - ddof  i=1
317    /// ```
318    ///
319    /// where
320    ///
321    /// ```text
322    ///     1   n
323    /// x̅ = ―   ∑ xᵢ
324    ///     n  i=1
325    /// ```
326    ///
327    /// and `n` is the length of the axis.
328    ///
329    /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
330    /// is out of bounds, or if `A::from_usize()` fails for any any of the
331    /// numbers in the range `0..=n`.
332    ///
333    /// # Example
334    ///
335    /// ```
336    /// use ndarray::{aview1, arr2, Axis};
337    ///
338    /// let a = arr2(&[[1., 2.],
339    ///                [3., 4.],
340    ///                [5., 6.]]);
341    /// let var = a.var_axis(Axis(0), 1.);
342    /// assert_eq!(var, aview1(&[4., 4.]));
343    /// ```
344    #[cfg(feature = "std")]
345    pub fn var_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
346    where
347        A: Float + FromPrimitive,
348        D: RemoveAxis,
349    {
350        let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
351        let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail.");
352        assert!(
353            !(ddof < zero || ddof > n),
354            "`ddof` must not be less than zero or greater than the length of \
355             the axis",
356        );
357        let dof = n - ddof;
358        let mut mean = Array::<A, _>::zeros(self.dim.remove_axis(axis));
359        let mut sum_sq = Array::<A, _>::zeros(self.dim.remove_axis(axis));
360        for (i, subview) in self.axis_iter(axis).enumerate() {
361            let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
362            azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) {
363                let delta = x - *mean;
364                *mean = *mean + delta / count;
365                *sum_sq = (x - *mean).mul_add(delta, *sum_sq);
366            });
367        }
368        sum_sq.mapv_into(|s| s / dof)
369    }
370
371    /// Return standard deviation along `axis`.
372    ///
373    /// The standard deviation is computed from the variance using
374    /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
375    ///
376    /// The parameter `ddof` specifies the "delta degrees of freedom". For
377    /// example, to calculate the population standard deviation, use `ddof = 0`,
378    /// or to calculate the sample standard deviation, use `ddof = 1`.
379    ///
380    /// The standard deviation is defined as:
381    ///
382    /// ```text
383    ///               ⎛    1       n          ⎞
384    /// stddev = sqrt ⎜ ――――――――   ∑ (xᵢ - x̅)²⎟
385    ///               ⎝ n - ddof  i=1         ⎠
386    /// ```
387    ///
388    /// where
389    ///
390    /// ```text
391    ///     1   n
392    /// x̅ = ―   ∑ xᵢ
393    ///     n  i=1
394    /// ```
395    ///
396    /// and `n` is the length of the axis.
397    ///
398    /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
399    /// is out of bounds, or if `A::from_usize()` fails for any any of the
400    /// numbers in the range `0..=n`.
401    ///
402    /// # Example
403    ///
404    /// ```
405    /// use ndarray::{aview1, arr2, Axis};
406    ///
407    /// let a = arr2(&[[1., 2.],
408    ///                [3., 4.],
409    ///                [5., 6.]]);
410    /// let stddev = a.std_axis(Axis(0), 1.);
411    /// assert_eq!(stddev, aview1(&[2., 2.]));
412    /// ```
413    #[cfg(feature = "std")]
414    pub fn std_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
415    where
416        A: Float + FromPrimitive,
417        D: RemoveAxis,
418    {
419        self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())
420    }
421}