ndarray/
indexes.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8use super::Dimension;
9use crate::dimension::IntoDimension;
10use crate::zip::Offset;
11use crate::split_at::SplitAt;
12use crate::Axis;
13use crate::Layout;
14use crate::NdProducer;
15use crate::{ArrayBase, Data};
16
17/// An iterator over the indexes of an array shape.
18///
19/// Iterator element type is `D`.
20#[derive(Clone)]
21pub struct IndicesIter<D> {
22    dim: D,
23    index: Option<D>,
24}
25
26/// Create an iterable of the array shape `shape`.
27///
28/// *Note:* prefer higher order methods, arithmetic operations and
29/// non-indexed iteration before using indices.
30pub fn indices<E>(shape: E) -> Indices<E::Dim>
31where
32    E: IntoDimension,
33{
34    let dim = shape.into_dimension();
35    Indices {
36        start: E::Dim::zeros(dim.ndim()),
37        dim,
38    }
39}
40
41/// Return an iterable of the indices of the passed-in array.
42///
43/// *Note:* prefer higher order methods, arithmetic operations and
44/// non-indexed iteration before using indices.
45pub fn indices_of<S, D>(array: &ArrayBase<S, D>) -> Indices<D>
46where
47    S: Data,
48    D: Dimension,
49{
50    indices(array.dim())
51}
52
53impl<D> Iterator for IndicesIter<D>
54where
55    D: Dimension,
56{
57    type Item = D::Pattern;
58    #[inline]
59    fn next(&mut self) -> Option<Self::Item> {
60        let index = match self.index {
61            None => return None,
62            Some(ref ix) => ix.clone(),
63        };
64        self.index = self.dim.next_for(index.clone());
65        Some(index.into_pattern())
66    }
67
68    fn size_hint(&self) -> (usize, Option<usize>) {
69        let l = match self.index {
70            None => 0,
71            Some(ref ix) => {
72                let gone = self
73                    .dim
74                    .default_strides()
75                    .slice()
76                    .iter()
77                    .zip(ix.slice().iter())
78                    .fold(0, |s, (&a, &b)| s + a as usize * b as usize);
79                self.dim.size() - gone
80            }
81        };
82        (l, Some(l))
83    }
84
85    fn fold<B, F>(self, init: B, mut f: F) -> B
86    where
87        F: FnMut(B, D::Pattern) -> B,
88    {
89        let IndicesIter { mut index, dim } = self;
90        let ndim = dim.ndim();
91        if ndim == 0 {
92            return match index {
93                Some(ix) => f(init, ix.into_pattern()),
94                None => init,
95            };
96        }
97        let inner_axis = ndim - 1;
98        let inner_len = dim[inner_axis];
99        let mut acc = init;
100        while let Some(mut ix) = index {
101            // unroll innermost axis
102            for i in ix[inner_axis]..inner_len {
103                ix[inner_axis] = i;
104                acc = f(acc, ix.clone().into_pattern());
105            }
106            index = dim.next_for(ix);
107        }
108        acc
109    }
110}
111
112impl<D> ExactSizeIterator for IndicesIter<D> where D: Dimension {}
113
114impl<D> IntoIterator for Indices<D>
115where
116    D: Dimension,
117{
118    type Item = D::Pattern;
119    type IntoIter = IndicesIter<D>;
120    fn into_iter(self) -> Self::IntoIter {
121        let sz = self.dim.size();
122        let index = if sz != 0 { Some(self.start) } else { None };
123        IndicesIter {
124            index,
125            dim: self.dim,
126        }
127    }
128}
129
130/// Indices producer and iterable.
131///
132/// `Indices` is an `NdProducer` that produces the indices of an array shape.
133#[derive(Copy, Clone, Debug)]
134pub struct Indices<D>
135where
136    D: Dimension,
137{
138    start: D,
139    dim: D,
140}
141
142#[derive(Copy, Clone, Debug)]
143pub struct IndexPtr<D> {
144    index: D,
145}
146
147impl<D> Offset for IndexPtr<D>
148where
149    D: Dimension + Copy,
150{
151    // stride: The axis to increment
152    type Stride = usize;
153
154    unsafe fn stride_offset(mut self, stride: Self::Stride, index: usize) -> Self {
155        self.index[stride] += index;
156        self
157    }
158    private_impl! {}
159}
160
161// How the NdProducer for Indices works.
162//
163// NdProducer allows for raw pointers (Ptr), strides (Stride) and the produced
164// item (Item).
165//
166// Instead of Ptr, there is `IndexPtr<D>` which is an index value, like [0, 0, 0]
167// for the three dimensional case.
168//
169// The stride is simply which axis is currently being incremented. The stride for axis 1, is 1.
170//
171// .stride_offset(stride, index) simply computes the new index along that axis, for example:
172// [0, 0, 0].stride_offset(1, 10) => [0, 10, 0]  axis 1 is incremented by 10.
173//
174// .as_ref() converts the Ptr value to an Item. For example [0, 10, 0] => (0, 10, 0)
175impl<D: Dimension + Copy> NdProducer for Indices<D> {
176    type Item = D::Pattern;
177    type Dim = D;
178    type Ptr = IndexPtr<D>;
179    type Stride = usize;
180
181    private_impl! {}
182
183    #[doc(hidden)]
184    fn raw_dim(&self) -> Self::Dim {
185        self.dim
186    }
187
188    #[doc(hidden)]
189    fn equal_dim(&self, dim: &Self::Dim) -> bool {
190        self.dim.equal(dim)
191    }
192
193    #[doc(hidden)]
194    fn as_ptr(&self) -> Self::Ptr {
195        IndexPtr { index: self.start }
196    }
197
198    #[doc(hidden)]
199    fn layout(&self) -> Layout {
200        if self.dim.ndim() <= 1 {
201            Layout::one_dimensional()
202        } else {
203            Layout::none()
204        }
205    }
206
207    #[doc(hidden)]
208    unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
209        ptr.index.into_pattern()
210    }
211
212    #[doc(hidden)]
213    unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
214        let mut index = *i;
215        index += &self.start;
216        IndexPtr { index }
217    }
218
219    #[doc(hidden)]
220    fn stride_of(&self, axis: Axis) -> Self::Stride {
221        axis.index()
222    }
223
224    #[inline(always)]
225    fn contiguous_stride(&self) -> Self::Stride {
226        0
227    }
228
229    #[doc(hidden)]
230    fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
231        let start_a = self.start;
232        let mut start_b = start_a;
233        let (a, b) = self.dim.split_at(axis, index);
234        start_b[axis.index()] += index;
235        (
236            Indices {
237                start: start_a,
238                dim: a,
239            },
240            Indices {
241                start: start_b,
242                dim: b,
243            },
244        )
245    }
246}
247
248/// An iterator over the indexes of an array shape.
249///
250/// Iterator element type is `D`.
251#[derive(Clone)]
252pub struct IndicesIterF<D> {
253    dim: D,
254    index: D,
255    has_remaining: bool,
256}
257
258pub fn indices_iter_f<E>(shape: E) -> IndicesIterF<E::Dim>
259where
260    E: IntoDimension,
261{
262    let dim = shape.into_dimension();
263    let zero = E::Dim::zeros(dim.ndim());
264    IndicesIterF {
265        has_remaining: dim.size_checked() != Some(0),
266        index: zero,
267        dim,
268    }
269}
270
271impl<D> Iterator for IndicesIterF<D>
272where
273    D: Dimension,
274{
275    type Item = D::Pattern;
276    #[inline]
277    fn next(&mut self) -> Option<Self::Item> {
278        if !self.has_remaining {
279            None
280        } else {
281            let elt = self.index.clone().into_pattern();
282            self.has_remaining = self.dim.next_for_f(&mut self.index);
283            Some(elt)
284        }
285    }
286
287    fn size_hint(&self) -> (usize, Option<usize>) {
288        if !self.has_remaining {
289            return (0, Some(0));
290        }
291        let gone = self
292            .dim
293            .fortran_strides()
294            .slice()
295            .iter()
296            .zip(self.index.slice().iter())
297            .fold(0, |s, (&a, &b)| s + a as usize * b as usize);
298        let l = self.dim.size() - gone;
299        (l, Some(l))
300    }
301}
302
303impl<D> ExactSizeIterator for IndicesIterF<D> where D: Dimension {}
304
305#[cfg(test)]
306mod tests {
307    use super::indices;
308    use super::indices_iter_f;
309
310    #[test]
311    fn test_indices_iter_c_size_hint() {
312        let dim = (3, 4);
313        let mut it = indices(dim).into_iter();
314        let mut len = dim.0 * dim.1;
315        assert_eq!(it.len(), len);
316        while let Some(_) = it.next() {
317            len -= 1;
318            assert_eq!(it.len(), len);
319        }
320        assert_eq!(len, 0);
321    }
322
323    #[test]
324    fn test_indices_iter_c_fold() {
325        macro_rules! run_test {
326            ($dim:expr) => {
327                for num_consume in 0..3 {
328                    let mut it = indices($dim).into_iter();
329                    for _ in 0..num_consume {
330                        it.next();
331                    }
332                    let clone = it.clone();
333                    let len = it.len();
334                    let acc = clone.fold(0, |acc, ix| {
335                        assert_eq!(ix, it.next().unwrap());
336                        acc + 1
337                    });
338                    assert_eq!(acc, len);
339                    assert!(it.next().is_none());
340                }
341            };
342        }
343        run_test!(());
344        run_test!((2,));
345        run_test!((2, 3));
346        run_test!((2, 0, 3));
347        run_test!((2, 3, 4));
348        run_test!((2, 3, 4, 2));
349    }
350
351    #[test]
352    fn test_indices_iter_f_size_hint() {
353        let dim = (3, 4);
354        let mut it = indices_iter_f(dim);
355        let mut len = dim.0 * dim.1;
356        assert_eq!(it.len(), len);
357        while let Some(_) = it.next() {
358            len -= 1;
359            assert_eq!(it.len(), len);
360        }
361        assert_eq!(len, 0);
362    }
363}