nalgebra/linalg/
bidiagonal.rs

1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use crate::allocator::Allocator;
5use crate::base::{DefaultAllocator, Matrix, OMatrix, OVector, Unit};
6use crate::dimension::{Const, Dim, DimDiff, DimMin, DimMinimum, DimSub, U1};
7use simba::scalar::ComplexField;
8
9use crate::geometry::Reflection;
10use crate::linalg::householder;
11use std::mem::MaybeUninit;
12
13/// The bidiagonalization of a general matrix.
14#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
15#[cfg_attr(
16    feature = "serde-serialize-no-std",
17    serde(bound(serialize = "DimMinimum<R, C>: DimSub<U1>,
18         DefaultAllocator: Allocator<T, R, C>             +
19                           Allocator<T, DimMinimum<R, C>> +
20                           Allocator<T, DimDiff<DimMinimum<R, C>, U1>>,
21         OMatrix<T, R, C>: Serialize,
22         OVector<T, DimMinimum<R, C>>: Serialize,
23         OVector<T, DimDiff<DimMinimum<R, C>, U1>>: Serialize"))
24)]
25#[cfg_attr(
26    feature = "serde-serialize-no-std",
27    serde(bound(deserialize = "DimMinimum<R, C>: DimSub<U1>,
28         DefaultAllocator: Allocator<T, R, C>             +
29                           Allocator<T, DimMinimum<R, C>> +
30                           Allocator<T, DimDiff<DimMinimum<R, C>, U1>>,
31         OMatrix<T, R, C>: Deserialize<'de>,
32         OVector<T, DimMinimum<R, C>>: Deserialize<'de>,
33         OVector<T, DimDiff<DimMinimum<R, C>, U1>>: Deserialize<'de>"))
34)]
35#[derive(Clone, Debug)]
36pub struct Bidiagonal<T: ComplexField, R: DimMin<C>, C: Dim>
37where
38    DimMinimum<R, C>: DimSub<U1>,
39    DefaultAllocator: Allocator<T, R, C>
40        + Allocator<T, DimMinimum<R, C>>
41        + Allocator<T, DimDiff<DimMinimum<R, C>, U1>>,
42{
43    // TODO: perhaps we should pack the axes into different vectors so that axes for `v_t` are
44    // contiguous. This prevents some useless copies.
45    uv: OMatrix<T, R, C>,
46    /// The diagonal elements of the decomposed matrix.
47    diagonal: OVector<T, DimMinimum<R, C>>,
48    /// The off-diagonal elements of the decomposed matrix.
49    off_diagonal: OVector<T, DimDiff<DimMinimum<R, C>, U1>>,
50    upper_diagonal: bool,
51}
52
53impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for Bidiagonal<T, R, C>
54where
55    DimMinimum<R, C>: DimSub<U1>,
56    DefaultAllocator: Allocator<T, R, C>
57        + Allocator<T, DimMinimum<R, C>>
58        + Allocator<T, DimDiff<DimMinimum<R, C>, U1>>,
59    OMatrix<T, R, C>: Copy,
60    OVector<T, DimMinimum<R, C>>: Copy,
61    OVector<T, DimDiff<DimMinimum<R, C>, U1>>: Copy,
62{
63}
64
65impl<T: ComplexField, R: DimMin<C>, C: Dim> Bidiagonal<T, R, C>
66where
67    DimMinimum<R, C>: DimSub<U1>,
68    DefaultAllocator: Allocator<T, R, C>
69        + Allocator<T, C>
70        + Allocator<T, R>
71        + Allocator<T, DimMinimum<R, C>>
72        + Allocator<T, DimDiff<DimMinimum<R, C>, U1>>,
73{
74    /// Computes the Bidiagonal decomposition using householder reflections.
75    pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
76        let (nrows, ncols) = matrix.shape_generic();
77        let min_nrows_ncols = nrows.min(ncols);
78        let dim = min_nrows_ncols.value();
79        assert!(
80            dim != 0,
81            "Cannot compute the bidiagonalization of an empty matrix."
82        );
83
84        let mut diagonal = Matrix::uninit(min_nrows_ncols, Const::<1>);
85        let mut off_diagonal = Matrix::uninit(min_nrows_ncols.sub(Const::<1>), Const::<1>);
86        let mut axis_packed = Matrix::zeros_generic(ncols, Const::<1>);
87        let mut work = Matrix::zeros_generic(nrows, Const::<1>);
88
89        let upper_diagonal = nrows.value() >= ncols.value();
90        if upper_diagonal {
91            for ite in 0..dim - 1 {
92                diagonal[ite] = MaybeUninit::new(householder::clear_column_unchecked(
93                    &mut matrix,
94                    ite,
95                    0,
96                    None,
97                ));
98                off_diagonal[ite] = MaybeUninit::new(householder::clear_row_unchecked(
99                    &mut matrix,
100                    &mut axis_packed,
101                    &mut work,
102                    ite,
103                    1,
104                ));
105            }
106
107            diagonal[dim - 1] = MaybeUninit::new(householder::clear_column_unchecked(
108                &mut matrix,
109                dim - 1,
110                0,
111                None,
112            ));
113        } else {
114            for ite in 0..dim - 1 {
115                diagonal[ite] = MaybeUninit::new(householder::clear_row_unchecked(
116                    &mut matrix,
117                    &mut axis_packed,
118                    &mut work,
119                    ite,
120                    0,
121                ));
122                off_diagonal[ite] = MaybeUninit::new(householder::clear_column_unchecked(
123                    &mut matrix,
124                    ite,
125                    1,
126                    None,
127                ));
128            }
129
130            diagonal[dim - 1] = MaybeUninit::new(householder::clear_row_unchecked(
131                &mut matrix,
132                &mut axis_packed,
133                &mut work,
134                dim - 1,
135                0,
136            ));
137        }
138
139        // Safety: diagonal and off_diagonal have been fully initialized.
140        let (diagonal, off_diagonal) =
141            unsafe { (diagonal.assume_init(), off_diagonal.assume_init()) };
142
143        Bidiagonal {
144            uv: matrix,
145            diagonal,
146            off_diagonal,
147            upper_diagonal,
148        }
149    }
150
151    /// Indicates whether this decomposition contains an upper-diagonal matrix.
152    #[inline]
153    #[must_use]
154    pub fn is_upper_diagonal(&self) -> bool {
155        self.upper_diagonal
156    }
157
158    #[inline]
159    fn axis_shift(&self) -> (usize, usize) {
160        if self.upper_diagonal {
161            (0, 1)
162        } else {
163            (1, 0)
164        }
165    }
166
167    /// Unpacks this decomposition into its three matrix factors `(U, D, V^t)`.
168    ///
169    /// The decomposed matrix `M` is equal to `U * D * V^t`.
170    #[inline]
171    pub fn unpack(
172        self,
173    ) -> (
174        OMatrix<T, R, DimMinimum<R, C>>,
175        OMatrix<T, DimMinimum<R, C>, DimMinimum<R, C>>,
176        OMatrix<T, DimMinimum<R, C>, C>,
177    )
178    where
179        DefaultAllocator: Allocator<T, DimMinimum<R, C>, DimMinimum<R, C>>
180            + Allocator<T, R, DimMinimum<R, C>>
181            + Allocator<T, DimMinimum<R, C>, C>,
182    {
183        // TODO: optimize by calling a reallocator.
184        (self.u(), self.d(), self.v_t())
185    }
186
187    /// Retrieves the upper trapezoidal submatrix `R` of this decomposition.
188    #[inline]
189    #[must_use]
190    pub fn d(&self) -> OMatrix<T, DimMinimum<R, C>, DimMinimum<R, C>>
191    where
192        DefaultAllocator: Allocator<T, DimMinimum<R, C>, DimMinimum<R, C>>,
193    {
194        let (nrows, ncols) = self.uv.shape_generic();
195
196        let d = nrows.min(ncols);
197        let mut res = OMatrix::identity_generic(d, d);
198        res.set_partial_diagonal(
199            self.diagonal
200                .iter()
201                .map(|e| T::from_real(e.clone().modulus())),
202        );
203
204        let start = self.axis_shift();
205        res.slice_mut(start, (d.value() - 1, d.value() - 1))
206            .set_partial_diagonal(
207                self.off_diagonal
208                    .iter()
209                    .map(|e| T::from_real(e.clone().modulus())),
210            );
211        res
212    }
213
214    /// Computes the orthogonal matrix `U` of this `U * D * V` decomposition.
215    // TODO: code duplication with householder::assemble_q.
216    // Except that we are returning a rectangular matrix here.
217    #[must_use]
218    pub fn u(&self) -> OMatrix<T, R, DimMinimum<R, C>>
219    where
220        DefaultAllocator: Allocator<T, R, DimMinimum<R, C>>,
221    {
222        let (nrows, ncols) = self.uv.shape_generic();
223
224        let mut res = Matrix::identity_generic(nrows, nrows.min(ncols));
225        let dim = self.diagonal.len();
226        let shift = self.axis_shift().0;
227
228        for i in (0..dim - shift).rev() {
229            let axis = self.uv.slice_range(i + shift.., i);
230            // TODO: sometimes, the axis might have a zero magnitude.
231            let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
232
233            let mut res_rows = res.slice_range_mut(i + shift.., i..);
234
235            let sign = if self.upper_diagonal {
236                self.diagonal[i].clone().signum()
237            } else {
238                self.off_diagonal[i].clone().signum()
239            };
240
241            refl.reflect_with_sign(&mut res_rows, sign);
242        }
243
244        res
245    }
246
247    /// Computes the orthogonal matrix `V_t` of this `U * D * V_t` decomposition.
248    #[must_use]
249    pub fn v_t(&self) -> OMatrix<T, DimMinimum<R, C>, C>
250    where
251        DefaultAllocator: Allocator<T, DimMinimum<R, C>, C>,
252    {
253        let (nrows, ncols) = self.uv.shape_generic();
254        let min_nrows_ncols = nrows.min(ncols);
255
256        let mut res = Matrix::identity_generic(min_nrows_ncols, ncols);
257        let mut work = Matrix::zeros_generic(min_nrows_ncols, Const::<1>);
258        let mut axis_packed = Matrix::zeros_generic(ncols, Const::<1>);
259
260        let shift = self.axis_shift().1;
261
262        for i in (0..min_nrows_ncols.value() - shift).rev() {
263            let axis = self.uv.slice_range(i, i + shift..);
264            let mut axis_packed = axis_packed.rows_range_mut(i + shift..);
265            axis_packed.tr_copy_from(&axis);
266            // TODO: sometimes, the axis might have a zero magnitude.
267            let refl = Reflection::new(Unit::new_unchecked(axis_packed), T::zero());
268
269            let mut res_rows = res.slice_range_mut(i.., i + shift..);
270
271            let sign = if self.upper_diagonal {
272                self.off_diagonal[i].clone().signum()
273            } else {
274                self.diagonal[i].clone().signum()
275            };
276
277            refl.reflect_rows_with_sign(&mut res_rows, &mut work.rows_range_mut(i..), sign);
278        }
279
280        res
281    }
282
283    /// The diagonal part of this decomposed matrix.
284    #[must_use]
285    pub fn diagonal(&self) -> OVector<T::RealField, DimMinimum<R, C>>
286    where
287        DefaultAllocator: Allocator<T::RealField, DimMinimum<R, C>>,
288    {
289        self.diagonal.map(|e| e.modulus())
290    }
291
292    /// The off-diagonal part of this decomposed matrix.
293    #[must_use]
294    pub fn off_diagonal(&self) -> OVector<T::RealField, DimDiff<DimMinimum<R, C>, U1>>
295    where
296        DefaultAllocator: Allocator<T::RealField, DimDiff<DimMinimum<R, C>, U1>>,
297    {
298        self.off_diagonal.map(|e| e.modulus())
299    }
300
301    #[doc(hidden)]
302    pub fn uv_internal(&self) -> &OMatrix<T, R, C> {
303        &self.uv
304    }
305}
306
307// impl<T: ComplexField, D: DimMin<D, Output = D> + DimSub<Dynamic>> Bidiagonal<T, D, D>
308//     where DefaultAllocator: Allocator<T, D, D> +
309//                             Allocator<T, D> {
310//     /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
311//     pub fn solve<R2: Dim, C2: Dim, S2>(&self, b: &Matrix<T, R2, C2, S2>) -> OMatrix<T, R2, C2>
312//         where S2: StorageMut<T, R2, C2>,
313//               ShapeConstraint: SameNumberOfRows<R2, D> {
314//         let mut res = b.clone_owned();
315//         self.solve_mut(&mut res);
316//         res
317//     }
318//
319//     /// Solves the linear system `self * x = b`, where `x` is the unknown to be determined.
320//     pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>)
321//         where S2: StorageMut<T, R2, C2>,
322//               ShapeConstraint: SameNumberOfRows<R2, D> {
323//
324//         assert_eq!(self.uv.nrows(), b.nrows(), "Bidiagonal solve matrix dimension mismatch.");
325//         assert!(self.uv.is_square(), "Bidiagonal solve: unable to solve a non-square system.");
326//
327//         self.q_tr_mul(b);
328//         self.solve_upper_triangular_mut(b);
329//     }
330//
331//     // TODO: duplicate code from the `solve` module.
332//     fn solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>)
333//         where S2: StorageMut<T, R2, C2>,
334//               ShapeConstraint: SameNumberOfRows<R2, D> {
335//
336//         let dim  = self.uv.nrows();
337//
338//         for k in 0 .. b.ncols() {
339//             let mut b = b.column_mut(k);
340//             for i in (0 .. dim).rev() {
341//                 let coeff;
342//
343//                 unsafe {
344//                     let diag = *self.diag.vget_unchecked(i);
345//                     coeff = *b.vget_unchecked(i) / diag;
346//                     *b.vget_unchecked_mut(i) = coeff;
347//                 }
348//
349//                 b.rows_range_mut(.. i).axpy(-coeff, &self.uv.slice_range(.. i, i), T::one());
350//             }
351//         }
352//     }
353//
354//     /// Computes the inverse of the decomposed matrix.
355//     pub fn inverse(&self) -> OMatrix<T, D, D> {
356//         assert!(self.uv.is_square(), "Bidiagonal inverse: unable to compute the inverse of a non-square matrix.");
357//
358//         // TODO: is there a less naive method ?
359//         let (nrows, ncols) = self.uv.shape_generic();
360//         let mut res = OMatrix::identity_generic(nrows, ncols);
361//         self.solve_mut(&mut res);
362//         res
363//     }
364//
365//     // /// Computes the determinant of the decomposed matrix.
366//     // pub fn determinant(&self) -> T {
367//     //     let dim = self.uv.nrows();
368//     //     assert!(self.uv.is_square(), "Bidiagonal determinant: unable to compute the determinant of a non-square matrix.");
369//
370//     //     let mut res = T::one();
371//     //     for i in 0 .. dim {
372//     //         res *= unsafe { *self.diag.vget_unchecked(i) };
373//     //     }
374//
375//     //     res self.q_determinant()
376//     // }
377// }