1use crate::dimension::DimMax;
10use crate::Zip;
11use num_complex::Complex;
12
13pub trait ScalarOperand: 'static + Clone {}
35impl ScalarOperand for bool {}
36impl ScalarOperand for i8 {}
37impl ScalarOperand for u8 {}
38impl ScalarOperand for i16 {}
39impl ScalarOperand for u16 {}
40impl ScalarOperand for i32 {}
41impl ScalarOperand for u32 {}
42impl ScalarOperand for i64 {}
43impl ScalarOperand for u64 {}
44impl ScalarOperand for i128 {}
45impl ScalarOperand for u128 {}
46impl ScalarOperand for isize {}
47impl ScalarOperand for usize {}
48impl ScalarOperand for f32 {}
49impl ScalarOperand for f64 {}
50impl ScalarOperand for Complex<f32> {}
51impl ScalarOperand for Complex<f64> {}
52
53macro_rules! impl_binary_op(
54 ($trt:ident, $operator:tt, $mth:ident, $iop:tt, $doc:expr) => (
55#[doc=$doc]
57impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
66where
67 A: Clone + $trt<B, Output=A>,
68 B: Clone,
69 S: DataOwned<Elem=A> + DataMut,
70 S2: Data<Elem=B>,
71 D: Dimension + DimMax<E>,
72 E: Dimension,
73{
74 type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
75 fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
76 {
77 self.$mth(&rhs)
78 }
79}
80
81#[doc=$doc]
83impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
93where
94 A: Clone + $trt<B, Output=A>,
95 B: Clone,
96 S: DataOwned<Elem=A> + DataMut,
97 S2: Data<Elem=B>,
98 D: Dimension + DimMax<E>,
99 E: Dimension,
100{
101 type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
102 fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
103 {
104 if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
105 let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
106 out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
107 out
108 } else {
109 let (lhs, rhs) = self.broadcast_with(rhs).unwrap();
110 Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth))
111 }
112 }
113}
114
115#[doc=$doc]
117impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
127where
128 A: Clone + $trt<B, Output=B>,
129 B: Clone,
130 S: Data<Elem=A>,
131 S2: DataOwned<Elem=B> + DataMut,
132 D: Dimension,
133 E: Dimension + DimMax<D>,
134{
135 type Output = ArrayBase<S2, <E as DimMax<D>>::Output>;
136 fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
137 where
138 {
139 if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
140 let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
141 out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
142 out
143 } else {
144 let (rhs, lhs) = rhs.broadcast_with(self).unwrap();
145 Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth))
146 }
147 }
148}
149
150#[doc=$doc]
152impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
160where
161 A: Clone + $trt<B, Output=A>,
162 B: Clone,
163 S: Data<Elem=A>,
164 S2: Data<Elem=B>,
165 D: Dimension + DimMax<E>,
166 E: Dimension,
167{
168 type Output = Array<A, <D as DimMax<E>>::Output>;
169 fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
170 let (lhs, rhs) = self.broadcast_with(rhs).unwrap();
171 Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth))
172 }
173}
174
175#[doc=$doc]
177impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
182 where A: Clone + $trt<B, Output=A>,
183 S: DataOwned<Elem=A> + DataMut,
184 D: Dimension,
185 B: ScalarOperand,
186{
187 type Output = ArrayBase<S, D>;
188 fn $mth(mut self, x: B) -> ArrayBase<S, D> {
189 self.map_inplace(move |elt| {
190 *elt = elt.clone() $operator x.clone();
191 });
192 self
193 }
194}
195
196#[doc=$doc]
198impl<'a, A, S, D, B> $trt<B> for &'a ArrayBase<S, D>
201 where A: Clone + $trt<B, Output=A>,
202 S: Data<Elem=A>,
203 D: Dimension,
204 B: ScalarOperand,
205{
206 type Output = Array<A, D>;
207 fn $mth(self, x: B) -> Self::Output {
208 self.map(move |elt| elt.clone() $operator x.clone())
209 }
210}
211 );
212);
213
214macro_rules! if_commutative {
216 (Commute { $a:expr } or { $b:expr }) => {
217 $a
218 };
219 (Ordered { $a:expr } or { $b:expr }) => {
220 $b
221 };
222}
223
224macro_rules! impl_scalar_lhs_op {
225 ($scalar:ty, $commutative:ident, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
228impl<S, D> $trt<ArrayBase<S, D>> for $scalar
233 where S: DataOwned<Elem=$scalar> + DataMut,
234 D: Dimension,
235{
236 type Output = ArrayBase<S, D>;
237 fn $mth(self, rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
238 if_commutative!($commutative {
239 rhs.$mth(self)
240 } or {{
241 let mut rhs = rhs;
242 rhs.map_inplace(move |elt| {
243 *elt = self $operator *elt;
244 });
245 rhs
246 }})
247 }
248}
249
250impl<'a, S, D> $trt<&'a ArrayBase<S, D>> for $scalar
254 where S: Data<Elem=$scalar>,
255 D: Dimension,
256{
257 type Output = Array<$scalar, D>;
258 fn $mth(self, rhs: &ArrayBase<S, D>) -> Self::Output {
259 if_commutative!($commutative {
260 rhs.$mth(self)
261 } or {
262 rhs.map(move |elt| self.clone() $operator elt.clone())
263 })
264 }
265}
266 );
267}
268
269mod arithmetic_ops {
270 use super::*;
271 use crate::imp_prelude::*;
272
273 use num_complex::Complex;
274 use std::ops::*;
275
276 fn clone_opf<A: Clone, B: Clone, C>(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C {
277 move |x, y| f(x.clone(), y.clone())
278 }
279
280 fn clone_iopf<A: Clone, B: Clone>(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B) {
281 move |x, y| *x = f(x.clone(), y.clone())
282 }
283
284 fn clone_iopf_rev<A: Clone, B: Clone>(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A) {
285 move |x, y| *x = f(y.clone(), x.clone())
286 }
287
288 impl_binary_op!(Add, +, add, +=, "addition");
289 impl_binary_op!(Sub, -, sub, -=, "subtraction");
290 impl_binary_op!(Mul, *, mul, *=, "multiplication");
291 impl_binary_op!(Div, /, div, /=, "division");
292 impl_binary_op!(Rem, %, rem, %=, "remainder");
293 impl_binary_op!(BitAnd, &, bitand, &=, "bit and");
294 impl_binary_op!(BitOr, |, bitor, |=, "bit or");
295 impl_binary_op!(BitXor, ^, bitxor, ^=, "bit xor");
296 impl_binary_op!(Shl, <<, shl, <<=, "left shift");
297 impl_binary_op!(Shr, >>, shr, >>=, "right shift");
298
299 macro_rules! all_scalar_ops {
300 ($int_scalar:ty) => (
301 impl_scalar_lhs_op!($int_scalar, Commute, +, Add, add, "addition");
302 impl_scalar_lhs_op!($int_scalar, Ordered, -, Sub, sub, "subtraction");
303 impl_scalar_lhs_op!($int_scalar, Commute, *, Mul, mul, "multiplication");
304 impl_scalar_lhs_op!($int_scalar, Ordered, /, Div, div, "division");
305 impl_scalar_lhs_op!($int_scalar, Ordered, %, Rem, rem, "remainder");
306 impl_scalar_lhs_op!($int_scalar, Commute, &, BitAnd, bitand, "bit and");
307 impl_scalar_lhs_op!($int_scalar, Commute, |, BitOr, bitor, "bit or");
308 impl_scalar_lhs_op!($int_scalar, Commute, ^, BitXor, bitxor, "bit xor");
309 impl_scalar_lhs_op!($int_scalar, Ordered, <<, Shl, shl, "left shift");
310 impl_scalar_lhs_op!($int_scalar, Ordered, >>, Shr, shr, "right shift");
311 );
312 }
313 all_scalar_ops!(i8);
314 all_scalar_ops!(u8);
315 all_scalar_ops!(i16);
316 all_scalar_ops!(u16);
317 all_scalar_ops!(i32);
318 all_scalar_ops!(u32);
319 all_scalar_ops!(i64);
320 all_scalar_ops!(u64);
321 all_scalar_ops!(i128);
322 all_scalar_ops!(u128);
323
324 impl_scalar_lhs_op!(bool, Commute, &, BitAnd, bitand, "bit and");
325 impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or");
326 impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor");
327
328 impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition");
329 impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction");
330 impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication");
331 impl_scalar_lhs_op!(f32, Ordered, /, Div, div, "division");
332 impl_scalar_lhs_op!(f32, Ordered, %, Rem, rem, "remainder");
333
334 impl_scalar_lhs_op!(f64, Commute, +, Add, add, "addition");
335 impl_scalar_lhs_op!(f64, Ordered, -, Sub, sub, "subtraction");
336 impl_scalar_lhs_op!(f64, Commute, *, Mul, mul, "multiplication");
337 impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division");
338 impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder");
339
340 impl_scalar_lhs_op!(Complex<f32>, Commute, +, Add, add, "addition");
341 impl_scalar_lhs_op!(Complex<f32>, Ordered, -, Sub, sub, "subtraction");
342 impl_scalar_lhs_op!(Complex<f32>, Commute, *, Mul, mul, "multiplication");
343 impl_scalar_lhs_op!(Complex<f32>, Ordered, /, Div, div, "division");
344
345 impl_scalar_lhs_op!(Complex<f64>, Commute, +, Add, add, "addition");
346 impl_scalar_lhs_op!(Complex<f64>, Ordered, -, Sub, sub, "subtraction");
347 impl_scalar_lhs_op!(Complex<f64>, Commute, *, Mul, mul, "multiplication");
348 impl_scalar_lhs_op!(Complex<f64>, Ordered, /, Div, div, "division");
349
350 impl<A, S, D> Neg for ArrayBase<S, D>
351 where
352 A: Clone + Neg<Output = A>,
353 S: DataOwned<Elem = A> + DataMut,
354 D: Dimension,
355 {
356 type Output = Self;
357 fn neg(mut self) -> Self {
359 self.map_inplace(|elt| {
360 *elt = -elt.clone();
361 });
362 self
363 }
364 }
365
366 impl<'a, A, S, D> Neg for &'a ArrayBase<S, D>
367 where
368 &'a A: 'a + Neg<Output = A>,
369 S: Data<Elem = A>,
370 D: Dimension,
371 {
372 type Output = Array<A, D>;
373 fn neg(self) -> Array<A, D> {
376 self.map(Neg::neg)
377 }
378 }
379
380 impl<A, S, D> Not for ArrayBase<S, D>
381 where
382 A: Clone + Not<Output = A>,
383 S: DataOwned<Elem = A> + DataMut,
384 D: Dimension,
385 {
386 type Output = Self;
387 fn not(mut self) -> Self {
389 self.map_inplace(|elt| {
390 *elt = !elt.clone();
391 });
392 self
393 }
394 }
395
396 impl<'a, A, S, D> Not for &'a ArrayBase<S, D>
397 where
398 &'a A: 'a + Not<Output = A>,
399 S: Data<Elem = A>,
400 D: Dimension,
401 {
402 type Output = Array<A, D>;
403 fn not(self) -> Array<A, D> {
406 self.map(Not::not)
407 }
408 }
409}
410
411mod assign_ops {
412 use super::*;
413 use crate::imp_prelude::*;
414
415 macro_rules! impl_assign_op {
416 ($trt:ident, $method:ident, $doc:expr) => {
417 use std::ops::$trt;
418
419 #[doc=$doc]
420 impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
424 where
425 A: Clone + $trt<A>,
426 S: DataMut<Elem = A>,
427 S2: Data<Elem = A>,
428 D: Dimension,
429 E: Dimension,
430 {
431 fn $method(&mut self, rhs: &ArrayBase<S2, E>) {
432 self.zip_mut_with(rhs, |x, y| {
433 x.$method(y.clone());
434 });
435 }
436 }
437
438 #[doc=$doc]
439 impl<A, S, D> $trt<A> for ArrayBase<S, D>
440 where
441 A: ScalarOperand + $trt<A>,
442 S: DataMut<Elem = A>,
443 D: Dimension,
444 {
445 fn $method(&mut self, rhs: A) {
446 self.map_inplace(move |elt| {
447 elt.$method(rhs.clone());
448 });
449 }
450 }
451 };
452 }
453
454 impl_assign_op!(
455 AddAssign,
456 add_assign,
457 "Perform `self += rhs` as elementwise addition (in place).\n"
458 );
459 impl_assign_op!(
460 SubAssign,
461 sub_assign,
462 "Perform `self -= rhs` as elementwise subtraction (in place).\n"
463 );
464 impl_assign_op!(
465 MulAssign,
466 mul_assign,
467 "Perform `self *= rhs` as elementwise multiplication (in place).\n"
468 );
469 impl_assign_op!(
470 DivAssign,
471 div_assign,
472 "Perform `self /= rhs` as elementwise division (in place).\n"
473 );
474 impl_assign_op!(
475 RemAssign,
476 rem_assign,
477 "Perform `self %= rhs` as elementwise remainder (in place).\n"
478 );
479 impl_assign_op!(
480 BitAndAssign,
481 bitand_assign,
482 "Perform `self &= rhs` as elementwise bit and (in place).\n"
483 );
484 impl_assign_op!(
485 BitOrAssign,
486 bitor_assign,
487 "Perform `self |= rhs` as elementwise bit or (in place).\n"
488 );
489 impl_assign_op!(
490 BitXorAssign,
491 bitxor_assign,
492 "Perform `self ^= rhs` as elementwise bit xor (in place).\n"
493 );
494 impl_assign_op!(
495 ShlAssign,
496 shl_assign,
497 "Perform `self <<= rhs` as elementwise left shift (in place).\n"
498 );
499 impl_assign_op!(
500 ShrAssign,
501 shr_assign,
502 "Perform `self >>= rhs` as elementwise right shift (in place).\n"
503 );
504}