ndarray/numeric/impl_numeric.rs
1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[cfg(feature = "std")]
10use num_traits::Float;
11use num_traits::{self, FromPrimitive, Zero};
12use std::ops::{Add, Div, Mul};
13
14use crate::imp_prelude::*;
15use crate::itertools::enumerate;
16use crate::numeric_util;
17
18/// # Numerical Methods for Arrays
19impl<A, S, D> ArrayBase<S, D>
20where
21 S: Data<Elem = A>,
22 D: Dimension,
23{
24 /// Return the sum of all elements in the array.
25 ///
26 /// ```
27 /// use ndarray::arr2;
28 ///
29 /// let a = arr2(&[[1., 2.],
30 /// [3., 4.]]);
31 /// assert_eq!(a.sum(), 10.);
32 /// ```
33 pub fn sum(&self) -> A
34 where
35 A: Clone + Add<Output = A> + num_traits::Zero,
36 {
37 if let Some(slc) = self.as_slice_memory_order() {
38 return numeric_util::unrolled_fold(slc, A::zero, A::add);
39 }
40 let mut sum = A::zero();
41 for row in self.inner_rows() {
42 if let Some(slc) = row.as_slice() {
43 sum = sum + numeric_util::unrolled_fold(slc, A::zero, A::add);
44 } else {
45 sum = sum + row.iter().fold(A::zero(), |acc, elt| acc + elt.clone());
46 }
47 }
48 sum
49 }
50
51 /// Return the sum of all elements in the array.
52 ///
53 /// *This method has been renamed to `.sum()`*
54 #[deprecated(note="renamed to `sum`", since="0.15.0")]
55 pub fn scalar_sum(&self) -> A
56 where
57 A: Clone + Add<Output = A> + num_traits::Zero,
58 {
59 self.sum()
60 }
61
62 /// Returns the [arithmetic mean] x̅ of all elements in the array:
63 ///
64 /// ```text
65 /// 1 n
66 /// x̅ = ― ∑ xᵢ
67 /// n i=1
68 /// ```
69 ///
70 /// If the array is empty, `None` is returned.
71 ///
72 /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
73 ///
74 /// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean
75 pub fn mean(&self) -> Option<A>
76 where
77 A: Clone + FromPrimitive + Add<Output = A> + Div<Output = A> + Zero,
78 {
79 let n_elements = self.len();
80 if n_elements == 0 {
81 None
82 } else {
83 let n_elements = A::from_usize(n_elements)
84 .expect("Converting number of elements to `A` must not fail.");
85 Some(self.sum() / n_elements)
86 }
87 }
88
89 /// Return the product of all elements in the array.
90 ///
91 /// ```
92 /// use ndarray::arr2;
93 ///
94 /// let a = arr2(&[[1., 2.],
95 /// [3., 4.]]);
96 /// assert_eq!(a.product(), 24.);
97 /// ```
98 pub fn product(&self) -> A
99 where
100 A: Clone + Mul<Output = A> + num_traits::One,
101 {
102 if let Some(slc) = self.as_slice_memory_order() {
103 return numeric_util::unrolled_fold(slc, A::one, A::mul);
104 }
105 let mut sum = A::one();
106 for row in self.inner_rows() {
107 if let Some(slc) = row.as_slice() {
108 sum = sum * numeric_util::unrolled_fold(slc, A::one, A::mul);
109 } else {
110 sum = sum * row.iter().fold(A::one(), |acc, elt| acc * elt.clone());
111 }
112 }
113 sum
114 }
115
116 /// Return variance of elements in the array.
117 ///
118 /// The variance is computed using the [Welford one-pass
119 /// algorithm](https://www.jstor.org/stable/1266577).
120 ///
121 /// The parameter `ddof` specifies the "delta degrees of freedom". For
122 /// example, to calculate the population variance, use `ddof = 0`, or to
123 /// calculate the sample variance, use `ddof = 1`.
124 ///
125 /// The variance is defined as:
126 ///
127 /// ```text
128 /// 1 n
129 /// variance = ―――――――― ∑ (xᵢ - x̅)²
130 /// n - ddof i=1
131 /// ```
132 ///
133 /// where
134 ///
135 /// ```text
136 /// 1 n
137 /// x̅ = ― ∑ xᵢ
138 /// n i=1
139 /// ```
140 ///
141 /// and `n` is the length of the array.
142 ///
143 /// **Panics** if `ddof` is less than zero or greater than `n`
144 ///
145 /// # Example
146 ///
147 /// ```
148 /// use ndarray::array;
149 /// use approx::assert_abs_diff_eq;
150 ///
151 /// let a = array![1., -4.32, 1.14, 0.32];
152 /// let var = a.var(1.);
153 /// assert_abs_diff_eq!(var, 6.7331, epsilon = 1e-4);
154 /// ```
155 #[cfg(feature = "std")]
156 pub fn var(&self, ddof: A) -> A
157 where
158 A: Float + FromPrimitive,
159 {
160 let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
161 let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail.");
162 assert!(
163 !(ddof < zero || ddof > n),
164 "`ddof` must not be less than zero or greater than the length of \
165 the axis",
166 );
167 let dof = n - ddof;
168 let mut mean = A::zero();
169 let mut sum_sq = A::zero();
170 let mut i = 0;
171 self.for_each(|&x| {
172 let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
173 let delta = x - mean;
174 mean = mean + delta / count;
175 sum_sq = (x - mean).mul_add(delta, sum_sq);
176 i += 1;
177 });
178 sum_sq / dof
179 }
180
181 /// Return standard deviation of elements in the array.
182 ///
183 /// The standard deviation is computed from the variance using
184 /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
185 ///
186 /// The parameter `ddof` specifies the "delta degrees of freedom". For
187 /// example, to calculate the population standard deviation, use `ddof = 0`,
188 /// or to calculate the sample standard deviation, use `ddof = 1`.
189 ///
190 /// The standard deviation is defined as:
191 ///
192 /// ```text
193 /// ⎛ 1 n ⎞
194 /// stddev = sqrt ⎜ ―――――――― ∑ (xᵢ - x̅)²⎟
195 /// ⎝ n - ddof i=1 ⎠
196 /// ```
197 ///
198 /// where
199 ///
200 /// ```text
201 /// 1 n
202 /// x̅ = ― ∑ xᵢ
203 /// n i=1
204 /// ```
205 ///
206 /// and `n` is the length of the array.
207 ///
208 /// **Panics** if `ddof` is less than zero or greater than `n`
209 ///
210 /// # Example
211 ///
212 /// ```
213 /// use ndarray::array;
214 /// use approx::assert_abs_diff_eq;
215 ///
216 /// let a = array![1., -4.32, 1.14, 0.32];
217 /// let stddev = a.std(1.);
218 /// assert_abs_diff_eq!(stddev, 2.59483, epsilon = 1e-4);
219 /// ```
220 #[cfg(feature = "std")]
221 pub fn std(&self, ddof: A) -> A
222 where
223 A: Float + FromPrimitive,
224 {
225 self.var(ddof).sqrt()
226 }
227
228 /// Return sum along `axis`.
229 ///
230 /// ```
231 /// use ndarray::{aview0, aview1, arr2, Axis};
232 ///
233 /// let a = arr2(&[[1., 2., 3.],
234 /// [4., 5., 6.]]);
235 /// assert!(
236 /// a.sum_axis(Axis(0)) == aview1(&[5., 7., 9.]) &&
237 /// a.sum_axis(Axis(1)) == aview1(&[6., 15.]) &&
238 ///
239 /// a.sum_axis(Axis(0)).sum_axis(Axis(0)) == aview0(&21.)
240 /// );
241 /// ```
242 ///
243 /// **Panics** if `axis` is out of bounds.
244 pub fn sum_axis(&self, axis: Axis) -> Array<A, D::Smaller>
245 where
246 A: Clone + Zero + Add<Output = A>,
247 D: RemoveAxis,
248 {
249 let n = self.len_of(axis);
250 let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
251 let stride = self.strides()[axis.index()];
252 if self.ndim() == 2 && stride == 1 {
253 // contiguous along the axis we are summing
254 let ax = axis.index();
255 for (i, elt) in enumerate(&mut res) {
256 *elt = self.index_axis(Axis(1 - ax), i).sum();
257 }
258 } else {
259 for i in 0..n {
260 let view = self.index_axis(axis, i);
261 res = res + &view;
262 }
263 }
264 res
265 }
266
267 /// Return mean along `axis`.
268 ///
269 /// Return `None` if the length of the axis is zero.
270 ///
271 /// **Panics** if `axis` is out of bounds or if `A::from_usize()`
272 /// fails for the axis length.
273 ///
274 /// ```
275 /// use ndarray::{aview0, aview1, arr2, Axis};
276 ///
277 /// let a = arr2(&[[1., 2., 3.],
278 /// [4., 5., 6.]]);
279 /// assert!(
280 /// a.mean_axis(Axis(0)).unwrap() == aview1(&[2.5, 3.5, 4.5]) &&
281 /// a.mean_axis(Axis(1)).unwrap() == aview1(&[2., 5.]) &&
282 ///
283 /// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
284 /// );
285 /// ```
286 pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
287 where
288 A: Clone + Zero + FromPrimitive + Add<Output = A> + Div<Output = A>,
289 D: RemoveAxis,
290 {
291 let axis_length = self.len_of(axis);
292 if axis_length == 0 {
293 None
294 } else {
295 let axis_length =
296 A::from_usize(axis_length).expect("Converting axis length to `A` must not fail.");
297 let sum = self.sum_axis(axis);
298 Some(sum / aview0(&axis_length))
299 }
300 }
301
302 /// Return variance along `axis`.
303 ///
304 /// The variance is computed using the [Welford one-pass
305 /// algorithm](https://www.jstor.org/stable/1266577).
306 ///
307 /// The parameter `ddof` specifies the "delta degrees of freedom". For
308 /// example, to calculate the population variance, use `ddof = 0`, or to
309 /// calculate the sample variance, use `ddof = 1`.
310 ///
311 /// The variance is defined as:
312 ///
313 /// ```text
314 /// 1 n
315 /// variance = ―――――――― ∑ (xᵢ - x̅)²
316 /// n - ddof i=1
317 /// ```
318 ///
319 /// where
320 ///
321 /// ```text
322 /// 1 n
323 /// x̅ = ― ∑ xᵢ
324 /// n i=1
325 /// ```
326 ///
327 /// and `n` is the length of the axis.
328 ///
329 /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
330 /// is out of bounds, or if `A::from_usize()` fails for any any of the
331 /// numbers in the range `0..=n`.
332 ///
333 /// # Example
334 ///
335 /// ```
336 /// use ndarray::{aview1, arr2, Axis};
337 ///
338 /// let a = arr2(&[[1., 2.],
339 /// [3., 4.],
340 /// [5., 6.]]);
341 /// let var = a.var_axis(Axis(0), 1.);
342 /// assert_eq!(var, aview1(&[4., 4.]));
343 /// ```
344 #[cfg(feature = "std")]
345 pub fn var_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
346 where
347 A: Float + FromPrimitive,
348 D: RemoveAxis,
349 {
350 let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
351 let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail.");
352 assert!(
353 !(ddof < zero || ddof > n),
354 "`ddof` must not be less than zero or greater than the length of \
355 the axis",
356 );
357 let dof = n - ddof;
358 let mut mean = Array::<A, _>::zeros(self.dim.remove_axis(axis));
359 let mut sum_sq = Array::<A, _>::zeros(self.dim.remove_axis(axis));
360 for (i, subview) in self.axis_iter(axis).enumerate() {
361 let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
362 azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) {
363 let delta = x - *mean;
364 *mean = *mean + delta / count;
365 *sum_sq = (x - *mean).mul_add(delta, *sum_sq);
366 });
367 }
368 sum_sq.mapv_into(|s| s / dof)
369 }
370
371 /// Return standard deviation along `axis`.
372 ///
373 /// The standard deviation is computed from the variance using
374 /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
375 ///
376 /// The parameter `ddof` specifies the "delta degrees of freedom". For
377 /// example, to calculate the population standard deviation, use `ddof = 0`,
378 /// or to calculate the sample standard deviation, use `ddof = 1`.
379 ///
380 /// The standard deviation is defined as:
381 ///
382 /// ```text
383 /// ⎛ 1 n ⎞
384 /// stddev = sqrt ⎜ ―――――――― ∑ (xᵢ - x̅)²⎟
385 /// ⎝ n - ddof i=1 ⎠
386 /// ```
387 ///
388 /// where
389 ///
390 /// ```text
391 /// 1 n
392 /// x̅ = ― ∑ xᵢ
393 /// n i=1
394 /// ```
395 ///
396 /// and `n` is the length of the axis.
397 ///
398 /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
399 /// is out of bounds, or if `A::from_usize()` fails for any any of the
400 /// numbers in the range `0..=n`.
401 ///
402 /// # Example
403 ///
404 /// ```
405 /// use ndarray::{aview1, arr2, Axis};
406 ///
407 /// let a = arr2(&[[1., 2.],
408 /// [3., 4.],
409 /// [5., 6.]]);
410 /// let stddev = a.std_axis(Axis(0), 1.);
411 /// assert_eq!(stddev, aview1(&[2., 2.]));
412 /// ```
413 #[cfg(feature = "std")]
414 pub fn std_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
415 where
416 A: Float + FromPrimitive,
417 D: RemoveAxis,
418 {
419 self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())
420 }
421}