Skip to content

Commit

Permalink
faster matrix-fillmatrix multiplication
Browse files Browse the repository at this point in the history
cleanup
  • Loading branch information
CarloLucibello committed Dec 4, 2020
1 parent 2915481 commit 2766105
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 41 deletions.
64 changes: 23 additions & 41 deletions src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ end

*(a::Zeros{<:Any,1}, b::AbstractMatrix) = mult_zeros(a, b)
*(a::Zeros{<:Any,2}, b::AbstractMatrix) = mult_zeros(a, b)
*(a::Zeros{<:Any,2}, b::AbstractTriangular) = mult_zeros(a, b)
*(a::AbstractMatrix, b::Zeros{<:Any,1}) = mult_zeros(a, b)
*(a::AbstractMatrix, b::Zeros{<:Any,2}) = mult_zeros(a, b)
*(a::AbstractTriangular, b::Zeros{<:Any,2}) = mult_zeros(a, b)
*(a::Zeros{<:Any,1}, b::AbstractVector) = mult_zeros(a, b)
*(a::Zeros{<:Any,2}, b::AbstractVector) = mult_zeros(a, b)
*(a::AbstractVector, b::Zeros{<:Any,2}) = mult_zeros(a, b)
Expand All @@ -95,66 +97,36 @@ end
*(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b)
*(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b)

# Cannot unify following methods for Diagonal
# due to ambiguity with general array mult. with fill
function *(a::Diagonal, b::FillMatrix)
function *(a::Diagonal, b::AbstractFill{T,2}) where T
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
a.diag .* b # use special broadcast
end
function *(a::FillMatrix, b::Diagonal)
function *(a::AbstractFill{T,2}, b::Diagonal) where T
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
a .* permutedims(b.diag) # use special broadcast
end
function *(a::Diagonal, b::OnesMatrix)
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
a.diag .* b # use special broadcast
end
function *(a::OnesMatrix, b::Diagonal)
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
a .* permutedims(b.diag) # use special broadcast
end

*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2))
*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2))
*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = reshape(sum(a; dims=2) .* b.value, size(a, 1))

function *(x::AbstractMatrix, f::FillMatrix)
function mult_sum2(x::AbstractMatrix, f::AbstractFill{T,2}) where T
axes(x, 2) axes(f, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
m = size(f, 2)
repeat(sum(x, dims=2) * f.value, 1, m)
repeat(sum(x, dims=2) * getindex_value(f), 1, m)
end

function *(f::FillMatrix, x::AbstractMatrix)
function mult_sum1(f::AbstractFill{T,2}, x::AbstractMatrix) where T
axes(f, 2) axes(x, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
m = size(f, 1)
repeat(sum(x, dims=1) * f.value, m, 1)
repeat(sum(x, dims=1) * getindex_value(f), m, 1)
end

function *(x::AbstractMatrix, f::OnesMatrix)
axes(x, 2) axes(f, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
m = size(f, 2)
repeat(sum(x, dims=2) * one(eltype(f)), 1, m)
end
*(x::AbstractMatrix, y::AbstractFill{<:Any,2}) = mult_sum2(x, y)
*(x::AbstractTriangular, y::AbstractFill{<:Any,2}) = mult_sum2(x, y)
*(x::AbstractFill{<:Any,2}, y::AbstractMatrix) = mult_sum1(x, y)
*(x::AbstractFill{<:Any,2}, y::AbstractTriangular) = mult_sum1(x, y)

function *(f::OnesMatrix, x::AbstractMatrix)
axes(f, 2) axes(x, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
m = size(f, 1)
repeat(sum(x, dims=1) * one(eltype(f)), m, 1)
end

*(x::FillMatrix, y::FillMatrix) = mult_fill(x, y)
*(x::FillMatrix, y::OnesMatrix) = mult_fill(x, y)
*(x::OnesMatrix, y::FillMatrix) = mult_fill(x, y)
*(x::OnesMatrix, y::OnesMatrix) = mult_fill(x, y)
*(x::ZerosMatrix, y::OnesMatrix) = mult_zeros(x, y)
*(x::ZerosMatrix, y::FillMatrix) = mult_zeros(x, y)
*(x::FillMatrix, y::ZerosMatrix) = mult_zeros(x, y)
*(x::OnesMatrix, y::ZerosMatrix) = mult_zeros(x, y)

### These methods are faster for small n #############
# function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
# fB = similar(parent(a), size(b, 1), size(b, 2))
# fill!(fB, b.value)
Expand All @@ -173,6 +145,16 @@ end
# return a*fB
# end

## Matrix-Vector multiplication

*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T =
reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2))
*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T =
reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2))
*(a::StridedMatrix{T}, b::Fill{T, 1}) where T =
reshape(sum(a; dims=2) .* b.value, size(a, 1))


function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S}
la, lb = length(a), length(b)
if la lb
Expand Down
13 changes: 13 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,19 @@ end
@test E*(1:5) 1.0:5.0
@test (1:5)'E == (1.0:5)'
@test E*E E

# Adjoint / Transpose / Triangular / Symmetric
for x in [transpose(rand(2, 2)),
adjoint(rand(2,2)),
UpperTriangular(rand(2,2)),
Symmetric(rand(2,2))]
@test x * Ones(2, 2) isa Matrix
@test Ones(2, 2) * x isa Matrix
@test x * Zeros(2, 2) isa Zeros
@test Zeros(2, 2) * x isa Zeros
@test x * Fill(1., 2, 2) isa Matrix
@test Fill(1., 2, 2) * x isa Matrix
end
end

@testset "count" begin
Expand Down

0 comments on commit 2766105

Please sign in to comment.