1use num::Zero;
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use crate::allocator::{Allocator, Reallocator};
6use crate::base::{Const, DefaultAllocator, Matrix, OMatrix, OVector, Unit};
7use crate::constraint::{SameNumberOfRows, ShapeConstraint};
8use crate::dimension::{Dim, DimMin, DimMinimum};
9use crate::storage::StorageMut;
10use crate::ComplexField;
11
12use crate::geometry::Reflection;
13use crate::linalg::{householder, PermutationSequence};
14use std::mem::MaybeUninit;
15
16#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
18#[cfg_attr(
19 feature = "serde-serialize-no-std",
20 serde(bound(serialize = "DefaultAllocator: Allocator<T, R, C> +
21 Allocator<T, DimMinimum<R, C>>,
22 OMatrix<T, R, C>: Serialize,
23 PermutationSequence<DimMinimum<R, C>>: Serialize,
24 OVector<T, DimMinimum<R, C>>: Serialize"))
25)]
26#[cfg_attr(
27 feature = "serde-serialize-no-std",
28 serde(bound(deserialize = "DefaultAllocator: Allocator<T, R, C> +
29 Allocator<T, DimMinimum<R, C>>,
30 OMatrix<T, R, C>: Deserialize<'de>,
31 PermutationSequence<DimMinimum<R, C>>: Deserialize<'de>,
32 OVector<T, DimMinimum<R, C>>: Deserialize<'de>"))
33)]
34#[derive(Clone, Debug)]
35pub struct ColPivQR<T: ComplexField, R: DimMin<C>, C: Dim>
36where
37 DefaultAllocator: Allocator<T, R, C>
38 + Allocator<T, DimMinimum<R, C>>
39 + Allocator<(usize, usize), DimMinimum<R, C>>,
40{
41 col_piv_qr: OMatrix<T, R, C>,
42 p: PermutationSequence<DimMinimum<R, C>>,
43 diag: OVector<T, DimMinimum<R, C>>,
44}
45
46impl<T: ComplexField, R: DimMin<C>, C: Dim> Copy for ColPivQR<T, R, C>
47where
48 DefaultAllocator: Allocator<T, R, C>
49 + Allocator<T, DimMinimum<R, C>>
50 + Allocator<(usize, usize), DimMinimum<R, C>>,
51 OMatrix<T, R, C>: Copy,
52 PermutationSequence<DimMinimum<R, C>>: Copy,
53 OVector<T, DimMinimum<R, C>>: Copy,
54{
55}
56
57impl<T: ComplexField, R: DimMin<C>, C: Dim> ColPivQR<T, R, C>
58where
59 DefaultAllocator: Allocator<T, R, C>
60 + Allocator<T, R>
61 + Allocator<T, DimMinimum<R, C>>
62 + Allocator<(usize, usize), DimMinimum<R, C>>,
63{
64 pub fn new(mut matrix: OMatrix<T, R, C>) -> Self {
66 let (nrows, ncols) = matrix.shape_generic();
67 let min_nrows_ncols = nrows.min(ncols);
68 let mut p = PermutationSequence::identity_generic(min_nrows_ncols);
69
70 if min_nrows_ncols.value() == 0 {
71 return ColPivQR {
72 col_piv_qr: matrix,
73 p,
74 diag: Matrix::zeros_generic(min_nrows_ncols, Const::<1>),
75 };
76 }
77
78 let mut diag = Matrix::uninit(min_nrows_ncols, Const::<1>);
79
80 for i in 0..min_nrows_ncols.value() {
81 let piv = matrix.slice_range(i.., i..).icamax_full();
82 let col_piv = piv.1 + i;
83 matrix.swap_columns(i, col_piv);
84 p.append_permutation(i, col_piv);
85
86 diag[i] =
87 MaybeUninit::new(householder::clear_column_unchecked(&mut matrix, i, 0, None));
88 }
89
90 let diag = unsafe { diag.assume_init() };
92
93 ColPivQR {
94 col_piv_qr: matrix,
95 p,
96 diag,
97 }
98 }
99
100 #[inline]
102 #[must_use]
103 pub fn r(&self) -> OMatrix<T, DimMinimum<R, C>, C>
104 where
105 DefaultAllocator: Allocator<T, DimMinimum<R, C>, C>,
106 {
107 let (nrows, ncols) = self.col_piv_qr.shape_generic();
108 let mut res = self
109 .col_piv_qr
110 .rows_generic(0, nrows.min(ncols))
111 .upper_triangle();
112 res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
113 res
114 }
115
116 #[inline]
120 pub fn unpack_r(self) -> OMatrix<T, DimMinimum<R, C>, C>
121 where
122 DefaultAllocator: Reallocator<T, R, C, DimMinimum<R, C>, C>,
123 {
124 let (nrows, ncols) = self.col_piv_qr.shape_generic();
125 let mut res = self
126 .col_piv_qr
127 .resize_generic(nrows.min(ncols), ncols, T::zero());
128 res.fill_lower_triangle(T::zero(), 1);
129 res.set_partial_diagonal(self.diag.iter().map(|e| T::from_real(e.clone().modulus())));
130 res
131 }
132
133 #[must_use]
135 pub fn q(&self) -> OMatrix<T, R, DimMinimum<R, C>>
136 where
137 DefaultAllocator: Allocator<T, R, DimMinimum<R, C>>,
138 {
139 let (nrows, ncols) = self.col_piv_qr.shape_generic();
140
141 let mut res = Matrix::identity_generic(nrows, nrows.min(ncols));
144 let dim = self.diag.len();
145
146 for i in (0..dim).rev() {
147 let axis = self.col_piv_qr.slice_range(i.., i);
148 let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
150
151 let mut res_rows = res.slice_range_mut(i.., i..);
152 refl.reflect_with_sign(&mut res_rows, self.diag[i].clone().signum());
153 }
154
155 res
156 }
157 #[inline]
159 #[must_use]
160 pub fn p(&self) -> &PermutationSequence<DimMinimum<R, C>> {
161 &self.p
162 }
163
164 pub fn unpack(
166 self,
167 ) -> (
168 OMatrix<T, R, DimMinimum<R, C>>,
169 OMatrix<T, DimMinimum<R, C>, C>,
170 PermutationSequence<DimMinimum<R, C>>,
171 )
172 where
173 DimMinimum<R, C>: DimMin<C, Output = DimMinimum<R, C>>,
174 DefaultAllocator: Allocator<T, R, DimMinimum<R, C>>
175 + Reallocator<T, R, C, DimMinimum<R, C>, C>
176 + Allocator<(usize, usize), DimMinimum<R, C>>,
177 {
178 (self.q(), self.r(), self.p)
179 }
180
181 #[doc(hidden)]
182 pub fn col_piv_qr_internal(&self) -> &OMatrix<T, R, C> {
183 &self.col_piv_qr
184 }
185
186 pub fn q_tr_mul<R2: Dim, C2: Dim, S2>(&self, rhs: &mut Matrix<T, R2, C2, S2>)
188 where
189 S2: StorageMut<T, R2, C2>,
190 {
191 let dim = self.diag.len();
192
193 for i in 0..dim {
194 let axis = self.col_piv_qr.slice_range(i.., i);
195 let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
196
197 let mut rhs_rows = rhs.rows_range_mut(i..);
198 refl.reflect_with_sign(&mut rhs_rows, self.diag[i].clone().signum().conjugate());
199 }
200 }
201}
202
203impl<T: ComplexField, D: DimMin<D, Output = D>> ColPivQR<T, D, D>
204where
205 DefaultAllocator:
206 Allocator<T, D, D> + Allocator<T, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
207{
208 #[must_use = "Did you mean to use solve_mut()?"]
212 pub fn solve<R2: Dim, C2: Dim, S2>(
213 &self,
214 b: &Matrix<T, R2, C2, S2>,
215 ) -> Option<OMatrix<T, R2, C2>>
216 where
217 S2: StorageMut<T, R2, C2>,
218 ShapeConstraint: SameNumberOfRows<R2, D>,
219 DefaultAllocator: Allocator<T, R2, C2>,
220 {
221 let mut res = b.clone_owned();
222
223 if self.solve_mut(&mut res) {
224 Some(res)
225 } else {
226 None
227 }
228 }
229
230 pub fn solve_mut<R2: Dim, C2: Dim, S2>(&self, b: &mut Matrix<T, R2, C2, S2>) -> bool
235 where
236 S2: StorageMut<T, R2, C2>,
237 ShapeConstraint: SameNumberOfRows<R2, D>,
238 {
239 assert_eq!(
240 self.col_piv_qr.nrows(),
241 b.nrows(),
242 "ColPivQR solve matrix dimension mismatch."
243 );
244 assert!(
245 self.col_piv_qr.is_square(),
246 "ColPivQR solve: unable to solve a non-square system."
247 );
248
249 self.q_tr_mul(b);
250 let solved = self.solve_upper_triangular_mut(b);
251 self.p.inv_permute_rows(b);
252
253 solved
254 }
255
256 fn solve_upper_triangular_mut<R2: Dim, C2: Dim, S2>(
258 &self,
259 b: &mut Matrix<T, R2, C2, S2>,
260 ) -> bool
261 where
262 S2: StorageMut<T, R2, C2>,
263 ShapeConstraint: SameNumberOfRows<R2, D>,
264 {
265 let dim = self.col_piv_qr.nrows();
266
267 for k in 0..b.ncols() {
268 let mut b = b.column_mut(k);
269 for i in (0..dim).rev() {
270 let coeff;
271
272 unsafe {
273 let diag = self.diag.vget_unchecked(i).clone().modulus();
274
275 if diag.is_zero() {
276 return false;
277 }
278
279 coeff = b.vget_unchecked(i).clone().unscale(diag);
280 *b.vget_unchecked_mut(i) = coeff.clone();
281 }
282
283 b.rows_range_mut(..i)
284 .axpy(-coeff, &self.col_piv_qr.slice_range(..i, i), T::one());
285 }
286 }
287
288 true
289 }
290
291 #[must_use]
295 pub fn try_inverse(&self) -> Option<OMatrix<T, D, D>> {
296 assert!(
297 self.col_piv_qr.is_square(),
298 "ColPivQR inverse: unable to compute the inverse of a non-square matrix."
299 );
300
301 let (nrows, ncols) = self.col_piv_qr.shape_generic();
303 let mut res = OMatrix::identity_generic(nrows, ncols);
304
305 if self.solve_mut(&mut res) {
306 Some(res)
307 } else {
308 None
309 }
310 }
311
312 #[must_use]
314 pub fn is_invertible(&self) -> bool {
315 assert!(
316 self.col_piv_qr.is_square(),
317 "ColPivQR: unable to test the invertibility of a non-square matrix."
318 );
319
320 for i in 0..self.diag.len() {
321 if self.diag[i].is_zero() {
322 return false;
323 }
324 }
325
326 true
327 }
328
329 #[must_use]
331 pub fn determinant(&self) -> T {
332 let dim = self.col_piv_qr.nrows();
333 assert!(
334 self.col_piv_qr.is_square(),
335 "ColPivQR determinant: unable to compute the determinant of a non-square matrix."
336 );
337
338 let mut res = T::one();
339 for i in 0..dim {
340 res *= unsafe { self.diag.vget_unchecked(i).clone() };
341 }
342
343 res * self.p.determinant()
344 }
345}