Skip to content

Commit

Permalink
Add ProjectTo(::Any) = identity (#458)
Browse files Browse the repository at this point in the history
* add ProjectTo(::Any) = identity

* Apply 3 suggestions

Co-authored-by: Lyndon White <[email protected]>

Co-authored-by: Lyndon White <[email protected]>
  • Loading branch information
mcabbott and oxinabox committed Sep 21, 2021
1 parent 1893e82 commit 0e560c6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
13 changes: 10 additions & 3 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ _maybe_call(f, x) = f
Returns a `ProjectTo{T}` functor which projects a differential `dx` onto the
relevant tangent space for `x`.
At present this undersands only `x::Number`, `x::AbstractArray` and `x::Ref`.
It should not be called on arguments of an `rrule` method which accepts other types.
Custom `ProjectTo` methods are provided for many subtypes of `Number` (to e.g. ensure precision),
and `AbstractArray` (to e.g. ensure sparsity structure is maintained by tangent).
Called on unknown types it will (as of v1.5.0) simply return `identity`, thus can be safely
applied to arbitrary `rrule` arguments.
# Examples
```jldoctest
Expand Down Expand Up @@ -112,7 +114,7 @@ julia> ProjectTo([1 2; 3 4]') # no special structure, integers are promoted to
ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(2), Base.OneTo(2)))
```
"""
ProjectTo(::Any) # just to attach docstring
ProjectTo(::Any) = identity

# Generic
(::ProjectTo{T})(dx::AbstractZero) where {T} = dx
Expand Down Expand Up @@ -143,6 +145,11 @@ ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = Project
# Bool
ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above

# Other never-differentiable types
for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle)
@eval ProjectTo(::$T) = ProjectTo{NoTangent}()
end

# Numbers
ProjectTo(::Real) = ProjectTo{Real}()
ProjectTo(::Complex) = ProjectTo{Complex}()
Expand Down
21 changes: 17 additions & 4 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Base.real(x::Dual) = x
Base.float(x::Dual) = Dual(float(x.value), float(x.partial))
Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))

# Trivial struct
struct NoSuperType end

@testset "projection" begin

#####
Expand All @@ -24,7 +27,6 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
@test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im
@test ProjectTo(2.0)(1+1im) === 1.0


# storage
@test ProjectTo(1)(pi) === pi
@test ProjectTo(1 + im)(pi) === ComplexF64(pi)
Expand Down Expand Up @@ -94,9 +96,10 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
@test y1[1] == [1 2]
@test !(y1 isa Adjoint) && !(y1[1] isa Adjoint)

# arrays of unknown things
@test_throws MethodError ProjectTo([:x, :y])
@test_throws MethodError ProjectTo(Any[:x, :y])
# arrays of other things
@test ProjectTo([:x, :y]) isa ProjectTo{NoTangent}
@test ProjectTo(Any['x', "y"]) isa ProjectTo{NoTangent}
@test ProjectTo([(1,2), (3,4), (5,6)]) isa ProjectTo{AbstractArray}

@test ProjectTo(Any[1, 2])(1:2) == [1.0, 2.0] # projects each number.
@test Tuple(ProjectTo(Any[1, 2 + 3im])(1:2)) === (1.0, 2.0 + 0.0im)
Expand Down Expand Up @@ -140,6 +143,12 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}
end

@testset "Base: non-diff" begin
@test ProjectTo(:a)(1) == NoTangent()
@test ProjectTo('b')(2) == NoTangent()
@test ProjectTo("cde")(345) == NoTangent()
end

#####
##### `LinearAlgebra`
#####
Expand Down Expand Up @@ -301,6 +310,10 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
##### `ChainRulesCore`
#####

@testset "pass-through" begin
@test ProjectTo(NoSuperType()) === identity
end

@testset "AbstractZero" begin
pz = ProjectTo(ZeroTangent())
pz(0) == NoTangent()
Expand Down

0 comments on commit 0e560c6

Please sign in to comment.