Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proof of concept for mul_tr #1387

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
9 changes: 8 additions & 1 deletion src/base/alias_view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,14 @@ pub type MatrixViewXx6<'a, T, RStride = U1, CStride = Dyn> =
pub type VectorView<'a, T, D, RStride = U1, CStride = D> =
Matrix<T, D, U1, ViewStorage<'a, T, D, U1, RStride, CStride>>;

/// An immutable row vector view with dimensions known at compile-time.
///
///
///
/// **Because this is an alias, not all its methods are listed here. See the [`Matrix`](crate::base::Matrix) type too.**
pub type RowVectorView<'a, T, D, RStride = D, CStride = U1> =
Matrix<T, U1, D, ViewStorage<'a, T, U1, D, RStride, CStride>>;

/// An immutable column vector view with dimensions known at compile-time.
///
/// See [`SVectorViewMut`] for a mutable version of this type.
Expand Down Expand Up @@ -806,7 +814,6 @@ pub type MatrixViewMutXx5<'a, T, RStride = U1, CStride = Dyn> =
/// **Because this is an alias, not all its methods are listed here. See the [`Matrix`](crate::base::Matrix) type too.**
pub type MatrixViewMutXx6<'a, T, RStride = U1, CStride = Dyn> =
Matrix<T, Dyn, U6, ViewStorageMut<'a, T, Dyn, U6, RStride, CStride>>;

/// A mutable column vector view with dimensions known at compile-time.
///
/// See [`VectorView`] for an immutable version of this type.
Expand Down
67 changes: 66 additions & 1 deletion src/base/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ use crate::base::constraint::{
use crate::base::dimension::{Dim, DimMul, DimName, DimProd, Dyn};
use crate::base::storage::{Storage, StorageMut};
use crate::base::uninit::Uninit;
use crate::base::{DefaultAllocator, Matrix, MatrixSum, OMatrix, Scalar, VectorView};
use crate::base::{
DefaultAllocator, Matrix, MatrixSum, OMatrix, RowVectorView, Scalar, VectorView,
};
use crate::storage::IsContiguous;
use crate::uninit::{Init, InitStatus};
use crate::{RawStorage, RawStorageMut, SimdComplexField};
Expand Down Expand Up @@ -680,6 +682,21 @@ where
unsafe { res.assume_init() }
}

#[inline]
#[must_use]
/// Equivalent to `self * rhs.transpose()`.
pub fn mul_tr<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<T, R2, C2, SB>) -> OMatrix<T, R1, R2>
where
SB: Storage<T, R2, C2>,
DefaultAllocator: Allocator<T, R1, R2>,
ShapeConstraint: SameNumberOfColumns<C1, C2>,
{
let mut res = Matrix::uninit(self.shape_generic().0, rhs.shape_generic().0);
self.yy_mul_to_uninit(Uninit, rhs, &mut res, |a, b| a.dot(b));
// SAFETY: this is OK because the result is now initialized.
unsafe { res.assume_init() }
}

/// Equivalent to `self.adjoint() * rhs`.
#[inline]
#[must_use]
Expand Down Expand Up @@ -744,6 +761,54 @@ where
}
}

#[inline(always)]
fn yy_mul_to_uninit<Status, R2: Dim, C2: Dim, SB, R3: Dim, C3: Dim, SC>(
&self,
_status: Status,
rhs: &Matrix<T, R2, C2, SB>,
out: &mut Matrix<Status::Value, R3, C3, SC>,
dot: impl Fn(
&RowVectorView<'_, T, C1, SA::RStride, SA::CStride>,
&RowVectorView<'_, T, C2, SB::RStride, SB::CStride>,
) -> T,
) where
Status: InitStatus<T>,
SB: RawStorage<T, R2, C2>,
SC: RawStorageMut<Status::Value, R3, C3>,
ShapeConstraint: SameNumberOfColumns<C1, C2> + DimEq<R1, R3> + DimEq<R2, C3>,
{
let (nrows1, ncols1) = self.shape();
let (nrows2, ncols2) = rhs.shape();
let (nrows3, ncols3) = out.shape();

assert!(
ncols1 == ncols2,
"Matrix multiplication dimensions mismatch {:?} and {:?}: left cols != right cols.",
self.shape(),
rhs.shape()
);
assert!(
nrows1 == nrows3,
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left rows != right rows.",
self.shape(),
out.shape()
);
assert!(
nrows2 == ncols3,
"Matrix multiplication output dimensions mismatch {:?} and {:?}: left rows != right cols",
rhs.shape(),
out.shape()
);

for i in 0..nrows1 {
for j in 0..nrows2 {
let dot = dot(&self.row(i), &rhs.row(j));
let elt = unsafe { out.get_unchecked_mut((i, j)) };
Status::init(elt, dot)
}
}
}

/// Equivalent to `self.transpose() * rhs` but stores the result into `out` to avoid
/// allocations.
#[inline]
Expand Down
5 changes: 5 additions & 0 deletions tests/core/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,11 @@ mod transposition_tests {
fn tr_mul_is_transpose_then_mul(m in matrix(PROPTEST_F64, Const::<4>, Const::<6>), v in vector4()) {
prop_assert!(relative_eq!(m.transpose() * v, m.tr_mul(&v), epsilon = 1.0e-7))
}

#[test]
fn mul_tr_is_transpose_rhs_then_mul(m in matrix(PROPTEST_F64, Const::<6>, Const::<4>), v in vector4()) {
prop_assert!(relative_eq!(m * v, m.mul_tr(&v.transpose()), epsilon = 1.0e-7))
}
}
}

Expand Down
Loading