Skip to content

Commit

Permalink
Implement SparseLU factorization
Browse files Browse the repository at this point in the history
Add `solve_upper_triangular` to `CsrMatrix`

This allows a sparse matrix to be used for efficient solving with a dense LU decomposition.
  • Loading branch information
JulianKnodt committed Sep 11, 2023
1 parent f404bcb commit 0db2532
Show file tree
Hide file tree
Showing 9 changed files with 534 additions and 2 deletions.
151 changes: 150 additions & 1 deletion nalgebra-sparse/src/csc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::csr::CsrMatrix;
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};

use nalgebra::Scalar;
use nalgebra::{RealField, Scalar};
use num_traits::One;
use std::slice::{Iter, IterMut};

Expand Down Expand Up @@ -553,6 +553,155 @@ impl<T> CscMatrix<T> {
self.filter(|i, j, _| i >= j)
}

/// Solves a lower triangular system, `self` is a matrix of NxN, and `b` is a column vector of size N
/// Assuming that b is dense.
// TODO add an option here for assuming diagonal is one.
pub fn dense_lower_triangular_solve(&self, b: &[T], out: &mut [T], unit_diagonal: bool)
where
T: RealField + Copy,
{
assert_eq!(self.nrows(), self.ncols());
assert_eq!(self.ncols(), b.len());
assert_eq!(out.len(), b.len());
out.copy_from_slice(b);
let n = b.len();

for i in 0..n {
let mul = out[i];
let col = self.col(i);
for (&ri, &v) in col.row_indices().iter().zip(col.values().iter()) {
// ensure that only using the lower part
if ri == i {
if !unit_diagonal {
out[ri] /= v;
}
} else if ri > i {
out[ri] -= v * mul;
}
}
}
}

/// Solves a sparse lower triangular system `Ax = b`, with both the matrix and vector
/// sparse.
/// sparsity_idxs should be precomputed using the sparse_lower_triangle.
/// Assumes that the diagonal of the sparse matrix is all 1.
pub fn sparse_lower_triangular_solve(
&self,
b_idxs: &[usize],
b: &[T],
// idx -> row
// for now, is permitted to be unsorted
// TODO maybe would be better to enforce sorted, but would have to sort internally.
out_sparsity_pattern: &[usize],
out: &mut [T],
assume_unit: bool,
) where
T: RealField + Copy,
{
assert_eq!(self.nrows(), self.ncols());
assert_eq!(b.len(), b_idxs.len());
assert!(b_idxs.iter().all(|&bi| bi < self.ncols()));

assert_eq!(out_sparsity_pattern.len(), out.len());
assert!(out_sparsity_pattern.iter().all(|&i| i < self.ncols()));

let is_sorted = (0..out_sparsity_pattern.len() - 1)
.all(|i| out_sparsity_pattern[i] < out_sparsity_pattern[i + 1]);
if is_sorted {
return self.sparse_lower_triangular_solve_sorted(
b_idxs,
b,
out_sparsity_pattern,
out,
assume_unit,
);
}

// initialize out with b
for (&bv, &bi) in b.iter().zip(b_idxs.iter()) {
let out_pos = out_sparsity_pattern.iter().position(|&p| p == bi).unwrap();
out[out_pos] = bv;
}

for (i, &row) in out_sparsity_pattern.iter().enumerate() {
let col = self.col(row);
if !assume_unit {
if let Some(l_val) = col.get_entry(row) {
out[i] /= l_val.into_value();
} else {
// diagonal is 0, non-invertible
out[i] /= T::zero();
}
}
let mul = out[i];
for (ni, &nrow) in out_sparsity_pattern.iter().enumerate() {
if nrow <= row {
continue;
}
// TODO in a sorted version may be able to iterate without
// having the cost of binary search at each iteration
let l_val = if let Some(l_val) = col.get_entry(nrow) {
l_val.into_value()
} else {
continue;
};
out[ni] -= l_val * mul;
}
}
}
/// Solves a sparse lower triangular system `Ax = b`, with both the matrix and vector
/// sparse.
/// sparsity_idxs should be precomputed using the sparse_lower_triangle.
/// Assumes that the diagonal of the sparse matrix is all 1.
pub fn sparse_lower_triangular_solve_sorted(
&self,
b_idxs: &[usize],
b: &[T],
// idx -> row
// for now, is permitted to be unsorted
// TODO maybe would be better to enforce sorted, but would have to sort internally.
out_sparsity_pattern: &[usize],
out: &mut [T],
assume_unit: bool,
) where
T: RealField + Copy,
{
assert_eq!(self.nrows(), self.ncols());
assert_eq!(b.len(), b_idxs.len());
assert!(b_idxs.iter().all(|&bi| bi < self.ncols()));

assert_eq!(out_sparsity_pattern.len(), out.len());
assert!(out_sparsity_pattern.iter().all(|&i| i < self.ncols()));

// initialize out with b
// TODO can make this more efficient by keeping two iterators in sorted order
for (&bv, &bi) in b.iter().zip(b_idxs.iter()) {
let out_pos = out_sparsity_pattern.iter().position(|&p| p == bi).unwrap();
out[out_pos] = bv;
}

for (i, &row) in out_sparsity_pattern.iter().enumerate() {
let col = self.col(row);
let mut iter = col.row_indices().iter().zip(col.values().iter()).peekable();
if !assume_unit && let Some(l_val) = iter.find(|v| *v.0 >= row) && *l_val.0 == row{
out[i] /= *l_val.1;
}
let mul = out[i];
for (offset, &nrow) in out_sparsity_pattern[i..].iter().enumerate() {
if nrow <= row {
continue;
}
let l_val = if let Some(l_val) = iter.find(|v| *v.0 >= nrow) && *l_val.0 == nrow {
*l_val.1
} else {
break;
};
out[i + offset] -= l_val * mul;
}
}
}

