ndarray/zip/
mod.rs

1// Copyright 2017 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[macro_use]
10mod zipmacro;
11mod ndproducer;
12
13#[cfg(feature = "rayon")]
14use std::mem::MaybeUninit;
15
16use crate::imp_prelude::*;
17use crate::AssignElem;
18use crate::IntoDimension;
19use crate::Layout;
20use crate::partial::Partial;
21
22use crate::indexes::{indices, Indices};
23use crate::layout::{CORDER, FORDER};
24use crate::split_at::{SplitPreference, SplitAt};
25
26pub use self::ndproducer::{NdProducer, IntoNdProducer, Offset};
27
28/// Return if the expression is a break value.
29macro_rules! fold_while {
30    ($e:expr) => {
31        match $e {
32            FoldWhile::Continue(x) => x,
33            x => return x,
34        }
35    };
36}
37
38/// Broadcast an array so that it acts like a larger size and/or shape array.
39///
40/// See [broadcasting][1] for more information.
41///
42/// [1]: struct.ArrayBase.html#broadcasting
43trait Broadcast<E>
44where
45    E: IntoDimension,
46{
47    type Output: NdProducer<Dim = E::Dim>;
48    /// Broadcast the array to the new dimensions `shape`.
49    ///
50    /// ***Panics*** if broadcasting isn’t possible.
51    fn broadcast_unwrap(self, shape: E) -> Self::Output;
52    private_decl! {}
53}
54
55impl<S, D> ArrayBase<S, D>
56where
57    S: RawData,
58    D: Dimension,
59{
60    pub(crate) fn layout_impl(&self) -> Layout {
61        let n = self.ndim();
62        if self.is_standard_layout() {
63            // effectively one-dimensional => C and F layout compatible
64            if n <= 1 || self.shape().iter().filter(|&&len| len > 1).count() <= 1 {
65                Layout::one_dimensional()
66            } else {
67                Layout::c()
68            }
69        } else if n > 1 && self.raw_view().reversed_axes().is_standard_layout() {
70            Layout::f()
71        } else if n > 1 {
72            if self.stride_of(Axis(0)) == 1 {
73                Layout::fpref()
74            } else if self.stride_of(Axis(n - 1)) == 1 {
75                Layout::cpref()
76            } else {
77                Layout::none()
78            }
79        } else {
80            Layout::none()
81        }
82    }
83}
84
85impl<'a, A, D, E> Broadcast<E> for ArrayView<'a, A, D>
86where
87    E: IntoDimension,
88    D: Dimension,
89{
90    type Output = ArrayView<'a, A, E::Dim>;
91    fn broadcast_unwrap(self, shape: E) -> Self::Output {
92        let res: ArrayView<'_, A, E::Dim> = (&self).broadcast_unwrap(shape.into_dimension());
93        unsafe { ArrayView::new(res.ptr, res.dim, res.strides) }
94    }
95    private_impl! {}
96}
97
98trait ZippableTuple: Sized {
99    type Item;
100    type Ptr: OffsetTuple<Args = Self::Stride> + Copy;
101    type Dim: Dimension;
102    type Stride: Copy;
103    fn as_ptr(&self) -> Self::Ptr;
104    unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item;
105    unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr;
106    fn stride_of(&self, index: usize) -> Self::Stride;
107    fn contiguous_stride(&self) -> Self::Stride;
108    fn split_at(self, axis: Axis, index: usize) -> (Self, Self);
109}
110
111/// Lock step function application across several arrays or other producers.
112///
113/// Zip allows matching several producers to each other elementwise and applying
114/// a function over all tuples of elements (one item from each input at
115/// a time).
116///
117/// In general, the zip uses a tuple of producers
118/// ([`NdProducer`](trait.NdProducer.html) trait) that all have to be of the
119/// same shape. The NdProducer implementation defines what its item type is
120/// (for example if it's a shared reference, mutable reference or an array
121/// view etc).
122///
123/// If all the input arrays are of the same memory layout the zip performs much
124/// better and the compiler can usually vectorize the loop (if applicable).
125///
126/// The order elements are visited is not specified. The producers don’t have to
127/// have the same item type.
128///
129/// The `Zip` has two methods for function application: `for_each` and
130/// `fold_while`. The zip object can be split, which allows parallelization.
131/// A read-only zip object (no mutable producers) can be cloned.
132///
133/// See also the [`azip!()` macro][az] which offers a convenient shorthand
134/// to common ways to use `Zip`.
135///
136/// [az]: macro.azip.html
137///
138/// ```
139/// use ndarray::Zip;
140/// use ndarray::Array2;
141///
142/// type M = Array2<f64>;
143///
144/// // Create four 2d arrays of the same size
145/// let mut a = M::zeros((64, 32));
146/// let b = M::from_elem(a.dim(), 1.);
147/// let c = M::from_elem(a.dim(), 2.);
148/// let d = M::from_elem(a.dim(), 3.);
149///
150/// // Example 1: Perform an elementwise arithmetic operation across
151/// // the four arrays a, b, c, d.
152///
153/// Zip::from(&mut a)
154///     .and(&b)
155///     .and(&c)
156///     .and(&d)
157///     .for_each(|w, &x, &y, &z| {
158///         *w += x + y * z;
159///     });
160///
161/// // Example 2: Create a new array `totals` with one entry per row of `a`.
162/// //  Use Zip to traverse the rows of `a` and assign to the corresponding
163/// //  entry in `totals` with the sum across each row.
164/// //  This is possible because the producer for `totals` and the row producer
165/// //  for `a` have the same shape and dimensionality.
166/// //  The rows producer yields one array view (`row`) per iteration.
167///
168/// use ndarray::{Array1, Axis};
169///
170/// let mut totals = Array1::zeros(a.nrows());
171///
172/// Zip::from(&mut totals)
173///     .and(a.rows())
174///     .for_each(|totals, row| *totals = row.sum());
175///
176/// // Check the result against the built in `.sum_axis()` along axis 1.
177/// assert_eq!(totals, a.sum_axis(Axis(1)));
178///
179///
180/// // Example 3: Recreate Example 2 using map_collect to make a new array
181///
182/// let mut totals2 = Zip::from(a.rows()).map_collect(|row| row.sum());
183///
184/// // Check the result against the previous example.
185/// assert_eq!(totals, totals2);
186/// ```
187#[derive(Debug, Clone)]
188pub struct Zip<Parts, D> {
189    parts: Parts,
190    dimension: D,
191    layout: Layout,
192    /// The sum of the layout tendencies of the parts;
193    /// positive for c- and negative for f-layout preference.
194    layout_tendency: i32,
195}
196
197
198impl<P, D> Zip<(P,), D>
199where
200    D: Dimension,
201    P: NdProducer<Dim = D>,
202{
203    /// Create a new `Zip` from the input array or other producer `p`.
204    ///
205    /// The Zip will take the exact dimension of `p` and all inputs
206    /// must have the same dimensions (or be broadcast to them).
207    pub fn from<IP>(p: IP) -> Self
208    where
209        IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>,
210    {
211        let array = p.into_producer();
212        let dim = array.raw_dim();
213        let layout = array.layout();
214        Zip {
215            dimension: dim,
216            layout,
217            parts: (array,),
218            layout_tendency: layout.tendency(),
219        }
220    }
221}
222impl<P, D> Zip<(Indices<D>, P), D>
223where
224    D: Dimension + Copy,
225    P: NdProducer<Dim = D>,
226{
227    /// Create a new `Zip` with an index producer and the producer `p`.
228    ///
229    /// The Zip will take the exact dimension of `p` and all inputs
230    /// must have the same dimensions (or be broadcast to them).
231    ///
232    /// *Note:* Indexed zip has overhead.
233    pub fn indexed<IP>(p: IP) -> Self
234    where
235        IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>,
236    {
237        let array = p.into_producer();
238        let dim = array.raw_dim();
239        Zip::from(indices(dim)).and(array)
240    }
241}
242
243impl<Parts, D> Zip<Parts, D>
244where
245    D: Dimension,
246{
247    fn check<P>(&self, part: &P)
248    where
249        P: NdProducer<Dim = D>,
250    {
251        ndassert!(
252            part.equal_dim(&self.dimension),
253            "Zip: Producer dimension mismatch, expected: {:?}, got: {:?}",
254            self.dimension,
255            part.raw_dim()
256        );
257    }
258
259    /// Return a the number of element tuples in the Zip
260    pub fn size(&self) -> usize {
261        self.dimension.size()
262    }
263
264    /// Return the length of `axis`
265    ///
266    /// ***Panics*** if `axis` is out of bounds.
267    fn len_of(&self, axis: Axis) -> usize {
268        self.dimension[axis.index()]
269    }
270
271    fn prefer_f(&self) -> bool {
272        !self.layout.is(CORDER) && (self.layout.is(FORDER) || self.layout_tendency < 0)
273    }
274
275    /// Return an *approximation* to the max stride axis; if
276    /// component arrays disagree, there may be no choice better than the
277    /// others.
278    fn max_stride_axis(&self) -> Axis {
279        let i = if self.prefer_f() {
280            self
281                .dimension
282                .slice()
283                .iter()
284                .rposition(|&len| len > 1)
285                .unwrap_or(self.dimension.ndim() - 1)
286        } else {
287            /* corder or default */
288            self
289                .dimension
290                .slice()
291                .iter()
292                .position(|&len| len > 1)
293                .unwrap_or(0)
294        };
295        Axis(i)
296    }
297}
298
299impl<P, D> Zip<P, D>
300where
301    D: Dimension,
302{
303    fn for_each_core<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
304    where
305        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
306        P: ZippableTuple<Dim = D>,
307    {
308        if self.dimension.ndim() == 0 {
309            function(acc, unsafe { self.parts.as_ref(self.parts.as_ptr()) })
310        } else if self.layout.is(CORDER | FORDER) {
311            self.for_each_core_contiguous(acc, function)
312        } else {
313            self.for_each_core_strided(acc, function)
314        }
315    }
316
317    fn for_each_core_contiguous<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
318    where
319        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
320        P: ZippableTuple<Dim = D>,
321    {
322        debug_assert!(self.layout.is(CORDER | FORDER));
323        let size = self.dimension.size();
324        let ptrs = self.parts.as_ptr();
325        let inner_strides = self.parts.contiguous_stride();
326        unsafe {
327            self.inner(acc, ptrs, inner_strides, size, &mut function)
328        }
329    }
330
331    /// The innermost loop of the Zip for_each methods
332    ///
333    /// Run the fold while operation on a stretch of elements with constant strides
334    ///
335    /// `ptr`: base pointer for the first element in this stretch
336    /// `strides`: strides for the elements in this stretch
337    /// `len`: number of elements
338    /// `function`: closure
339    unsafe fn inner<F, Acc>(&self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride,
340                            len: usize, function: &mut F) -> FoldWhile<Acc>
341    where
342        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
343        P: ZippableTuple
344    {
345        let mut i = 0;
346        while i < len {
347            let p = ptr.stride_offset(strides, i);
348            acc = fold_while!(function(acc, self.parts.as_ref(p)));
349            i += 1;
350        }
351        FoldWhile::Continue(acc)
352    }
353
354
355    fn for_each_core_strided<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
356    where
357        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
358        P: ZippableTuple<Dim = D>,
359    {
360        let n = self.dimension.ndim();
361        if n == 0 {
362            panic!("Unreachable: ndim == 0 is contiguous")
363        }
364        if n == 1 || self.layout_tendency >= 0 {
365            self.for_each_core_strided_c(acc, function)
366        } else {
367            self.for_each_core_strided_f(acc, function)
368        }
369    }
370
371    // Non-contiguous but preference for C - unroll over Axis(ndim - 1)
372    fn for_each_core_strided_c<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
373    where
374        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
375        P: ZippableTuple<Dim = D>,
376    {
377        let n = self.dimension.ndim();
378        let unroll_axis = n - 1;
379        let inner_len = self.dimension[unroll_axis];
380        self.dimension[unroll_axis] = 1;
381        let mut index_ = self.dimension.first_index();
382        let inner_strides = self.parts.stride_of(unroll_axis);
383        // Loop unrolled over closest axis
384        while let Some(index) = index_ {
385            unsafe {
386                let ptr = self.parts.uget_ptr(&index);
387                acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
388            }
389
390            index_ = self.dimension.next_for(index);
391        }
392        FoldWhile::Continue(acc)
393    }
394
395    // Non-contiguous but preference for F - unroll over Axis(0)
396    fn for_each_core_strided_f<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
397    where
398        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
399        P: ZippableTuple<Dim = D>,
400    {
401        let unroll_axis = 0;
402        let inner_len = self.dimension[unroll_axis];
403        self.dimension[unroll_axis] = 1;
404        let index_ = self.dimension.first_index();
405        let inner_strides = self.parts.stride_of(unroll_axis);
406        // Loop unrolled over closest axis
407        if let Some(mut index) = index_ {
408            loop {
409                unsafe {
410                    let ptr = self.parts.uget_ptr(&index);
411                    acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
412                }
413
414                if !self.dimension.next_for_f(&mut index) {
415                    break;
416                }
417            }
418        }
419        FoldWhile::Continue(acc)
420    }
421
422    #[cfg(feature = "rayon")]
423    pub(crate) fn uninitalized_for_current_layout<T>(&self) -> Array<MaybeUninit<T>, D>
424    {
425        let is_f = self.prefer_f();
426        Array::uninit(self.dimension.clone().set_f(is_f))
427    }
428}
429
430/*
431trait Offset : Copy {
432    unsafe fn offset(self, off: isize) -> Self;
433    unsafe fn stride_offset(self, index: usize, stride: isize) -> Self {
434        self.offset(index as isize * stride)
435    }
436}
437
438impl<T> Offset for *mut T {
439    unsafe fn offset(self, off: isize) -> Self {
440        self.offset(off)
441    }
442}
443*/
444
445trait OffsetTuple {
446    type Args;
447    unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self;
448}
449
450impl<T> OffsetTuple for *mut T {
451    type Args = isize;
452    unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self {
453        self.offset(index as isize * stride)
454    }
455}
456
457macro_rules! offset_impl {
458    ($([$($param:ident)*][ $($q:ident)*],)+) => {
459        $(
460        #[allow(non_snake_case)]
461        impl<$($param: Offset),*> OffsetTuple for ($($param, )*) {
462            type Args = ($($param::Stride,)*);
463            unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self {
464                let ($($param, )*) = self;
465                let ($($q, )*) = stride;
466                ($(Offset::stride_offset($param, $q, index),)*)
467            }
468        }
469        )+
470    }
471}
472
473offset_impl! {
474    [A ][ a],
475    [A B][ a b],
476    [A B C][ a b c],
477    [A B C D][ a b c d],
478    [A B C D E][ a b c d e],
479    [A B C D E F][ a b c d e f],
480}
481
482macro_rules! zipt_impl {
483    ($([$($p:ident)*][ $($q:ident)*],)+) => {
484        $(
485        #[allow(non_snake_case)]
486        impl<Dim: Dimension, $($p: NdProducer<Dim=Dim>),*> ZippableTuple for ($($p, )*) {
487            type Item = ($($p::Item, )*);
488            type Ptr = ($($p::Ptr, )*);
489            type Dim = Dim;
490            type Stride = ($($p::Stride,)* );
491
492            fn stride_of(&self, index: usize) -> Self::Stride {
493                let ($(ref $p,)*) = *self;
494                ($($p.stride_of(Axis(index)), )*)
495            }
496
497            fn contiguous_stride(&self) -> Self::Stride {
498                let ($(ref $p,)*) = *self;
499                ($($p.contiguous_stride(), )*)
500            }
501
502            fn as_ptr(&self) -> Self::Ptr {
503                let ($(ref $p,)*) = *self;
504                ($($p.as_ptr(), )*)
505            }
506            unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
507                let ($(ref $q ,)*) = *self;
508                let ($($p,)*) = ptr;
509                ($($q.as_ref($p),)*)
510            }
511
512            unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
513                let ($(ref $p,)*) = *self;
514                ($($p.uget_ptr(i), )*)
515            }
516
517            fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
518                let ($($p,)*) = self;
519                let ($($p,)*) = (
520                    $($p.split_at(axis, index), )*
521                );
522                (
523                    ($($p.0,)*),
524                    ($($p.1,)*)
525                )
526            }
527        }
528        )+
529    }
530}
531
532zipt_impl! {
533    [A ][ a],
534    [A B][ a b],
535    [A B C][ a b c],
536    [A B C D][ a b c d],
537    [A B C D E][ a b c d e],
538    [A B C D E F][ a b c d e f],
539}
540
541macro_rules! map_impl {
542    ($([$notlast:ident $($p:ident)*],)+) => {
543        $(
544        #[allow(non_snake_case)]
545        impl<D, $($p),*> Zip<($($p,)*), D>
546            where D: Dimension,
547                  $($p: NdProducer<Dim=D> ,)*
548        {
549            /// Apply a function to all elements of the input arrays,
550            /// visiting elements in lock step.
551            pub fn for_each<F>(mut self, mut function: F)
552                where F: FnMut($($p::Item),*)
553            {
554                self.for_each_core((), move |(), args| {
555                    let ($($p,)*) = args;
556                    FoldWhile::Continue(function($($p),*))
557                });
558            }
559
560            /// Apply a function to all elements of the input arrays,
561            /// visiting elements in lock step.
562            #[deprecated(note="Renamed to .for_each()", since="0.15.0")]
563            pub fn apply<F>(self, function: F)
564                where F: FnMut($($p::Item),*)
565            {
566                self.for_each(function)
567            }
568
569            /// Apply a fold function to all elements of the input arrays,
570            /// visiting elements in lock step.
571            ///
572            /// # Example
573            ///
574            /// The expression `tr(AᵀB)` can be more efficiently computed as
575            /// the equivalent expression `∑ᵢⱼ(A∘B)ᵢⱼ` (i.e. the sum of the
576            /// elements of the entry-wise product). It would be possible to
577            /// evaluate this expression by first computing the entry-wise
578            /// product, `A∘B`, and then computing the elementwise sum of that
579            /// product, but it's possible to do this in a single loop (and
580            /// avoid an extra heap allocation if `A` and `B` can't be
581            /// consumed) by using `Zip`:
582            ///
583            /// ```
584            /// use ndarray::{array, Zip};
585            ///
586            /// let a = array![[1, 5], [3, 7]];
587            /// let b = array![[2, 4], [8, 6]];
588            ///
589            /// // Without using `Zip`. This involves two loops and an extra
590            /// // heap allocation for the result of `&a * &b`.
591            /// let sum_prod_nonzip = (&a * &b).sum();
592            /// // Using `Zip`. This is a single loop without any heap allocations.
593            /// let sum_prod_zip = Zip::from(&a).and(&b).fold(0, |acc, a, b| acc + a * b);
594            ///
595            /// assert_eq!(sum_prod_nonzip, sum_prod_zip);
596            /// ```
597            pub fn fold<F, Acc>(mut self, acc: Acc, mut function: F) -> Acc
598            where
599                F: FnMut(Acc, $($p::Item),*) -> Acc,
600            {
601                self.for_each_core(acc, move |acc, args| {
602                    let ($($p,)*) = args;
603                    FoldWhile::Continue(function(acc, $($p),*))
604                }).into_inner()
605            }
606
607            /// Apply a fold function to the input arrays while the return
608            /// value is `FoldWhile::Continue`, visiting elements in lock step.
609            ///
610            pub fn fold_while<F, Acc>(mut self, acc: Acc, mut function: F)
611                -> FoldWhile<Acc>
612                where F: FnMut(Acc, $($p::Item),*) -> FoldWhile<Acc>
613            {
614                self.for_each_core(acc, move |acc, args| {
615                    let ($($p,)*) = args;
616                    function(acc, $($p),*)
617                })
618            }
619
620            /// Tests if every element of the iterator matches a predicate.
621            ///
622            /// Returns `true` if `predicate` evaluates to `true` for all elements.
623            /// Returns `true` if the input arrays are empty.
624            ///
625            /// Example:
626            ///
627            /// ```
628            /// use ndarray::{array, Zip};
629            /// let a = array![1, 2, 3];
630            /// let b = array![1, 4, 9];
631            /// assert!(Zip::from(&a).and(&b).all(|&a, &b| a * a == b));
632            /// ```
633            pub fn all<F>(mut self, mut predicate: F) -> bool
634                where F: FnMut($($p::Item),*) -> bool
635            {
636                !self.for_each_core((), move |_, args| {
637                    let ($($p,)*) = args;
638                    if predicate($($p),*) {
639                        FoldWhile::Continue(())
640                    } else {
641                        FoldWhile::Done(())
642                    }
643                }).is_done()
644            }
645
646            expand_if!(@bool [$notlast]
647
648            /// Include the producer `p` in the Zip.
649            ///
650            /// ***Panics*** if `p`’s shape doesn’t match the Zip’s exactly.
651            pub fn and<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
652                where P: IntoNdProducer<Dim=D>,
653            {
654                let part = p.into_producer();
655                self.check(&part);
656                self.build_and(part)
657            }
658
659            /// Include the producer `p` in the Zip, broadcasting if needed.
660            ///
661            /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
662            ///
663            /// ***Panics*** if broadcasting isn’t possible.
664            pub fn and_broadcast<'a, P, D2, Elem>(self, p: P)
665                -> Zip<($($p,)* ArrayView<'a, Elem, D>, ), D>
666                where P: IntoNdProducer<Dim=D2, Output=ArrayView<'a, Elem, D2>, Item=&'a Elem>,
667                      D2: Dimension,
668            {
669                let part = p.into_producer().broadcast_unwrap(self.dimension.clone());
670                self.build_and(part)
671            }
672
673            fn build_and<P>(self, part: P) -> Zip<($($p,)* P, ), D>
674                where P: NdProducer<Dim=D>,
675            {
676                let part_layout = part.layout();
677                let ($($p,)*) = self.parts;
678                Zip {
679                    parts: ($($p,)* part, ),
680                    layout: self.layout.intersect(part_layout),
681                    dimension: self.dimension,
682                    layout_tendency: self.layout_tendency + part_layout.tendency(),
683                }
684            }
685
686            /// Map and collect the results into a new array, which has the same size as the
687            /// inputs.
688            ///
689            /// If all inputs are c- or f-order respectively, that is preserved in the output.
690            pub fn map_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D> {
691                self.map_collect_owned(f)
692            }
693
694            pub(crate) fn map_collect_owned<S, R>(self, f: impl FnMut($($p::Item,)* ) -> R)
695                -> ArrayBase<S, D>
696                where S: DataOwned<Elem = R>
697            {
698                // safe because: all elements are written before the array is completed
699
700                let shape = self.dimension.clone().set_f(self.prefer_f());
701                let output = <ArrayBase<S, D>>::build_uninit(shape, |output| {
702                    // Use partial to count the number of filled elements, and can drop the right
703                    // number of elements on unwinding (if it happens during apply/collect).
704                    unsafe {
705                        let output_view = output.cast::<R>();
706                        self.and(output_view)
707                            .collect_with_partial(f)
708                            .release_ownership();
709                    }
710                });
711                unsafe {
712                    output.assume_init()
713                }
714            }
715
716            /// Map and collect the results into a new array, which has the same size as the
717            /// inputs.
718            ///
719            /// If all inputs are c- or f-order respectively, that is preserved in the output.
720            #[deprecated(note="Renamed to .map_collect()", since="0.15.0")]
721            pub fn apply_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D> {
722                self.map_collect(f)
723            }
724
725            /// Map and assign the results into the producer `into`, which should have the same
726            /// size as the other inputs.
727            ///
728            /// The producer should have assignable items as dictated by the `AssignElem` trait,
729            /// for example `&mut R`.
730            pub fn map_assign_into<R, Q>(self, into: Q, mut f: impl FnMut($($p::Item,)* ) -> R)
731                where Q: IntoNdProducer<Dim=D>,
732                      Q::Item: AssignElem<R>
733            {
734                self.and(into)
735                    .for_each(move |$($p, )* output_| {
736                        output_.assign_elem(f($($p ),*));
737                    });
738            }
739
740            /// Map and assign the results into the producer `into`, which should have the same
741            /// size as the other inputs.
742            ///
743            /// The producer should have assignable items as dictated by the `AssignElem` trait,
744            /// for example `&mut R`.
745            #[deprecated(note="Renamed to .map_assign_into()", since="0.15.0")]
746            pub fn apply_assign_into<R, Q>(self, into: Q, f: impl FnMut($($p::Item,)* ) -> R)
747                where Q: IntoNdProducer<Dim=D>,
748                      Q::Item: AssignElem<R>
749            {
750                self.map_assign_into(into, f)
751            }
752
753
754            );
755
756            /// Split the `Zip` evenly in two.
757            ///
758            /// It will be split in the way that best preserves element locality.
759            pub fn split(self) -> (Self, Self) {
760                debug_assert_ne!(self.size(), 0, "Attempt to split empty zip");
761                debug_assert_ne!(self.size(), 1, "Attempt to split zip with 1 elem");
762                SplitPreference::split(self)
763            }
764        }
765
766        expand_if!(@bool [$notlast]
767            // For collect; Last producer is a RawViewMut
768            #[allow(non_snake_case)]
769            impl<D, PLast, R, $($p),*> Zip<($($p,)* PLast), D>
770                where D: Dimension,
771                      $($p: NdProducer<Dim=D> ,)*
772                      PLast: NdProducer<Dim = D, Item = *mut R, Ptr = *mut R, Stride = isize>,
773            {
774                /// The inner workings of map_collect and par_map_collect
775                ///
776                /// Apply the function and collect the results into the output (last producer)
777                /// which should be a raw array view; a Partial that owns the written
778                /// elements is returned.
779                ///
780                /// Elements will be overwritten in place (in the sense of std::ptr::write).
781                ///
782                /// ## Safety
783                ///
784                /// The last producer is a RawArrayViewMut and must be safe to write into.
785                /// The producer must be c- or f-contig and have the same layout tendency
786                /// as the whole Zip.
787                ///
788                /// The returned Partial's proxy ownership of the elements must be handled,
789                /// before the array the raw view points to realizes its ownership.
790                pub(crate) unsafe fn collect_with_partial<F>(self, mut f: F) -> Partial<R>
791                    where F: FnMut($($p::Item,)* ) -> R
792                {
793                    // Get the last producer; and make a Partial that aliases its data pointer
794                    let (.., ref output) = &self.parts;
795
796                    // debug assert that the output is contiguous in the memory layout we need
797                    if cfg!(debug_assertions) {
798                        let out_layout = output.layout();
799                        assert!(out_layout.is(CORDER | FORDER));
800                        assert!(
801                            (self.layout_tendency <= 0 && out_layout.tendency() <= 0) ||
802                            (self.layout_tendency >= 0 && out_layout.tendency() >= 0),
803                            "layout tendency violation for self layout {:?}, output layout {:?},\
804                            output shape {:?}",
805                            self.layout, out_layout, output.raw_dim());
806                    }
807
808                    let mut partial = Partial::new(output.as_ptr());
809
810                    // Apply the mapping function on this zip
811                    // if we panic with unwinding; Partial will drop the written elements.
812                    let partial_len = &mut partial.len;
813                    self.for_each(move |$($p,)* output_elem: *mut R| {
814                        output_elem.write(f($($p),*));
815                        if std::mem::needs_drop::<R>() {
816                            *partial_len += 1;
817                        }
818                    });
819
820                    partial
821                }
822            }
823        );
824
825        impl<D, $($p),*> SplitPreference for Zip<($($p,)*), D>
826            where D: Dimension,
827                  $($p: NdProducer<Dim=D> ,)*
828        {
829            fn can_split(&self) -> bool { self.size() > 1 }
830
831            fn split_preference(&self) -> (Axis, usize) {
832                // Always split in a way that preserves layout (if any)
833                let axis = self.max_stride_axis();
834                let index = self.len_of(axis) / 2;
835                (axis, index)
836            }
837        }
838
839        impl<D, $($p),*> SplitAt for Zip<($($p,)*), D>
840            where D: Dimension,
841                  $($p: NdProducer<Dim=D> ,)*
842        {
843            fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
844                let (p1, p2) = self.parts.split_at(axis, index);
845                let (d1, d2) = self.dimension.split_at(axis, index);
846                (Zip {
847                    dimension: d1,
848                    layout: self.layout,
849                    parts: p1,
850                    layout_tendency: self.layout_tendency,
851                },
852                Zip {
853                    dimension: d2,
854                    layout: self.layout,
855                    parts: p2,
856                    layout_tendency: self.layout_tendency,
857                })
858            }
859
860        }
861
862        )+
863    }
864}
865
866map_impl! {
867    [true P1],
868    [true P1 P2],
869    [true P1 P2 P3],
870    [true P1 P2 P3 P4],
871    [true P1 P2 P3 P4 P5],
872    [false P1 P2 P3 P4 P5 P6],
873}
874
875/// Value controlling the execution of `.fold_while` on `Zip`.
876#[derive(Debug, Copy, Clone)]
877pub enum FoldWhile<T> {
878    /// Continue folding with this value
879    Continue(T),
880    /// Fold is complete and will return this value
881    Done(T),
882}
883
884impl<T> FoldWhile<T> {
885    /// Return the inner value
886    pub fn into_inner(self) -> T {
887        match self {
888            FoldWhile::Continue(x) | FoldWhile::Done(x) => x,
889        }
890    }
891
892    /// Return true if it is `Done`, false if `Continue`
893    pub fn is_done(&self) -> bool {
894        match *self {
895            FoldWhile::Continue(_) => false,
896            FoldWhile::Done(_) => true,
897        }
898    }
899}