From fc5ed85c6e510cccf4dc4dde29fbadcf036628de Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 26 Jul 2023 18:58:49 +0200 Subject: [PATCH 1/6] Handle Statistics and SparseArrays as extensions --- Project.toml | 14 +++++- ext/FillArraysSparseArraysExt.jl | 57 +++++++++++++++++++++ ext/FillArraysStatisticsExt.jl | 36 ++++++++++++++ src/FillArrays.jl | 85 +++----------------------------- src/fillalgebra.jl | 1 - 5 files changed, 111 insertions(+), 82 deletions(-) create mode 100644 ext/FillArraysSparseArraysExt.jl create mode 100644 ext/FillArraysStatisticsExt.jl diff --git a/Project.toml b/Project.toml index ca1a18bd..05ec2c17 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.5.0" +version = "1.6.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -8,6 +8,14 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[weakdeps] +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[extensions] +FillArraysSparseArraysExt = "SparseArrays" +FillArraysStatisticsExt = "Statistics" + [compat] Aqua = "0.5, 0.6" julia = "1.6" @@ -16,8 +24,10 @@ julia = "1.6" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test", "Base64", "ReverseDiff", "StaticArrays"] +test = ["Aqua", "Test", "Base64", "ReverseDiff", "SparseArrays", "StaticArrays", "Statistics"] diff --git a/ext/FillArraysSparseArraysExt.jl b/ext/FillArraysSparseArraysExt.jl new file mode 100644 index 00000000..6804fd02 --- /dev/null +++ b/ext/FillArraysSparseArraysExt.jl @@ -0,0 +1,57 @@ +module FillArraysSparseArraysExt + +using SparseArrays +import Base: convert, kron +using FillArrays +using FillArrays: RectDiagonalFill, RectOrDiagonalFill, ZerosVector, ZerosMatrix + +################## +## Sparse arrays +################## +SparseVector{T}(Z::ZerosVector) where T = spzeros(T, length(Z)) +SparseVector{Tv,Ti}(Z::ZerosVector) where {Tv,Ti} = spzeros(Tv, Ti, length(Z)) + +convert(::Type{AbstractSparseVector}, Z::ZerosVector{T}) where T = spzeros(T, length(Z)) +convert(::Type{AbstractSparseVector{T}}, Z::ZerosVector) where T= spzeros(T, length(Z)) + +SparseMatrixCSC{T}(Z::ZerosMatrix) where T = spzeros(T, size(Z)...) +SparseMatrixCSC{Tv,Ti}(Z::Zeros{T,2,Axes}) where {Tv,Ti<:Integer,T,Axes} = spzeros(Tv, Ti, size(Z)...) + +convert(::Type{AbstractSparseMatrix}, Z::ZerosMatrix{T}) where T = spzeros(T, size(Z)...) +convert(::Type{AbstractSparseMatrix{T}}, Z::ZerosMatrix) where T = spzeros(T, size(Z)...) + +convert(::Type{AbstractSparseArray}, Z::Zeros{T}) where T = spzeros(T, size(Z)...) +convert(::Type{AbstractSparseArray{Tv}}, Z::Zeros{T}) where {T,Tv} = spzeros(Tv, size(Z)...) +convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Zeros{T}) where {T,Tv,Ti} = spzeros(Tv, Ti, size(Z)...) +convert(::Type{AbstractSparseArray{Tv,Ti,N}}, Z::Zeros{T,N}) where {T,Tv,Ti,N} = spzeros(Tv, Ti, size(Z)...) + +SparseMatrixCSC{Tv}(Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...) +# works around missing `speye`: +SparseMatrixCSC{Tv,Ti}(Z::Eye{T}) where {T,Tv,Ti<:Integer} = + convert(SparseMatrixCSC{Tv,Ti}, SparseMatrixCSC{Tv}(I, size(Z)...)) + +convert(::Type{AbstractSparseMatrix}, Z::Eye{T}) where {T} = SparseMatrixCSC{T}(I, size(Z)...) +convert(::Type{AbstractSparseMatrix{Tv}}, Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...) + +convert(::Type{AbstractSparseArray}, Z::Eye{T}) where T = SparseMatrixCSC{T}(I, size(Z)...) +convert(::Type{AbstractSparseArray{Tv}}, Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...) + + +convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Eye{T}) where {T,Tv,Ti} = + convert(SparseMatrixCSC{Tv,Ti}, Z) +convert(::Type{AbstractSparseArray{Tv,Ti,2}}, Z::Eye{T}) where {T,Tv,Ti} = + convert(SparseMatrixCSC{Tv,Ti}, Z) + +function SparseMatrixCSC{Tv}(R::RectOrDiagonalFill) where {Tv} + SparseMatrixCSC{Tv,eltype(axes(R,1))}(R) +end +function SparseMatrixCSC{Tv,Ti}(R::RectOrDiagonalFill) where {Tv,Ti} + Base.require_one_based_indexing(R) + v = parent(R) + J = getindex_value(v)*I + SparseMatrixCSC{Tv,Ti}(J, size(R)) +end + +kron(E1::RectDiagonalFill, E2::RectDiagonalFill) = kron(sparse(E1), sparse(E2)) + +end # module diff --git a/ext/FillArraysStatisticsExt.jl b/ext/FillArraysStatisticsExt.jl new file mode 100644 index 00000000..14249d75 --- /dev/null +++ b/ext/FillArraysStatisticsExt.jl @@ -0,0 +1,36 @@ +module FillArraysStatisticsExt + +import Statistics: mean, std, var, cov, cor + +using FillArrays +using FillArrays: AbstractFill, AbstractFillVector, AbstractFillMatrix + +######### +# mean, std +######### + +mean(A::AbstractFill; dims=(:)) = mean(identity, A; dims=dims) +function mean(f::Union{Function, Type}, A::AbstractFill; dims=(:)) + val = float(f(getindex_value(A))) + dims isa Colon ? val : + Fill(val, ntuple(d -> d in dims ? 1 : size(A,d), ndims(A))...) +end + + +function var(A::AbstractFill{T}; corrected::Bool=true, mean=nothing, dims=(:)) where {T<:Number} + dims isa Colon ? zero(float(T)) : + Zeros{float(T)}(ntuple(d -> d in dims ? 1 : size(A,d), ndims(A))...) +end + +cov(::AbstractFillVector{T}; corrected::Bool=true) where {T<:Number} = zero(float(T)) +cov(A::AbstractFillMatrix{T}; corrected::Bool=true, dims::Integer=1) where {T<:Number} = + Zeros{float(T)}(size(A, 3-dims), size(A, 3-dims)) + +cor(::AbstractFillVector{T}) where {T<:Number} = one(float(T)) +function cor(A::AbstractFillMatrix{T}; dims::Integer=1) where {T<:Number} + out = fill(float(T)(NaN), size(A, 3-dims), size(A, 3-dims)) + out[LinearAlgebra.diagind(out)] .= 1 + out +end + +end # module diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 0aff05d6..c7558cf2 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -1,7 +1,7 @@ """ `FillArrays` module to lazily represent matrices with a single value """ module FillArrays -using LinearAlgebra, SparseArrays, Statistics +using LinearAlgebra import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, +, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!, any, all, axes, isone, iterate, unique, allunique, permutedims, inv, @@ -16,9 +16,6 @@ import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape, BroadcastStyle, Broadcasted -import Statistics: mean, std, var, cov, cor - - export Zeros, Ones, Fill, Eye, Trues, Falses, OneElement import Base: oneto @@ -542,52 +539,6 @@ for SMT in (:Diagonal, :Bidiagonal, :Tridiagonal, :SymTridiagonal) end end -################## -## Sparse arrays -################## -SparseVector{T}(Z::ZerosVector) where T = spzeros(T, length(Z)) -SparseVector{Tv,Ti}(Z::ZerosVector) where {Tv,Ti} = spzeros(Tv, Ti, length(Z)) - -convert(::Type{AbstractSparseVector}, Z::ZerosVector{T}) where T = spzeros(T, length(Z)) -convert(::Type{AbstractSparseVector{T}}, Z::ZerosVector) where T= spzeros(T, length(Z)) - -SparseMatrixCSC{T}(Z::ZerosMatrix) where T = spzeros(T, size(Z)...) -SparseMatrixCSC{Tv,Ti}(Z::Zeros{T,2,Axes}) where {Tv,Ti<:Integer,T,Axes} = spzeros(Tv, Ti, size(Z)...) - -convert(::Type{AbstractSparseMatrix}, Z::ZerosMatrix{T}) where T = spzeros(T, size(Z)...) -convert(::Type{AbstractSparseMatrix{T}}, Z::ZerosMatrix) where T = spzeros(T, size(Z)...) - -convert(::Type{AbstractSparseArray}, Z::Zeros{T}) where T = spzeros(T, size(Z)...) -convert(::Type{AbstractSparseArray{Tv}}, Z::Zeros{T}) where {T,Tv} = spzeros(Tv, size(Z)...) -convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Zeros{T}) where {T,Tv,Ti} = spzeros(Tv, Ti, size(Z)...) -convert(::Type{AbstractSparseArray{Tv,Ti,N}}, Z::Zeros{T,N}) where {T,Tv,Ti,N} = spzeros(Tv, Ti, size(Z)...) - -SparseMatrixCSC{Tv}(Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...) -# works around missing `speye`: -SparseMatrixCSC{Tv,Ti}(Z::Eye{T}) where {T,Tv,Ti<:Integer} = - convert(SparseMatrixCSC{Tv,Ti}, SparseMatrixCSC{Tv}(I, size(Z)...)) - -convert(::Type{AbstractSparseMatrix}, Z::Eye{T}) where {T} = SparseMatrixCSC{T}(I, size(Z)...) -convert(::Type{AbstractSparseMatrix{Tv}}, Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...) - -convert(::Type{AbstractSparseArray}, Z::Eye{T}) where T = SparseMatrixCSC{T}(I, size(Z)...) -convert(::Type{AbstractSparseArray{Tv}}, Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...) - - -convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Eye{T}) where {T,Tv,Ti} = - convert(SparseMatrixCSC{Tv,Ti}, Z) -convert(::Type{AbstractSparseArray{Tv,Ti,2}}, Z::Eye{T}) where {T,Tv,Ti} = - convert(SparseMatrixCSC{Tv,Ti}, Z) - -function SparseMatrixCSC{Tv}(R::RectOrDiagonalFill) where {Tv} - SparseMatrixCSC{Tv,eltype(axes(R,1))}(R) -end -function SparseMatrixCSC{Tv,Ti}(R::RectOrDiagonalFill) where {Tv,Ti} - Base.require_one_based_indexing(R) - v = parent(R) - J = getindex_value(v)*I - SparseMatrixCSC{Tv,Ti}(J, size(R)) -end ######### # maximum/minimum @@ -698,35 +649,6 @@ function in(x, A::RectDiagonal{<:Number}) x == zero(eltype(A)) || x in A.diag end -######### -# mean, std -######### - -mean(A::AbstractFill; dims=(:)) = mean(identity, A; dims=dims) -function mean(f::Union{Function, Type}, A::AbstractFill; dims=(:)) - val = float(f(getindex_value(A))) - dims isa Colon ? val : - Fill(val, ntuple(d -> d in dims ? 1 : size(A,d), ndims(A))...) -end - - -function var(A::AbstractFill{T}; corrected::Bool=true, mean=nothing, dims=(:)) where {T<:Number} - dims isa Colon ? zero(float(T)) : - Zeros{float(T)}(ntuple(d -> d in dims ? 1 : size(A,d), ndims(A))...) -end - -cov(A::AbstractFillVector{T}; corrected::Bool=true) where {T<:Number} = zero(float(T)) -cov(A::AbstractFillMatrix{T}; corrected::Bool=true, dims::Integer=1) where {T<:Number} = - Zeros{float(T)}(size(A, 3-dims), size(A, 3-dims)) - -cor(A::AbstractFillVector{T}) where {T<:Number} = one(float(T)) -function cor(A::AbstractFillMatrix{T}; dims::Integer=1) where {T<:Number} - out = fill(float(T)(NaN), size(A, 3-dims), size(A, 3-dims)) - out[LinearAlgebra.diagind(out)] .= 1 - out -end - - ######### # include ######### @@ -735,6 +657,11 @@ include("fillalgebra.jl") include("fillbroadcast.jl") include("trues.jl") +@static if !isdefined(Base, :get_extension) + include("../ext/FillArraysSparseArraysExt.jl") + include("../ext/FillArraysStatisticsExt.jl") +end + ## # print ## diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index f0f7d044..f28992f8 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -467,4 +467,3 @@ function kron(f::AbstractFillVecOrMat, g::AbstractFillVecOrMat) sz = _kronsize(f, g) _kron(f, g, sz) end -kron(E1::RectDiagonalFill, E2::RectDiagonalFill) = kron(sparse(E1), sparse(E2)) From 8d61af56e06f72dc08f528c9f94a86c9250c394d Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 26 Jul 2023 22:06:33 +0200 Subject: [PATCH 2/6] review comments --- ext/FillArraysSparseArraysExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/FillArraysSparseArraysExt.jl b/ext/FillArraysSparseArraysExt.jl index 6804fd02..aa3d4ac6 100644 --- a/ext/FillArraysSparseArraysExt.jl +++ b/ext/FillArraysSparseArraysExt.jl @@ -3,7 +3,8 @@ module FillArraysSparseArraysExt using SparseArrays import Base: convert, kron using FillArrays -using FillArrays: RectDiagonalFill, RectOrDiagonalFill, ZerosVector, ZerosMatrix +using FillArrays: RectDiagonalFill, RectOrDiagonalFill, ZerosVector, ZerosMatrix, getindex_value +using LinearAlgebra ################## ## Sparse arrays From 209a1ba6b759e6d70672602381a4847765ead300 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 26 Jul 2023 22:24:17 +0200 Subject: [PATCH 3/6] clean-ip --- ext/FillArraysStatisticsExt.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/ext/FillArraysStatisticsExt.jl b/ext/FillArraysStatisticsExt.jl index 14249d75..f39cad69 100644 --- a/ext/FillArraysStatisticsExt.jl +++ b/ext/FillArraysStatisticsExt.jl @@ -1,13 +1,10 @@ module FillArraysStatisticsExt -import Statistics: mean, std, var, cov, cor +import Statistics: mean, var, cov, cor +using LinearAlgebra: diagind using FillArrays -using FillArrays: AbstractFill, AbstractFillVector, AbstractFillMatrix - -######### -# mean, std -######### +using FillArrays: AbstractFill, AbstractFillVector, AbstractFillMatrix, getindex_value mean(A::AbstractFill; dims=(:)) = mean(identity, A; dims=dims) function mean(f::Union{Function, Type}, A::AbstractFill; dims=(:)) @@ -29,7 +26,7 @@ cov(A::AbstractFillMatrix{T}; corrected::Bool=true, dims::Integer=1) where {T<:N cor(::AbstractFillVector{T}) where {T<:Number} = one(float(T)) function cor(A::AbstractFillMatrix{T}; dims::Integer=1) where {T<:Number} out = fill(float(T)(NaN), size(A, 3-dims), size(A, 3-dims)) - out[LinearAlgebra.diagind(out)] .= 1 + out[diagind(out)] .= 1 out end From 3cc7063c2a0383a06c5a6ba41bd3593f71b53e43 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 26 Jul 2023 22:41:06 +0200 Subject: [PATCH 4/6] add deprecation --- ext/FillArraysSparseArraysExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/FillArraysSparseArraysExt.jl b/ext/FillArraysSparseArraysExt.jl index aa3d4ac6..6a82f246 100644 --- a/ext/FillArraysSparseArraysExt.jl +++ b/ext/FillArraysSparseArraysExt.jl @@ -53,6 +53,7 @@ function SparseMatrixCSC{Tv,Ti}(R::RectOrDiagonalFill) where {Tv,Ti} SparseMatrixCSC{Tv,Ti}(J, size(R)) end -kron(E1::RectDiagonalFill, E2::RectDiagonalFill) = kron(sparse(E1), sparse(E2)) +# TODO: remove in v2.0 +@deprecate kron(E1::RectDiagonalFill, E2::RectDiagonalFill) kron(sparse(E1), sparse(E2)) end # module From 7c7a333d21eac8b95fbd959f3d5134abd5e2f327 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 31 Jul 2023 10:34:38 +0200 Subject: [PATCH 5/6] error or load --- ext/FillArraysSparseArraysExt.jl | 3 +-- src/fillalgebra.jl | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ext/FillArraysSparseArraysExt.jl b/ext/FillArraysSparseArraysExt.jl index 6a82f246..9afd9204 100644 --- a/ext/FillArraysSparseArraysExt.jl +++ b/ext/FillArraysSparseArraysExt.jl @@ -53,7 +53,6 @@ function SparseMatrixCSC{Tv,Ti}(R::RectOrDiagonalFill) where {Tv,Ti} SparseMatrixCSC{Tv,Ti}(J, size(R)) end -# TODO: remove in v2.0 -@deprecate kron(E1::RectDiagonalFill, E2::RectDiagonalFill) kron(sparse(E1), sparse(E2)) +kron_fill(A::RectDiagonalFill, B::RectDiagonalFill) = kron(sparse(A), sparse(B)) end # module diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index f28992f8..58b58e2e 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -467,3 +467,5 @@ function kron(f::AbstractFillVecOrMat, g::AbstractFillVecOrMat) sz = _kronsize(f, g) _kron(f, g, sz) end +kron(A::RectDiagonalFill, B::RectDiagonalFill) = kron_fill(A, B) +kron_fill(A, B) = error("Please load SparseArrays.jl") From 6147c9ed49f207fbc534c9b4340a22ee34882da0 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 31 Jul 2023 12:52:28 +0200 Subject: [PATCH 6/6] Revert "error or load" This reverts commit 7c7a333d21eac8b95fbd959f3d5134abd5e2f327. --- ext/FillArraysSparseArraysExt.jl | 3 ++- src/fillalgebra.jl | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ext/FillArraysSparseArraysExt.jl b/ext/FillArraysSparseArraysExt.jl index 9afd9204..6a82f246 100644 --- a/ext/FillArraysSparseArraysExt.jl +++ b/ext/FillArraysSparseArraysExt.jl @@ -53,6 +53,7 @@ function SparseMatrixCSC{Tv,Ti}(R::RectOrDiagonalFill) where {Tv,Ti} SparseMatrixCSC{Tv,Ti}(J, size(R)) end -kron_fill(A::RectDiagonalFill, B::RectDiagonalFill) = kron(sparse(A), sparse(B)) +# TODO: remove in v2.0 +@deprecate kron(E1::RectDiagonalFill, E2::RectDiagonalFill) kron(sparse(E1), sparse(E2)) end # module diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 58b58e2e..f28992f8 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -467,5 +467,3 @@ function kron(f::AbstractFillVecOrMat, g::AbstractFillVecOrMat) sz = _kronsize(f, g) _kron(f, g, sz) end -kron(A::RectDiagonalFill, B::RectDiagonalFill) = kron_fill(A, B) -kron_fill(A, B) = error("Please load SparseArrays.jl")