ndarray/linalg/
impl_linalg.rs

1// Copyright 2014-2020 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::imp_prelude::*;
10use crate::numeric_util;
11
12use crate::{LinalgScalar, Zip};
13
14use std::any::TypeId;
15use alloc::vec::Vec;
16
17#[cfg(feature = "blas")]
18use std::cmp;
19#[cfg(feature = "blas")]
20use std::mem::swap;
21#[cfg(feature = "blas")]
22use libc::c_int;
23
24#[cfg(feature = "blas")]
25use cblas_sys as blas_sys;
26#[cfg(feature = "blas")]
27use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT};
28
29/// len of vector before we use blas
30#[cfg(feature = "blas")]
31const DOT_BLAS_CUTOFF: usize = 32;
32/// side of matrix before we use blas
33#[cfg(feature = "blas")]
34const GEMM_BLAS_CUTOFF: usize = 7;
35#[cfg(feature = "blas")]
36#[allow(non_camel_case_types)]
37type blas_index = c_int; // blas index type
38
39impl<A, S> ArrayBase<S, Ix1>
40where
41    S: Data<Elem = A>,
42{
43    /// Perform dot product or matrix multiplication of arrays `self` and `rhs`.
44    ///
45    /// `Rhs` may be either a one-dimensional or a two-dimensional array.
46    ///
47    /// If `Rhs` is one-dimensional, then the operation is a vector dot
48    /// product, which is the sum of the elementwise products (no conjugation
49    /// of complex operands, and thus not their inner product). In this case,
50    /// `self` and `rhs` must be the same length.
51    ///
52    /// If `Rhs` is two-dimensional, then the operation is matrix
53    /// multiplication, where `self` is treated as a row vector. In this case,
54    /// if `self` is shape *M*, then `rhs` is shape *M* × *N* and the result is
55    /// shape *N*.
56    ///
57    /// **Panics** if the array shapes are incompatible.<br>
58    /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
59    /// layout allows.
60    pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
61    where
62        Self: Dot<Rhs>,
63    {
64        Dot::dot(self, rhs)
65    }
66
67    fn dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
68    where
69        S2: Data<Elem = A>,
70        A: LinalgScalar,
71    {
72        debug_assert_eq!(self.len(), rhs.len());
73        assert!(self.len() == rhs.len());
74        if let Some(self_s) = self.as_slice() {
75            if let Some(rhs_s) = rhs.as_slice() {
76                return numeric_util::unrolled_dot(self_s, rhs_s);
77            }
78        }
79        let mut sum = A::zero();
80        for i in 0..self.len() {
81            unsafe {
82                sum = sum + *self.uget(i) * *rhs.uget(i);
83            }
84        }
85        sum
86    }
87
88    #[cfg(not(feature = "blas"))]
89    fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
90    where
91        S2: Data<Elem = A>,
92        A: LinalgScalar,
93    {
94        self.dot_generic(rhs)
95    }
96
97    #[cfg(feature = "blas")]
98    fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
99    where
100        S2: Data<Elem = A>,
101        A: LinalgScalar,
102    {
103        // Use only if the vector is large enough to be worth it
104        if self.len() >= DOT_BLAS_CUTOFF {
105            debug_assert_eq!(self.len(), rhs.len());
106            assert!(self.len() == rhs.len());
107            macro_rules! dot {
108                ($ty:ty, $func:ident) => {{
109                    if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) {
110                        unsafe {
111                            let (lhs_ptr, n, incx) =
112                                blas_1d_params(self.ptr.as_ptr(), self.len(), self.strides()[0]);
113                            let (rhs_ptr, _, incy) =
114                                blas_1d_params(rhs.ptr.as_ptr(), rhs.len(), rhs.strides()[0]);
115                            let ret = blas_sys::$func(
116                                n,
117                                lhs_ptr as *const $ty,
118                                incx,
119                                rhs_ptr as *const $ty,
120                                incy,
121                            );
122                            return cast_as::<$ty, A>(&ret);
123                        }
124                    }
125                }};
126            }
127
128            dot! {f32, cblas_sdot};
129            dot! {f64, cblas_ddot};
130        }
131        self.dot_generic(rhs)
132    }
133}
134
135/// Return a pointer to the starting element in BLAS's view.
136///
137/// BLAS wants a pointer to the element with lowest address,
138/// which agrees with our pointer for non-negative strides, but
139/// is at the opposite end for negative strides.
140#[cfg(feature = "blas")]
141unsafe fn blas_1d_params<A>(
142    ptr: *const A,
143    len: usize,
144    stride: isize,
145) -> (*const A, blas_index, blas_index) {
146    // [x x x x]
147    //        ^--ptr
148    //        stride = -1
149    //  ^--blas_ptr = ptr + (len - 1) * stride
150    if stride >= 0 || len == 0 {
151        (ptr, len as blas_index, stride as blas_index)
152    } else {
153        let ptr = ptr.offset((len - 1) as isize * stride);
154        (ptr, len as blas_index, stride as blas_index)
155    }
156}
157
158/// Matrix Multiplication
159///
160/// For two-dimensional arrays, the dot method computes the matrix
161/// multiplication.
162pub trait Dot<Rhs> {
163    /// The result of the operation.
164    ///
165    /// For two-dimensional arrays: a rectangular array.
166    type Output;
167    fn dot(&self, rhs: &Rhs) -> Self::Output;
168}
169
170impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix1>
171where
172    S: Data<Elem = A>,
173    S2: Data<Elem = A>,
174    A: LinalgScalar,
175{
176    type Output = A;
177
178    /// Compute the dot product of one-dimensional arrays.
179    ///
180    /// The dot product is a sum of the elementwise products (no conjugation
181    /// of complex operands, and thus not their inner product).
182    ///
183    /// **Panics** if the arrays are not of the same length.<br>
184    /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory
185    /// layout allows.
186    fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> A {
187        self.dot_impl(rhs)
188    }
189}
190
191impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix1>
192where
193    S: Data<Elem = A>,
194    S2: Data<Elem = A>,
195    A: LinalgScalar,
196{
197    type Output = Array<A, Ix1>;
198
199    /// Perform the matrix multiplication of the row vector `self` and
200    /// rectangular matrix `rhs`.
201    ///
202    /// The array shapes must agree in the way that
203    /// if `self` is *M*, then `rhs` is *M* × *N*.
204    ///
205    /// Return a result array with shape *N*.
206    ///
207    /// **Panics** if shapes are incompatible.
208    fn dot(&self, rhs: &ArrayBase<S2, Ix2>) -> Array<A, Ix1> {
209        rhs.t().dot(self)
210    }
211}
212
213impl<A, S> ArrayBase<S, Ix2>
214where
215    S: Data<Elem = A>,
216{
217    /// Perform matrix multiplication of rectangular arrays `self` and `rhs`.
218    ///
219    /// `Rhs` may be either a one-dimensional or a two-dimensional array.
220    ///
221    /// If Rhs is two-dimensional, they array shapes must agree in the way that
222    /// if `self` is *M* × *N*, then `rhs` is *N* × *K*.
223    ///
224    /// Return a result array with shape *M* × *K*.
225    ///
226    /// **Panics** if shapes are incompatible or the number of elements in the
227    /// result would overflow `isize`.
228    ///
229    /// *Note:* If enabled, uses blas `gemv/gemm` for elements of `f32, f64`
230    /// when memory layout allows. The default matrixmultiply backend
231    /// is otherwise used for `f32, f64` for all memory layouts.
232    ///
233    /// ```
234    /// use ndarray::arr2;
235    ///
236    /// let a = arr2(&[[1., 2.],
237    ///                [0., 1.]]);
238    /// let b = arr2(&[[1., 2.],
239    ///                [2., 3.]]);
240    ///
241    /// assert!(
242    ///     a.dot(&b) == arr2(&[[5., 8.],
243    ///                         [2., 3.]])
244    /// );
245    /// ```
246    pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
247    where
248        Self: Dot<Rhs>,
249    {
250        Dot::dot(self, rhs)
251    }
252}
253
254impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix2>
255where
256    S: Data<Elem = A>,
257    S2: Data<Elem = A>,
258    A: LinalgScalar,
259{
260    type Output = Array2<A>;
261    fn dot(&self, b: &ArrayBase<S2, Ix2>) -> Array2<A> {
262        let a = self.view();
263        let b = b.view();
264        let ((m, k), (k2, n)) = (a.dim(), b.dim());
265        if k != k2 || m.checked_mul(n).is_none() {
266            dot_shape_error(m, k, k2, n);
267        }
268
269        let lhs_s0 = a.strides()[0];
270        let rhs_s0 = b.strides()[0];
271        let column_major = lhs_s0 == 1 && rhs_s0 == 1;
272        // A is Copy so this is safe
273        let mut v = Vec::with_capacity(m * n);
274        let mut c;
275        unsafe {
276            v.set_len(m * n);
277            c = Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
278        }
279        mat_mul_impl(A::one(), &a, &b, A::zero(), &mut c.view_mut());
280        c
281    }
282}
283
284/// Assumes that `m` and `n` are ≤ `isize::MAX`.
285#[cold]
286#[inline(never)]
287fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! {
288    match m.checked_mul(n) {
289        Some(len) if len <= ::std::isize::MAX as usize => {}
290        _ => panic!("ndarray: shape {} × {} overflows isize", m, n),
291    }
292    panic!(
293        "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
294        m, k, k2, n
295    );
296}
297
298#[cold]
299#[inline(never)]
300fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! {
301    panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication",
302           m, k, k2, n, c1, c2);
303}
304
305/// Perform the matrix multiplication of the rectangular array `self` and
306/// column vector `rhs`.
307///
308/// The array shapes must agree in the way that
309/// if `self` is *M* × *N*, then `rhs` is *N*.
310///
311/// Return a result array with shape *M*.
312///
313/// **Panics** if shapes are incompatible.
314impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix2>
315where
316    S: Data<Elem = A>,
317    S2: Data<Elem = A>,
318    A: LinalgScalar,
319{
320    type Output = Array<A, Ix1>;
321    fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> Array<A, Ix1> {
322        let ((m, a), n) = (self.dim(), rhs.dim());
323        if a != n {
324            dot_shape_error(m, a, n, 1);
325        }
326
327        // Avoid initializing the memory in vec -- set it during iteration
328        unsafe {
329            let mut c = Array1::uninit(m);
330            general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::<A>());
331            c.assume_init()
332        }
333    }
334}
335
336impl<A, S, D> ArrayBase<S, D>
337where
338    S: Data<Elem = A>,
339    D: Dimension,
340{
341    /// Perform the operation `self += alpha * rhs` efficiently, where
342    /// `alpha` is a scalar and `rhs` is another array. This operation is
343    /// also known as `axpy` in BLAS.
344    ///
345    /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
346    ///
347    /// **Panics** if broadcasting isn’t possible.
348    pub fn scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
349    where
350        S: DataMut,
351        S2: Data<Elem = A>,
352        A: LinalgScalar,
353        E: Dimension,
354    {
355        self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
356    }
357}
358
359// mat_mul_impl uses ArrayView arguments to send all array kinds into
360// the same instantiated implementation.
361#[cfg(not(feature = "blas"))]
362use self::mat_mul_general as mat_mul_impl;
363
364#[cfg(feature = "blas")]
365fn mat_mul_impl<A>(
366    alpha: A,
367    lhs: &ArrayView2<'_, A>,
368    rhs: &ArrayView2<'_, A>,
369    beta: A,
370    c: &mut ArrayViewMut2<'_, A>,
371) where
372    A: LinalgScalar,
373{
374    // size cutoff for using BLAS
375    let cut = GEMM_BLAS_CUTOFF;
376    let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
377    if !(m > cut || n > cut || a > cut) || !(same_type::<A, f32>() || same_type::<A, f64>()) {
378        return mat_mul_general(alpha, lhs, rhs, beta, c);
379    }
380    {
381        // Use `c` for c-order and `f` for an f-order matrix
382        // We can handle c * c, f * f generally and
383        // c * f and f * c if the `f` matrix is square.
384        let mut lhs_ = lhs.view();
385        let mut rhs_ = rhs.view();
386        let mut c_ = c.view_mut();
387        let lhs_s0 = lhs_.strides()[0];
388        let rhs_s0 = rhs_.strides()[0];
389        let both_f = lhs_s0 == 1 && rhs_s0 == 1;
390        let mut lhs_trans = CblasNoTrans;
391        let mut rhs_trans = CblasNoTrans;
392        if both_f {
393            // A^t B^t = C^t => B A = C
394            let lhs_t = lhs_.reversed_axes();
395            lhs_ = rhs_.reversed_axes();
396            rhs_ = lhs_t;
397            c_ = c_.reversed_axes();
398            swap(&mut m, &mut n);
399        } else if lhs_s0 == 1 && m == a {
400            lhs_ = lhs_.reversed_axes();
401            lhs_trans = CblasTrans;
402        } else if rhs_s0 == 1 && a == n {
403            rhs_ = rhs_.reversed_axes();
404            rhs_trans = CblasTrans;
405        }
406
407        macro_rules! gemm {
408            ($ty:ty, $gemm:ident) => {
409                if blas_row_major_2d::<$ty, _>(&lhs_)
410                    && blas_row_major_2d::<$ty, _>(&rhs_)
411                    && blas_row_major_2d::<$ty, _>(&c_)
412                {
413                    let (m, k) = match lhs_trans {
414                        CblasNoTrans => lhs_.dim(),
415                        _ => {
416                            let (rows, cols) = lhs_.dim();
417                            (cols, rows)
418                        }
419                    };
420                    let n = match rhs_trans {
421                        CblasNoTrans => rhs_.raw_dim()[1],
422                        _ => rhs_.raw_dim()[0],
423                    };
424                    // adjust strides, these may [1, 1] for column matrices
425                    let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index);
426                    let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index);
427                    let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index);
428
429                    // gemm is C ← αA^Op B^Op + βC
430                    // Where Op is notrans/trans/conjtrans
431                    unsafe {
432                        blas_sys::$gemm(
433                            CblasRowMajor,
434                            lhs_trans,
435                            rhs_trans,
436                            m as blas_index,               // m, rows of Op(a)
437                            n as blas_index,               // n, cols of Op(b)
438                            k as blas_index,               // k, cols of Op(a)
439                            cast_as(&alpha),               // alpha
440                            lhs_.ptr.as_ptr() as *const _, // a
441                            lhs_stride,                    // lda
442                            rhs_.ptr.as_ptr() as *const _, // b
443                            rhs_stride,                    // ldb
444                            cast_as(&beta),                // beta
445                            c_.ptr.as_ptr() as *mut _,     // c
446                            c_stride,                      // ldc
447                        );
448                    }
449                    return;
450                }
451            };
452        }
453        gemm!(f32, cblas_sgemm);
454        gemm!(f64, cblas_dgemm);
455    }
456    mat_mul_general(alpha, lhs, rhs, beta, c)
457}
458
459/// C ← α A B + β C
460fn mat_mul_general<A>(
461    alpha: A,
462    lhs: &ArrayView2<'_, A>,
463    rhs: &ArrayView2<'_, A>,
464    beta: A,
465    c: &mut ArrayViewMut2<'_, A>,
466) where
467    A: LinalgScalar,
468{
469    let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
470
471    // common parameters for gemm
472    let ap = lhs.as_ptr();
473    let bp = rhs.as_ptr();
474    let cp = c.as_mut_ptr();
475    let (rsc, csc) = (c.strides()[0], c.strides()[1]);
476    if same_type::<A, f32>() {
477        unsafe {
478            ::matrixmultiply::sgemm(
479                m,
480                k,
481                n,
482                cast_as(&alpha),
483                ap as *const _,
484                lhs.strides()[0],
485                lhs.strides()[1],
486                bp as *const _,
487                rhs.strides()[0],
488                rhs.strides()[1],
489                cast_as(&beta),
490                cp as *mut _,
491                rsc,
492                csc,
493            );
494        }
495    } else if same_type::<A, f64>() {
496        unsafe {
497            ::matrixmultiply::dgemm(
498                m,
499                k,
500                n,
501                cast_as(&alpha),
502                ap as *const _,
503                lhs.strides()[0],
504                lhs.strides()[1],
505                bp as *const _,
506                rhs.strides()[0],
507                rhs.strides()[1],
508                cast_as(&beta),
509                cp as *mut _,
510                rsc,
511                csc,
512            );
513        }
514    } else {
515        // It's a no-op if `c` has zero length.
516        if c.is_empty() {
517            return;
518        }
519
520        // initialize memory if beta is zero
521        if beta.is_zero() {
522            c.fill(beta);
523        }
524
525        let mut i = 0;
526        let mut j = 0;
527        loop {
528            unsafe {
529                let elt = c.uget_mut((i, j));
530                *elt = *elt * beta
531                    + alpha
532                        * (0..k).fold(A::zero(), move |s, x| {
533                            s + *lhs.uget((i, x)) * *rhs.uget((x, j))
534                        });
535            }
536            j += 1;
537            if j == n {
538                j = 0;
539                i += 1;
540                if i == m {
541                    break;
542                }
543            }
544        }
545    }
546}
547
548/// General matrix-matrix multiplication.
549///
550/// Compute C ← α A B + β C
551///
552/// The array shapes must agree in the way that
553/// if `a` is *M* × *N*, then `b` is *N* × *K* and `c` is *M* × *K*.
554///
555/// ***Panics*** if array shapes are not compatible<br>
556/// *Note:* If enabled, uses blas `gemm` for elements of `f32, f64` when memory
557/// layout allows.  The default matrixmultiply backend is otherwise used for
558/// `f32, f64` for all memory layouts.
559pub fn general_mat_mul<A, S1, S2, S3>(
560    alpha: A,
561    a: &ArrayBase<S1, Ix2>,
562    b: &ArrayBase<S2, Ix2>,
563    beta: A,
564    c: &mut ArrayBase<S3, Ix2>,
565) where
566    S1: Data<Elem = A>,
567    S2: Data<Elem = A>,
568    S3: DataMut<Elem = A>,
569    A: LinalgScalar,
570{
571    let ((m, k), (k2, n)) = (a.dim(), b.dim());
572    let (m2, n2) = c.dim();
573    if k != k2 || m != m2 || n != n2 {
574        general_dot_shape_error(m, k, k2, n, m2, n2);
575    } else {
576        mat_mul_impl(alpha, &a.view(), &b.view(), beta, &mut c.view_mut());
577    }
578}
579
580/// General matrix-vector multiplication.
581///
582/// Compute y ← α A x + β y
583///
584/// where A is a *M* × *N* matrix and x is an *N*-element column vector and
585/// y an *M*-element column vector (one dimensional arrays).
586///
587/// ***Panics*** if array shapes are not compatible<br>
588/// *Note:* If enabled, uses blas `gemv` for elements of `f32, f64` when memory
589/// layout allows.
590#[allow(clippy::collapsible_if)]
591pub fn general_mat_vec_mul<A, S1, S2, S3>(
592    alpha: A,
593    a: &ArrayBase<S1, Ix2>,
594    x: &ArrayBase<S2, Ix1>,
595    beta: A,
596    y: &mut ArrayBase<S3, Ix1>,
597) where
598    S1: Data<Elem = A>,
599    S2: Data<Elem = A>,
600    S3: DataMut<Elem = A>,
601    A: LinalgScalar,
602{
603    unsafe {
604        general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut())
605    }
606}
607
608/// General matrix-vector multiplication
609///
610/// Use a raw view for the destination vector, so that it can be uninitalized.
611///
612/// ## Safety
613///
614/// The caller must ensure that the raw view is valid for writing.
615/// the destination may be uninitialized iff beta is zero.
616#[allow(clippy::collapsible_else_if)]
617unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
618    alpha: A,
619    a: &ArrayBase<S1, Ix2>,
620    x: &ArrayBase<S2, Ix1>,
621    beta: A,
622    y: RawArrayViewMut<A, Ix1>,
623) where
624    S1: Data<Elem = A>,
625    S2: Data<Elem = A>,
626    A: LinalgScalar,
627{
628    let ((m, k), k2) = (a.dim(), x.dim());
629    let m2 = y.dim();
630    if k != k2 || m != m2 {
631        general_dot_shape_error(m, k, k2, 1, m2, 1);
632    } else {
633        #[cfg(feature = "blas")]
634        macro_rules! gemv {
635            ($ty:ty, $gemv:ident) => {
636                if let Some(layout) = blas_layout::<$ty, _>(&a) {
637                    if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) {
638                        // Determine stride between rows or columns. Note that the stride is
639                        // adjusted to at least `k` or `m` to handle the case of a matrix with a
640                        // trivial (length 1) dimension, since the stride for the trivial dimension
641                        // may be arbitrary.
642                        let a_trans = CblasNoTrans;
643                        let a_stride = match layout {
644                            CBLAS_LAYOUT::CblasRowMajor => {
645                                a.strides()[0].max(k as isize) as blas_index
646                            }
647                            CBLAS_LAYOUT::CblasColMajor => {
648                                a.strides()[1].max(m as isize) as blas_index
649                            }
650                        };
651
652                        let x_stride = x.strides()[0] as blas_index;
653                        let y_stride = y.strides()[0] as blas_index;
654
655                        blas_sys::$gemv(
656                            layout,
657                            a_trans,
658                            m as blas_index,            // m, rows of Op(a)
659                            k as blas_index,            // n, cols of Op(a)
660                            cast_as(&alpha),            // alpha
661                            a.ptr.as_ptr() as *const _, // a
662                            a_stride,                   // lda
663                            x.ptr.as_ptr() as *const _, // x
664                            x_stride,
665                            cast_as(&beta),           // beta
666                            y.ptr.as_ptr() as *mut _, // x
667                            y_stride,
668                        );
669                        return;
670                    }
671                }
672            };
673        }
674        #[cfg(feature = "blas")]
675        gemv!(f32, cblas_sgemv);
676        #[cfg(feature = "blas")]
677        gemv!(f64, cblas_dgemv);
678
679        /* general */
680
681        if beta.is_zero() {
682            // when beta is zero, c may be uninitialized
683            Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
684                elt.write(row.dot(x) * alpha);
685            });
686        } else {
687            Zip::from(a.outer_iter()).and(y).for_each(|row, elt| {
688                *elt = *elt * beta + row.dot(x) * alpha;
689            });
690        }
691    }
692}
693
694#[inline(always)]
695/// Return `true` if `A` and `B` are the same type
696fn same_type<A: 'static, B: 'static>() -> bool {
697    TypeId::of::<A>() == TypeId::of::<B>()
698}
699
700// Read pointer to type `A` as type `B`.
701//
702// **Panics** if `A` and `B` are not the same type
703fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B {
704    assert!(same_type::<A, B>());
705    unsafe { ::std::ptr::read(a as *const _ as *const B) }
706}
707
708#[cfg(feature = "blas")]
709fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
710where
711    S: RawData,
712    A: 'static,
713    S::Elem: 'static,
714{
715    if !same_type::<A, S::Elem>() {
716        return false;
717    }
718    if a.len() > blas_index::max_value() as usize {
719        return false;
720    }
721    let stride = a.strides()[0];
722    if stride > blas_index::max_value() as isize || stride < blas_index::min_value() as isize {
723        return false;
724    }
725    true
726}
727
728#[cfg(feature = "blas")]
729enum MemoryOrder {
730    C,
731    F,
732}
733
734#[cfg(feature = "blas")]
735fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
736where
737    S: Data,
738    A: 'static,
739    S::Elem: 'static,
740{
741    if !same_type::<A, S::Elem>() {
742        return false;
743    }
744    is_blas_2d(&a.dim, &a.strides, MemoryOrder::C)
745}
746
747#[cfg(feature = "blas")]
748fn blas_column_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
749where
750    S: Data,
751    A: 'static,
752    S::Elem: 'static,
753{
754    if !same_type::<A, S::Elem>() {
755        return false;
756    }
757    is_blas_2d(&a.dim, &a.strides, MemoryOrder::F)
758}
759
760#[cfg(feature = "blas")]
761fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool {
762    let (m, n) = dim.into_pattern();
763    let s0 = stride[0] as isize;
764    let s1 = stride[1] as isize;
765    let (inner_stride, outer_dim) = match order {
766        MemoryOrder::C => (s1, n),
767        MemoryOrder::F => (s0, m),
768    };
769    if !(inner_stride == 1 || outer_dim == 1) {
770        return false;
771    }
772    if s0 < 1 || s1 < 1 {
773        return false;
774    }
775    if (s0 > blas_index::max_value() as isize || s0 < blas_index::min_value() as isize)
776        || (s1 > blas_index::max_value() as isize || s1 < blas_index::min_value() as isize)
777    {
778        return false;
779    }
780    if m > blas_index::max_value() as usize || n > blas_index::max_value() as usize {
781        return false;
782    }
783    true
784}
785
786#[cfg(feature = "blas")]
787fn blas_layout<A, S>(a: &ArrayBase<S, Ix2>) -> Option<CBLAS_LAYOUT>
788where
789    S: Data,
790    A: 'static,
791    S::Elem: 'static,
792{
793    if blas_row_major_2d::<A, _>(a) {
794        Some(CBLAS_LAYOUT::CblasRowMajor)
795    } else if blas_column_major_2d::<A, _>(a) {
796        Some(CBLAS_LAYOUT::CblasColMajor)
797    } else {
798        None
799    }
800}
801
802#[cfg(test)]
803#[cfg(feature = "blas")]
804mod blas_tests {
805    use super::*;
806
807    #[test]
808    fn blas_row_major_2d_normal_matrix() {
809        let m: Array2<f32> = Array2::zeros((3, 5));
810        assert!(blas_row_major_2d::<f32, _>(&m));
811        assert!(!blas_column_major_2d::<f32, _>(&m));
812    }
813
814    #[test]
815    fn blas_row_major_2d_row_matrix() {
816        let m: Array2<f32> = Array2::zeros((1, 5));
817        assert!(blas_row_major_2d::<f32, _>(&m));
818        assert!(blas_column_major_2d::<f32, _>(&m));
819    }
820
821    #[test]
822    fn blas_row_major_2d_column_matrix() {
823        let m: Array2<f32> = Array2::zeros((5, 1));
824        assert!(blas_row_major_2d::<f32, _>(&m));
825        assert!(blas_column_major_2d::<f32, _>(&m));
826    }
827
828    #[test]
829    fn blas_row_major_2d_transposed_row_matrix() {
830        let m: Array2<f32> = Array2::zeros((1, 5));
831        let m_t = m.t();
832        assert!(blas_row_major_2d::<f32, _>(&m_t));
833        assert!(blas_column_major_2d::<f32, _>(&m_t));
834    }
835
836    #[test]
837    fn blas_row_major_2d_transposed_column_matrix() {
838        let m: Array2<f32> = Array2::zeros((5, 1));
839        let m_t = m.t();
840        assert!(blas_row_major_2d::<f32, _>(&m_t));
841        assert!(blas_column_major_2d::<f32, _>(&m_t));
842    }
843
844    #[test]
845    fn blas_column_major_2d_normal_matrix() {
846        let m: Array2<f32> = Array2::zeros((3, 5).f());
847        assert!(!blas_row_major_2d::<f32, _>(&m));
848        assert!(blas_column_major_2d::<f32, _>(&m));
849    }
850}