1use crate::dimension::IntoDimension;
2use crate::Dimension;
3
4#[derive(Copy, Clone, Debug)]
8pub struct Shape<D> {
9 pub(crate) dim: D,
11 pub(crate) strides: Strides<Contiguous>,
13}
14
15#[derive(Copy, Clone, Debug)]
16pub(crate) enum Contiguous { }
17
18impl<D> Shape<D> {
19 pub(crate) fn is_c(&self) -> bool {
20 matches!(self.strides, Strides::C)
21 }
22}
23
24
25#[derive(Copy, Clone, Debug)]
27pub struct StrideShape<D> {
28 pub(crate) dim: D,
29 pub(crate) strides: Strides<D>,
30}
31
32#[derive(Copy, Clone, Debug)]
34pub(crate) enum Strides<D> {
35 C,
37 F,
39 Custom(D)
41}
42
43impl<D> Strides<D> {
44 pub(crate) fn strides_for_dim(self, dim: &D) -> D
46 where D: Dimension
47 {
48 match self {
49 Strides::C => dim.default_strides(),
50 Strides::F => dim.fortran_strides(),
51 Strides::Custom(c) => {
52 debug_assert_eq!(c.ndim(), dim.ndim(),
53 "Custom strides given with {} dimensions, expected {}",
54 c.ndim(), dim.ndim());
55 c
56 }
57 }
58 }
59
60 pub(crate) fn is_custom(&self) -> bool {
61 matches!(*self, Strides::Custom(_))
62 }
63}
64
65pub trait ShapeBuilder {
71 type Dim: Dimension;
72 type Strides;
73
74 fn into_shape(self) -> Shape<Self::Dim>;
75 fn f(self) -> Shape<Self::Dim>;
76 fn set_f(self, is_f: bool) -> Shape<Self::Dim>;
77 fn strides(self, strides: Self::Strides) -> StrideShape<Self::Dim>;
78}
79
80impl<D> From<D> for Shape<D>
81where
82 D: Dimension,
83{
84 fn from(dimension: D) -> Shape<D> {
86 dimension.into_shape()
87 }
88}
89
90impl<T, D> From<T> for StrideShape<D>
91where
92 D: Dimension,
93 T: ShapeBuilder<Dim = D>,
94{
95 fn from(value: T) -> Self {
96 let shape = value.into_shape();
97 let st = if shape.is_c() {
98 Strides::C
99 } else {
100 Strides::F
101 };
102 StrideShape {
103 strides: st,
104 dim: shape.dim,
105 }
106 }
107}
108
109impl<T> ShapeBuilder for T
110where
111 T: IntoDimension,
112{
113 type Dim = T::Dim;
114 type Strides = T;
115 fn into_shape(self) -> Shape<Self::Dim> {
116 Shape {
117 dim: self.into_dimension(),
118 strides: Strides::C,
119 }
120 }
121 fn f(self) -> Shape<Self::Dim> {
122 self.set_f(true)
123 }
124 fn set_f(self, is_f: bool) -> Shape<Self::Dim> {
125 self.into_shape().set_f(is_f)
126 }
127 fn strides(self, st: T) -> StrideShape<Self::Dim> {
128 self.into_shape().strides(st.into_dimension())
129 }
130}
131
132impl<D> ShapeBuilder for Shape<D>
133where
134 D: Dimension,
135{
136 type Dim = D;
137 type Strides = D;
138
139 fn into_shape(self) -> Shape<D> {
140 self
141 }
142
143 fn f(self) -> Self {
144 self.set_f(true)
145 }
146
147 fn set_f(mut self, is_f: bool) -> Self {
148 self.strides = if !is_f { Strides::C } else { Strides::F };
149 self
150 }
151
152 fn strides(self, st: D) -> StrideShape<D> {
153 StrideShape {
154 dim: self.dim,
155 strides: Strides::Custom(st),
156 }
157 }
158}
159
160impl<D> Shape<D>
161where
162 D: Dimension,
163{
164 pub fn size(&self) -> usize {
168 self.dim.size()
169 }
170}