ndarray/
shape_builder.rs

1use crate::dimension::IntoDimension;
2use crate::Dimension;
3
4/// A contiguous array shape of n dimensions.
5///
6/// Either c- or f- memory ordered (*c* a.k.a *row major* is the default).
7#[derive(Copy, Clone, Debug)]
8pub struct Shape<D> {
9    /// Shape (axis lengths)
10    pub(crate) dim: D,
11    /// Strides can only be C or F here
12    pub(crate) strides: Strides<Contiguous>,
13}
14
15#[derive(Copy, Clone, Debug)]
16pub(crate) enum Contiguous { }
17
18impl<D> Shape<D> {
19    pub(crate) fn is_c(&self) -> bool {
20        matches!(self.strides, Strides::C)
21    }
22}
23
24
25/// An array shape of n dimensions in c-order, f-order or custom strides.
26#[derive(Copy, Clone, Debug)]
27pub struct StrideShape<D> {
28    pub(crate) dim: D,
29    pub(crate) strides: Strides<D>,
30}
31
32/// Stride description
33#[derive(Copy, Clone, Debug)]
34pub(crate) enum Strides<D> {
35    /// Row-major ("C"-order)
36    C,
37    /// Column-major ("F"-order)
38    F,
39    /// Custom strides
40    Custom(D)
41}
42
43impl<D> Strides<D> {
44    /// Return strides for `dim` (computed from dimension if c/f, else return the custom stride)
45    pub(crate) fn strides_for_dim(self, dim: &D) -> D
46        where D: Dimension
47    {
48        match self {
49            Strides::C => dim.default_strides(),
50            Strides::F => dim.fortran_strides(),
51            Strides::Custom(c) => {
52                debug_assert_eq!(c.ndim(), dim.ndim(),
53                    "Custom strides given with {} dimensions, expected {}",
54                    c.ndim(), dim.ndim());
55                c
56            }
57        }
58    }
59
60    pub(crate) fn is_custom(&self) -> bool {
61        matches!(*self, Strides::Custom(_))
62    }
63}
64
65/// A trait for `Shape` and `D where D: Dimension` that allows
66/// customizing the memory layout (strides) of an array shape.
67///
68/// This trait is used together with array constructor methods like
69/// `Array::from_shape_vec`.
70pub trait ShapeBuilder {
71    type Dim: Dimension;
72    type Strides;
73
74    fn into_shape(self) -> Shape<Self::Dim>;
75    fn f(self) -> Shape<Self::Dim>;
76    fn set_f(self, is_f: bool) -> Shape<Self::Dim>;
77    fn strides(self, strides: Self::Strides) -> StrideShape<Self::Dim>;
78}
79
80impl<D> From<D> for Shape<D>
81where
82    D: Dimension,
83{
84    /// Create a `Shape` from `dimension`, using the default memory layout.
85    fn from(dimension: D) -> Shape<D> {
86        dimension.into_shape()
87    }
88}
89
90impl<T, D> From<T> for StrideShape<D>
91where
92    D: Dimension,
93    T: ShapeBuilder<Dim = D>,
94{
95    fn from(value: T) -> Self {
96        let shape = value.into_shape();
97        let st = if shape.is_c() {
98            Strides::C
99        } else {
100            Strides::F
101        };
102        StrideShape {
103            strides: st,
104            dim: shape.dim,
105        }
106    }
107}
108
109impl<T> ShapeBuilder for T
110where
111    T: IntoDimension,
112{
113    type Dim = T::Dim;
114    type Strides = T;
115    fn into_shape(self) -> Shape<Self::Dim> {
116        Shape {
117            dim: self.into_dimension(),
118            strides: Strides::C,
119        }
120    }
121    fn f(self) -> Shape<Self::Dim> {
122        self.set_f(true)
123    }
124    fn set_f(self, is_f: bool) -> Shape<Self::Dim> {
125        self.into_shape().set_f(is_f)
126    }
127    fn strides(self, st: T) -> StrideShape<Self::Dim> {
128        self.into_shape().strides(st.into_dimension())
129    }
130}
131
132impl<D> ShapeBuilder for Shape<D>
133where
134    D: Dimension,
135{
136    type Dim = D;
137    type Strides = D;
138
139    fn into_shape(self) -> Shape<D> {
140        self
141    }
142
143    fn f(self) -> Self {
144        self.set_f(true)
145    }
146
147    fn set_f(mut self, is_f: bool) -> Self {
148        self.strides = if !is_f { Strides::C } else { Strides::F };
149        self
150    }
151
152    fn strides(self, st: D) -> StrideShape<D> {
153        StrideShape {
154            dim: self.dim,
155            strides: Strides::Custom(st),
156        }
157    }
158}
159
160impl<D> Shape<D>
161where
162    D: Dimension,
163{
164    // Return a reference to the dimension
165    //pub fn dimension(&self) -> &D { &self.dim }
166    /// Return the size of the shape in number of elements
167    pub fn size(&self) -> usize {
168        self.dim.size()
169    }
170}