diff --git a/Project.toml b/Project.toml index 47ee0d086..c15432dd7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.37" +version = "0.10.38" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -19,7 +19,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] ChainRulesCore = "1" @@ -34,5 +33,4 @@ Requires = "1.0.1" SpecialFunctions = "0.8, 0.9, 0.10, 1, 2" StatsBase = "0.32, 0.33" TensorCore = "0.1" -ZygoteRules = "0.2" julia = "1.3" diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 0211340e6..39a4a1385 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -60,7 +60,6 @@ using IrrationalConstants: logtwo, twoπ, invsqrt2 using LogExpFunctions: softplus using StatsBase using TensorCore -using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield # Hack to work around Zygote type inference problems. const Distances_pairwise = Distances.pairwise @@ -122,7 +121,6 @@ include("mokernels/intrinsiccoregion.jl") include("mokernels/lmm.jl") include("chainrules.jl") -include("zygoterules.jl") include("test_utils.jl") diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index 88e719ef1..87fab6493 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -82,37 +82,37 @@ end # Kernel matrix operations function kernelmatrix_diag!(K::AbstractVector, κ::TransformedKernel, x::AbstractVector) - return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x)) + return kernelmatrix_diag!(K, κ.kernel, map(κ.transform, x)) end function kernelmatrix_diag!( K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector ) - return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) + return kernelmatrix_diag!(K, κ.kernel, map(κ.transform, x), map(κ.transform, y)) end function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector) - return kernelmatrix!(K, κ.kernel, _map(κ.transform, x)) + return kernelmatrix!(K, κ.kernel, map(κ.transform, x)) end function kernelmatrix!( K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector, y::AbstractVector ) - return kernelmatrix!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) + return kernelmatrix!(K, κ.kernel, map(κ.transform, x), map(κ.transform, y)) end function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector) - return kernelmatrix_diag(κ.kernel, _map(κ.transform, x)) + return kernelmatrix_diag(κ.kernel, map(κ.transform, x)) end function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::AbstractVector) - return kernelmatrix_diag(κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) + return kernelmatrix_diag(κ.kernel, map(κ.transform, x), map(κ.transform, y)) end function kernelmatrix(κ::TransformedKernel, x::AbstractVector) - return kernelmatrix(κ.kernel, _map(κ.transform, x)) + return kernelmatrix(κ.kernel, map(κ.transform, x)) end function kernelmatrix(κ::TransformedKernel, x::AbstractVector, y::AbstractVector) - return kernelmatrix(κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) + return kernelmatrix(κ.kernel, map(κ.transform, x), map(κ.transform, y)) end diff --git a/src/transform/ardtransform.jl b/src/transform/ardtransform.jl index 726d940ad..ff1ac3027 100644 --- a/src/transform/ardtransform.jl +++ b/src/transform/ardtransform.jl @@ -35,9 +35,9 @@ dim(t::ARDTransform) = length(t.v) (t::ARDTransform)(x::Real) = only(t.v) * x (t::ARDTransform)(x) = t.v .* x -_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x -_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X) -_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X) +Base.map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x +Base.map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X) +Base.map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X) Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v) diff --git a/src/transform/chaintransform.jl b/src/transform/chaintransform.jl index bd4627b19..9fa0e4570 100644 --- a/src/transform/chaintransform.jl +++ b/src/transform/chaintransform.jl @@ -39,7 +39,7 @@ Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transfor (t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x) -function _map(t::ChainTransform, x::AbstractVector) +function Base.map(t::ChainTransform, x::AbstractVector) return foldl((x, t) -> map(t, x), t.transforms; init=x) end diff --git a/src/transform/functiontransform.jl b/src/transform/functiontransform.jl index 53dc0b28d..df53d2991 100644 --- a/src/transform/functiontransform.jl +++ b/src/transform/functiontransform.jl @@ -23,16 +23,16 @@ end (t::FunctionTransform)(x) = t.f(x) -_map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x) +Base.map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x) -function _map(t::FunctionTransform, x::ColVecs) +function Base.map(t::FunctionTransform, x::ColVecs) vals = map(axes(x.X, 2)) do i t.f(view(x.X, :, i)) end return ColVecs(reduce(hcat, vals)) end -function _map(t::FunctionTransform, x::RowVecs) +function Base.map(t::FunctionTransform, x::RowVecs) vals = map(axes(x.X, 1)) do i t.f(view(x.X, i, :)) end diff --git a/src/transform/lineartransform.jl b/src/transform/lineartransform.jl index b61ba6a94..52ac39740 100644 --- a/src/transform/lineartransform.jl +++ b/src/transform/lineartransform.jl @@ -34,9 +34,9 @@ end (t::LinearTransform)(x::Real) = vec(t.A * x) (t::LinearTransform)(x::AbstractVector{<:Real}) = t.A * x -_map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * collect(x')) -_map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X) -_map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A') +Base.map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * collect(x')) +Base.map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X) +Base.map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A') function Base.show(io::IO, t::LinearTransform) return print(io::IO, "Linear transform (size(A) = ", size(t.A), ")") diff --git a/src/transform/periodic_transform.jl b/src/transform/periodic_transform.jl index 098262309..53dc435bf 100644 --- a/src/transform/periodic_transform.jl +++ b/src/transform/periodic_transform.jl @@ -27,7 +27,7 @@ dim(t::PeriodicTransform) = 2 (t::PeriodicTransform)(x::Real) = [sinpi(2 * only(t.f) * x), cospi(2 * only(t.f) * x)] -function _map(t::PeriodicTransform, x::AbstractVector{<:Real}) +function Base.map(t::PeriodicTransform, x::AbstractVector{<:Real}) return RowVecs(hcat(sinpi.((2 * only(t.f)) .* x), cospi.((2 * only(t.f)) .* x))) end diff --git a/src/transform/scaletransform.jl b/src/transform/scaletransform.jl index 18923fcc4..c6c62d982 100644 --- a/src/transform/scaletransform.jl +++ b/src/transform/scaletransform.jl @@ -26,9 +26,9 @@ set!(t::ScaleTransform, ρ::Real) = t.s .= [ρ] (t::ScaleTransform)(x) = only(t.s) * x -_map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x -_map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X) -_map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X) +Base.map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x +Base.map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X) +Base.map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X) Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(only(t.s), only(t2.s)) diff --git a/src/transform/selecttransform.jl b/src/transform/selecttransform.jl index 9c83daedc..9ccb6f9b8 100644 --- a/src/transform/selecttransform.jl +++ b/src/transform/selecttransform.jl @@ -25,8 +25,8 @@ duplicate(t::SelectTransform, θ) = t _maybe_unwrap(x) = x _maybe_unwrap(x::AbstractArray{<:Any,0}) = x[] -_map(t::SelectTransform, x::ColVecs) = _wrap(view(x.X, t.select, :), ColVecs) -_map(t::SelectTransform, x::RowVecs) = _wrap(view(x.X, :, t.select), RowVecs) +Base.map(t::SelectTransform, x::ColVecs) = _wrap(view(x.X, t.select, :), ColVecs) +Base.map(t::SelectTransform, x::RowVecs) = _wrap(view(x.X, :, t.select), RowVecs) _wrap(x::AbstractVector{<:Real}, ::Any) = x _wrap(X::AbstractMatrix{<:Real}, ::Type{T}) where {T} = T(X) diff --git a/src/transform/transform.jl b/src/transform/transform.jl index 40ce8c058..53ed7ed3f 100644 --- a/src/transform/transform.jl +++ b/src/transform/transform.jl @@ -5,8 +5,7 @@ Abstract type defining a transformation of the input. """ abstract type Transform end -Base.map(t::Transform, x::AbstractVector) = _map(t, x) -_map(t::Transform, x::AbstractVector) = t.(x) +Base.map(t::Transform, x::AbstractVector) = t.(x) """ IdentityTransform() @@ -16,7 +15,7 @@ Transformation that returns exactly the input. struct IdentityTransform <: Transform end (t::IdentityTransform)(x) = x -_map(::IdentityTransform, x::AbstractVector) = x +Base.map(::IdentityTransform, x::AbstractVector) = x ### TODO Maybe defining adjoints could help but so far it's not working diff --git a/src/zygoterules.jl b/src/zygoterules.jl deleted file mode 100644 index e405a4946..000000000 --- a/src/zygoterules.jl +++ /dev/null @@ -1,13 +0,0 @@ -ZygoteRules.@adjoint function Base.map(t::Transform, X::ColVecs) - return ZygoteRules.pullback(_map, t, X) -end - -ZygoteRules.@adjoint function Base.map(t::Transform, X::RowVecs) - return ZygoteRules.pullback(_map, t, X) -end - -function ZygoteRules._pullback( - cx::AContext, ::typeof(literal_getproperty), x::ColVecs, ::Val{f} -) where {f} - return ZygoteRules._pullback(cx, literal_getfield, x, Val{f}()) -end diff --git a/test/runtests.jl b/test/runtests.jl index a1f5c395a..30ee3d71d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -168,7 +168,6 @@ include("test_utils.jl") include("generic.jl") include("chainrules.jl") - include("zygoterules.jl") @testset "doctests" begin DocMeta.setdocmeta!( diff --git a/test/transform/selecttransform.jl b/test/transform/selecttransform.jl index c9888443c..1f10976ae 100644 --- a/test/transform/selecttransform.jl +++ b/test/transform/selecttransform.jl @@ -117,7 +117,7 @@ ("ColVecs", ColVecs(randn(5, 10))), ("RowVecs", RowVecs(randn(11, 4))), ] - @test KernelFunctions._map(t, x) isa AbstractVector{Float64} + @test map(t, x) isa AbstractVector{Float64} end end end diff --git a/test/zygoterules.jl b/test/zygoterules.jl deleted file mode 100644 index dc3bb98fe..000000000 --- a/test/zygoterules.jl +++ /dev/null @@ -1 +0,0 @@ -@testset "zygoterules" begin end