ndarray/stacking.rs
1// Copyright 2014-2020 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::error::{from_kind, ErrorKind, ShapeError};
10use crate::imp_prelude::*;
11
12/// Stack arrays along the new axis.
13///
14/// ***Errors*** if the arrays have mismatching shapes.
15/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
16/// if the result is larger than is possible to represent.
17///
18/// ```
19/// extern crate ndarray;
20///
21/// use ndarray::{arr2, arr3, stack, Axis};
22///
23/// # fn main() {
24///
25/// let a = arr2(&[[2., 2.],
26/// [3., 3.]]);
27/// assert!(
28/// stack(Axis(0), &[a.view(), a.view()])
29/// == Ok(arr3(&[[[2., 2.],
30/// [3., 3.]],
31/// [[2., 2.],
32/// [3., 3.]]]))
33/// );
34/// # }
35/// ```
36pub fn stack<A, D>(
37 axis: Axis,
38 arrays: &[ArrayView<A, D>],
39) -> Result<Array<A, D::Larger>, ShapeError>
40where
41 A: Copy,
42 D: Dimension,
43 D::Larger: RemoveAxis,
44{
45 #[allow(deprecated)]
46 stack_new_axis(axis, arrays)
47}
48
49/// Concatenate arrays along the given axis.
50///
51/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
52/// (may be made more flexible in the future).<br>
53/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
54/// if the result is larger than is possible to represent.
55///
56/// ```
57/// use ndarray::{arr2, Axis, concatenate};
58///
59/// let a = arr2(&[[2., 2.],
60/// [3., 3.]]);
61/// assert!(
62/// concatenate(Axis(0), &[a.view(), a.view()])
63/// == Ok(arr2(&[[2., 2.],
64/// [3., 3.],
65/// [2., 2.],
66/// [3., 3.]]))
67/// );
68/// ```
69pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
70where
71 A: Copy,
72 D: RemoveAxis,
73{
74 if arrays.is_empty() {
75 return Err(from_kind(ErrorKind::Unsupported));
76 }
77 let mut res_dim = arrays[0].raw_dim();
78 if axis.index() >= res_dim.ndim() {
79 return Err(from_kind(ErrorKind::OutOfBounds));
80 }
81 let common_dim = res_dim.remove_axis(axis);
82 if arrays
83 .iter()
84 .any(|a| a.raw_dim().remove_axis(axis) != common_dim)
85 {
86 return Err(from_kind(ErrorKind::IncompatibleShape));
87 }
88
89 let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis));
90 res_dim.set_axis(axis, stacked_dim);
91
92 // we can safely use uninitialized values here because we will
93 // overwrite every one of them.
94 let mut res = Array::uninit(res_dim);
95
96 {
97 let mut assign_view = res.view_mut();
98 for array in arrays {
99 let len = array.len_of(axis);
100 let (front, rest) = assign_view.split_at(axis, len);
101 array.assign_to(front);
102 assign_view = rest;
103 }
104 debug_assert_eq!(assign_view.len(), 0);
105 }
106 unsafe {
107 Ok(res.assume_init())
108 }
109}
110
111#[deprecated(note="Use under the name stack instead.", since="0.15.0")]
112/// Stack arrays along the new axis.
113///
114/// ***Errors*** if the arrays have mismatching shapes.
115/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
116/// if the result is larger than is possible to represent.
117///
118/// ```
119/// extern crate ndarray;
120///
121/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
122///
123/// # fn main() {
124///
125/// let a = arr2(&[[2., 2.],
126/// [3., 3.]]);
127/// assert!(
128/// stack_new_axis(Axis(0), &[a.view(), a.view()])
129/// == Ok(arr3(&[[[2., 2.],
130/// [3., 3.]],
131/// [[2., 2.],
132/// [3., 3.]]]))
133/// );
134/// # }
135/// ```
136pub fn stack_new_axis<A, D>(
137 axis: Axis,
138 arrays: &[ArrayView<A, D>],
139) -> Result<Array<A, D::Larger>, ShapeError>
140where
141 A: Copy,
142 D: Dimension,
143 D::Larger: RemoveAxis,
144{
145 if arrays.is_empty() {
146 return Err(from_kind(ErrorKind::Unsupported));
147 }
148 let common_dim = arrays[0].raw_dim();
149 // Avoid panic on `insert_axis` call, return an Err instead of it.
150 if axis.index() > common_dim.ndim() {
151 return Err(from_kind(ErrorKind::OutOfBounds));
152 }
153 let mut res_dim = common_dim.insert_axis(axis);
154
155 if arrays.iter().any(|a| a.raw_dim() != common_dim) {
156 return Err(from_kind(ErrorKind::IncompatibleShape));
157 }
158
159 res_dim.set_axis(axis, arrays.len());
160
161 // we can safely use uninitialized values here because we will
162 // overwrite every one of them.
163 let mut res = Array::uninit(res_dim);
164
165 res.axis_iter_mut(axis)
166 .zip(arrays.iter())
167 .for_each(|(assign_view, array)| {
168 // assign_view is D::Larger::Smaller which is usually == D
169 // (but if D is Ix6, we have IxD != Ix6 here; differing types
170 // but same number of axes).
171 let assign_view = assign_view.into_dimensionality::<D>()
172 .expect("same-dimensionality cast");
173 array.assign_to(assign_view);
174 });
175
176 unsafe {
177 Ok(res.assume_init())
178 }
179}
180
181/// Stack arrays along the new axis.
182///
183/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each
184/// argument `a`.
185///
186/// [1]: fn.stack.html
187///
188/// ***Panics*** if the `stack` function would return an error.
189///
190/// ```
191/// extern crate ndarray;
192///
193/// use ndarray::{arr2, arr3, stack, Axis};
194///
195/// # fn main() {
196///
197/// let a = arr2(&[[2., 2.],
198/// [3., 3.]]);
199/// assert!(
200/// stack![Axis(0), a, a]
201/// == arr3(&[[[2., 2.],
202/// [3., 3.]],
203/// [[2., 2.],
204/// [3., 3.]]])
205/// );
206/// # }
207/// ```
208#[macro_export]
209macro_rules! stack {
210 ($axis:expr, $( $array:expr ),+ ) => {
211 $crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
212 }
213}
214
215/// Concatenate arrays along the given axis.
216///
217/// Uses the [`concatenate`][1] function, calling `ArrayView::from(&a)` on each
218/// argument `a`.
219///
220/// [1]: fn.concatenate.html
221///
222/// ***Panics*** if the `concatenate` function would return an error.
223///
224/// ```
225/// extern crate ndarray;
226///
227/// use ndarray::{arr2, concatenate, Axis};
228///
229/// # fn main() {
230///
231/// let a = arr2(&[[2., 2.],
232/// [3., 3.]]);
233/// assert!(
234/// concatenate![Axis(0), a, a]
235/// == arr2(&[[2., 2.],
236/// [3., 3.],
237/// [2., 2.],
238/// [3., 3.]])
239/// );
240/// # }
241/// ```
242#[macro_export]
243macro_rules! concatenate {
244 ($axis:expr, $( $array:expr ),+ ) => {
245 $crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
246 }
247}
248
249/// Stack arrays along the new axis.
250///
251/// Uses the [`stack_new_axis`][1] function, calling `ArrayView::from(&a)` on each
252/// argument `a`.
253///
254/// [1]: fn.stack_new_axis.html
255///
256/// ***Panics*** if the `stack` function would return an error.
257///
258/// ```
259/// extern crate ndarray;
260///
261/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
262///
263/// # fn main() {
264///
265/// let a = arr2(&[[2., 2.],
266/// [3., 3.]]);
267/// assert!(
268/// stack_new_axis![Axis(0), a, a]
269/// == arr3(&[[[2., 2.],
270/// [3., 3.]],
271/// [[2., 2.],
272/// [3., 3.]]])
273/// );
274/// # }
275/// ```
276#[macro_export]
277#[deprecated(note="Use under the name stack instead.", since="0.15.0")]
278macro_rules! stack_new_axis {
279 ($axis:expr, $( $array:expr ),+ ) => {
280 $crate::stack_new_axis($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
281 }
282}