/// Returns the diagonal of the matrix as a sparse matrix.
#[must_use]
pub fn diagonal_as_csc(&self) -> Self
Expand Down
51 changes: 50 additions & 1 deletion nalgebra-sparse/src/csr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::csc::CscMatrix;
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};

use nalgebra::Scalar;
use nalgebra::{DMatrix, DMatrixView, RealField, Scalar};
use num_traits::One;

use std::slice::{Iter, IterMut};
Expand Down Expand Up @@ -573,6 +573,55 @@ impl<T> CsrMatrix<T> {
{
CscMatrix::from(self).transpose_as_csr()
}

/// Solves the equation `Ax = b`, treating `self` as an upper triangular matrix.
/// If `A` is not upper triangular, elements in the lower triangle below the diagonal
/// will be ignored.
///
/// If `m` has zeros along the diagonal, returns `None`.
/// Panics:
/// Panics if `A` and `b` have incompatible shapes, specifically if `b`
/// has a different number of rows and than `a`.
pub fn solve_upper_triangular<'a>(&self, b: impl Into<DMatrixView<'a, T>>) -> Option<DMatrix<T>>
where
T: RealField + Scalar,
{
// https://www.nicolasboumal.net/papers/MAT321_Lecture_notes_Boumal_2019.pdf
// page 48
let b: DMatrixView<'a, T> = b.into();
assert_eq!(b.nrows(), self.nrows());

let out_cols = b.ncols();
let out_rows = self.nrows();

let mut out = DMatrix::zeros(out_rows, out_cols);
for r in (0..out_rows).rev() {
let row = self.row(r);
// only take upper triangle elements
let mut row_iter = row
.col_indices()
.iter()
.copied()
.zip(row.values().iter())
.filter(|&(c, _)| c >= r);

let (c, div) = row_iter.next()?;
// This implies there is a 0 on the diagonal
if c != r || div.is_zero() {
return None;
}
for c in 0..out_cols {
let numer = b.index((r, c)).clone();
let numer = numer
- row_iter
.clone()
.map(|(a_col, val)| val.clone() * b.index((a_col, c)).clone())
.fold(T::zero(), |acc, n| acc + n);
*out.index_mut((r, c)) = numer / div.clone();
}
}
Some(out)
}
}

/// Convert pattern format errors into more meaningful CSR-specific errors.
Expand Down
16 changes: 16 additions & 0 deletions nalgebra-sparse/src/factorization/lu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use crate::CscMatrix;
use nalgebra::RealField;

pub struct LeftLookingLUFactorization<T> {
/// A single matrix stores both the lower and upper triangular components
l_u: CscMatrix<T>
}

impl<T: RealField> LeftLookingLUFactorization<T> {
/// Construct a new sparse LU factorization
/// from a given CSC matrix.
pub fn new(a: &CscMatrix<T>) -> Self {
assert_eq!(a.nrows(), a.ncols());
todo!();
}
}
4 changes: 4 additions & 0 deletions nalgebra-sparse/src/factorization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
mod cholesky;

pub use cholesky::*;

//mod lu;

//pub use lu::*;
7 changes: 7 additions & 0 deletions nalgebra-sparse/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,4 +279,11 @@ impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
SparseEntryMut::Zero => T::zero(),
}
}
/// If the entry is nonzero, returns `Some(&mut value)`, otherwise returns `None`.
pub fn nonzero(self) -> Option<&'a mut T> {
match self {
SparseEntryMut::NonZero(v) => Some(v),
SparseEntryMut::Zero => None,
}
}
}
Loading

0 comments on commit 0db2532

Please sign in to comment.