ndarray/
stacking.rs

1// Copyright 2014-2020 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::error::{from_kind, ErrorKind, ShapeError};
10use crate::imp_prelude::*;
11
12/// Stack arrays along the new axis.
13///
14/// ***Errors*** if the arrays have mismatching shapes.
15/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
16/// if the result is larger than is possible to represent.
17///
18/// ```
19/// extern crate ndarray;
20///
21/// use ndarray::{arr2, arr3, stack, Axis};
22///
23/// # fn main() {
24///
25/// let a = arr2(&[[2., 2.],
26///                [3., 3.]]);
27/// assert!(
28///     stack(Axis(0), &[a.view(), a.view()])
29///     == Ok(arr3(&[[[2., 2.],
30///                   [3., 3.]],
31///                  [[2., 2.],
32///                   [3., 3.]]]))
33/// );
34/// # }
35/// ```
36pub fn stack<A, D>(
37    axis: Axis,
38    arrays: &[ArrayView<A, D>],
39) -> Result<Array<A, D::Larger>, ShapeError>
40where
41    A: Copy,
42    D: Dimension,
43    D::Larger: RemoveAxis,
44{
45    #[allow(deprecated)]
46    stack_new_axis(axis, arrays)
47}
48
49/// Concatenate arrays along the given axis.
50///
51/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
52/// (may be made more flexible in the future).<br>
53/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
54/// if the result is larger than is possible to represent.
55///
56/// ```
57/// use ndarray::{arr2, Axis, concatenate};
58///
59/// let a = arr2(&[[2., 2.],
60///                [3., 3.]]);
61/// assert!(
62///     concatenate(Axis(0), &[a.view(), a.view()])
63///     == Ok(arr2(&[[2., 2.],
64///                  [3., 3.],
65///                  [2., 2.],
66///                  [3., 3.]]))
67/// );
68/// ```
69pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
70where
71    A: Copy,
72    D: RemoveAxis,
73{
74    if arrays.is_empty() {
75        return Err(from_kind(ErrorKind::Unsupported));
76    }
77    let mut res_dim = arrays[0].raw_dim();
78    if axis.index() >= res_dim.ndim() {
79        return Err(from_kind(ErrorKind::OutOfBounds));
80    }
81    let common_dim = res_dim.remove_axis(axis);
82    if arrays
83        .iter()
84        .any(|a| a.raw_dim().remove_axis(axis) != common_dim)
85    {
86        return Err(from_kind(ErrorKind::IncompatibleShape));
87    }
88
89    let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis));
90    res_dim.set_axis(axis, stacked_dim);
91
92    // we can safely use uninitialized values here because we will
93    // overwrite every one of them.
94    let mut res = Array::uninit(res_dim);
95
96    {
97        let mut assign_view = res.view_mut();
98        for array in arrays {
99            let len = array.len_of(axis);
100            let (front, rest) = assign_view.split_at(axis, len);
101            array.assign_to(front);
102            assign_view = rest;
103        }
104        debug_assert_eq!(assign_view.len(), 0);
105    }
106    unsafe {
107        Ok(res.assume_init())
108    }
109}
110
111#[deprecated(note="Use under the name stack instead.", since="0.15.0")]
112/// Stack arrays along the new axis.
113///
114/// ***Errors*** if the arrays have mismatching shapes.
115/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
116/// if the result is larger than is possible to represent.
117///
118/// ```
119/// extern crate ndarray;
120///
121/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
122///
123/// # fn main() {
124///
125/// let a = arr2(&[[2., 2.],
126///                [3., 3.]]);
127/// assert!(
128///     stack_new_axis(Axis(0), &[a.view(), a.view()])
129///     == Ok(arr3(&[[[2., 2.],
130///                   [3., 3.]],
131///                  [[2., 2.],
132///                   [3., 3.]]]))
133/// );
134/// # }
135/// ```
136pub fn stack_new_axis<A, D>(
137    axis: Axis,
138    arrays: &[ArrayView<A, D>],
139) -> Result<Array<A, D::Larger>, ShapeError>
140where
141    A: Copy,
142    D: Dimension,
143    D::Larger: RemoveAxis,
144{
145    if arrays.is_empty() {
146        return Err(from_kind(ErrorKind::Unsupported));
147    }
148    let common_dim = arrays[0].raw_dim();
149    // Avoid panic on `insert_axis` call, return an Err instead of it.
150    if axis.index() > common_dim.ndim() {
151        return Err(from_kind(ErrorKind::OutOfBounds));
152    }
153    let mut res_dim = common_dim.insert_axis(axis);
154
155    if arrays.iter().any(|a| a.raw_dim() != common_dim) {
156        return Err(from_kind(ErrorKind::IncompatibleShape));
157    }
158
159    res_dim.set_axis(axis, arrays.len());
160
161    // we can safely use uninitialized values here because we will
162    // overwrite every one of them.
163    let mut res = Array::uninit(res_dim);
164
165    res.axis_iter_mut(axis)
166        .zip(arrays.iter())
167        .for_each(|(assign_view, array)| {
168            // assign_view is D::Larger::Smaller which is usually == D
169            // (but if D is Ix6, we have IxD != Ix6 here; differing types
170            // but same number of axes).
171            let assign_view = assign_view.into_dimensionality::<D>()
172                .expect("same-dimensionality cast");
173            array.assign_to(assign_view);
174        });
175
176    unsafe {
177        Ok(res.assume_init())
178    }
179}
180
181/// Stack arrays along the new axis.
182///
183/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each
184/// argument `a`.
185///
186/// [1]: fn.stack.html
187///
188/// ***Panics*** if the `stack` function would return an error.
189///
190/// ```
191/// extern crate ndarray;
192///
193/// use ndarray::{arr2, arr3, stack, Axis};
194///
195/// # fn main() {
196///
197/// let a = arr2(&[[2., 2.],
198///                [3., 3.]]);
199/// assert!(
200///     stack![Axis(0), a, a]
201///     == arr3(&[[[2., 2.],
202///                [3., 3.]],
203///               [[2., 2.],
204///                [3., 3.]]])
205/// );
206/// # }
207/// ```
208#[macro_export]
209macro_rules! stack {
210    ($axis:expr, $( $array:expr ),+ ) => {
211        $crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
212    }
213}
214
215/// Concatenate arrays along the given axis.
216///
217/// Uses the [`concatenate`][1] function, calling `ArrayView::from(&a)` on each
218/// argument `a`.
219///
220/// [1]: fn.concatenate.html
221///
222/// ***Panics*** if the `concatenate` function would return an error.
223///
224/// ```
225/// extern crate ndarray;
226///
227/// use ndarray::{arr2, concatenate, Axis};
228///
229/// # fn main() {
230///
231/// let a = arr2(&[[2., 2.],
232///                [3., 3.]]);
233/// assert!(
234///     concatenate![Axis(0), a, a]
235///     == arr2(&[[2., 2.],
236///               [3., 3.],
237///               [2., 2.],
238///               [3., 3.]])
239/// );
240/// # }
241/// ```
242#[macro_export]
243macro_rules! concatenate {
244    ($axis:expr, $( $array:expr ),+ ) => {
245        $crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
246    }
247}
248
249/// Stack arrays along the new axis.
250///
251/// Uses the [`stack_new_axis`][1] function, calling `ArrayView::from(&a)` on each
252/// argument `a`.
253///
254/// [1]: fn.stack_new_axis.html
255///
256/// ***Panics*** if the `stack` function would return an error.
257///
258/// ```
259/// extern crate ndarray;
260///
261/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
262///
263/// # fn main() {
264///
265/// let a = arr2(&[[2., 2.],
266///                [3., 3.]]);
267/// assert!(
268///     stack_new_axis![Axis(0), a, a]
269///     == arr3(&[[[2., 2.],
270///                [3., 3.]],
271///               [[2., 2.],
272///                [3., 3.]]])
273/// );
274/// # }
275/// ```
276#[macro_export]
277#[deprecated(note="Use under the name stack instead.", since="0.15.0")]
278macro_rules! stack_new_axis {
279    ($axis:expr, $( $array:expr ),+ ) => {
280        $crate::stack_new_axis($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
281    }
282}