Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_map -> Base.map #453

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.37"
version = "0.10.38"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -19,7 +19,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ChainRulesCore = "1"
Expand All @@ -34,5 +33,4 @@ Requires = "1.0.1"
SpecialFunctions = "0.8, 0.9, 0.10, 1, 2"
StatsBase = "0.32, 0.33"
TensorCore = "0.1"
ZygoteRules = "0.2"
julia = "1.3"
2 changes: 0 additions & 2 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ using IrrationalConstants: logtwo, twoπ, invsqrt2
using LogExpFunctions: softplus
using StatsBase
using TensorCore
using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield

# Hack to work around Zygote type inference problems.
const Distances_pairwise = Distances.pairwise
Expand Down Expand Up @@ -122,7 +121,6 @@ include("mokernels/intrinsiccoregion.jl")
include("mokernels/lmm.jl")

include("chainrules.jl")
include("zygoterules.jl")

include("test_utils.jl")

Expand Down
16 changes: 8 additions & 8 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,37 +82,37 @@ end
# Kernel matrix operations

function kernelmatrix_diag!(K::AbstractVector, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x))
return kernelmatrix_diag!(K, κ.kernel, map(κ.transform, x))
end

function kernelmatrix_diag!(
K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix_diag!(K, κ.kernel, map(κ.transform, x), map(κ.transform, y))
end

function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix!(K, κ.kernel, _map(κ.transform, x))
return kernelmatrix!(K, κ.kernel, map(κ.transform, x))
end

function kernelmatrix!(
K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix!(K, κ.kernel, map(κ.transform, x), map(κ.transform, y))
end

function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector)
return kernelmatrix_diag(κ.kernel, _map(κ.transform, x))
return kernelmatrix_diag(κ.kernel, map(κ.transform, x))
end

function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix_diag(κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix_diag(κ.kernel, map(κ.transform, x), map(κ.transform, y))
end

function kernelmatrix(κ::TransformedKernel, x::AbstractVector)
return kernelmatrix(κ.kernel, _map(κ.transform, x))
return kernelmatrix(κ.kernel, map(κ.transform, x))
end

function kernelmatrix(κ::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix(κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
return kernelmatrix(κ.kernel, map(κ.transform, x), map(κ.transform, y))
end
6 changes: 3 additions & 3 deletions src/transform/ardtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ dim(t::ARDTransform) = length(t.v)
(t::ARDTransform)(x::Real) = only(t.v) * x
(t::ARDTransform)(x) = t.v .* x

_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
Base.map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
Base.map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
Base.map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)

Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)

Expand Down
2 changes: 1 addition & 1 deletion src/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transfor

(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)

function _map(t::ChainTransform, x::AbstractVector)
function Base.map(t::ChainTransform, x::AbstractVector)
return foldl((x, t) -> map(t, x), t.transforms; init=x)
end

Expand Down
6 changes: 3 additions & 3 deletions src/transform/functiontransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ end

(t::FunctionTransform)(x) = t.f(x)

_map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
Base.map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)

function _map(t::FunctionTransform, x::ColVecs)
function Base.map(t::FunctionTransform, x::ColVecs)
vals = map(axes(x.X, 2)) do i
t.f(view(x.X, :, i))
end
return ColVecs(reduce(hcat, vals))
end

function _map(t::FunctionTransform, x::RowVecs)
function Base.map(t::FunctionTransform, x::RowVecs)
vals = map(axes(x.X, 1)) do i
t.f(view(x.X, i, :))
end
Expand Down
6 changes: 3 additions & 3 deletions src/transform/lineartransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ end
(t::LinearTransform)(x::Real) = vec(t.A * x)
(t::LinearTransform)(x::AbstractVector{<:Real}) = t.A * x

_map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * collect(x'))
_map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
_map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
Base.map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * collect(x'))
Base.map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
Base.map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')

function Base.show(io::IO, t::LinearTransform)
return print(io::IO, "Linear transform (size(A) = ", size(t.A), ")")
Expand Down
2 changes: 1 addition & 1 deletion src/transform/periodic_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dim(t::PeriodicTransform) = 2

(t::PeriodicTransform)(x::Real) = [sinpi(2 * only(t.f) * x), cospi(2 * only(t.f) * x)]

function _map(t::PeriodicTransform, x::AbstractVector{<:Real})
function Base.map(t::PeriodicTransform, x::AbstractVector{<:Real})
return RowVecs(hcat(sinpi.((2 * only(t.f)) .* x), cospi.((2 * only(t.f)) .* x)))
end

Expand Down
6 changes: 3 additions & 3 deletions src/transform/scaletransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ set!(t::ScaleTransform, ρ::Real) = t.s .= [ρ]

(t::ScaleTransform)(x) = only(t.s) * x

_map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x
_map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X)
_map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X)
Base.map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x
Base.map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X)
Base.map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X)

Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(only(t.s), only(t2.s))

Expand Down
4 changes: 2 additions & 2 deletions src/transform/selecttransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ duplicate(t::SelectTransform, θ) = t
_maybe_unwrap(x) = x
_maybe_unwrap(x::AbstractArray{<:Any,0}) = x[]

_map(t::SelectTransform, x::ColVecs) = _wrap(view(x.X, t.select, :), ColVecs)
_map(t::SelectTransform, x::RowVecs) = _wrap(view(x.X, :, t.select), RowVecs)
Base.map(t::SelectTransform, x::ColVecs) = _wrap(view(x.X, t.select, :), ColVecs)
Base.map(t::SelectTransform, x::RowVecs) = _wrap(view(x.X, :, t.select), RowVecs)

_wrap(x::AbstractVector{<:Real}, ::Any) = x
_wrap(X::AbstractMatrix{<:Real}, ::Type{T}) where {T} = T(X)
Expand Down
5 changes: 2 additions & 3 deletions src/transform/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ Abstract type defining a transformation of the input.
"""
abstract type Transform end

Base.map(t::Transform, x::AbstractVector) = _map(t, x)
_map(t::Transform, x::AbstractVector) = t.(x)
Base.map(t::Transform, x::AbstractVector) = t.(x)

"""
IdentityTransform()
Expand All @@ -16,7 +15,7 @@ Transformation that returns exactly the input.
struct IdentityTransform <: Transform end

(t::IdentityTransform)(x) = x
_map(::IdentityTransform, x::AbstractVector) = x
Base.map(::IdentityTransform, x::AbstractVector) = x

### TODO Maybe defining adjoints could help but so far it's not working

Expand Down
13 changes: 0 additions & 13 deletions src/zygoterules.jl

This file was deleted.

1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ include("test_utils.jl")

include("generic.jl")
include("chainrules.jl")
include("zygoterules.jl")

@testset "doctests" begin
DocMeta.setdocmeta!(
Expand Down
2 changes: 1 addition & 1 deletion test/transform/selecttransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
("ColVecs", ColVecs(randn(5, 10))),
("RowVecs", RowVecs(randn(11, 4))),
]
@test KernelFunctions._map(t, x) isa AbstractVector{Float64}
@test map(t, x) isa AbstractVector{Float64}
end
end
end
1 change: 0 additions & 1 deletion test/zygoterules.jl

This file was deleted.