Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ProjectTo on non-differential types? #442

Closed
mzgubic opened this issue Aug 18, 2021 · 11 comments · Fixed by #458
Closed

Allow ProjectTo on non-differential types? #442

mzgubic opened this issue Aug 18, 2021 · 11 comments · Fixed by #458
Labels
design Requires some desgin before changes are made ProjectTo related to the projection functionality

Comments

@mzgubic
Copy link
Member

mzgubic commented Aug 18, 2021

At the moment we only define ProjectTo for differential types. (With Ref being the first exception)

Consider a generic rrule(*, a::Number, b::Number) which uses ProjectTo to ensure that the tangents are in the right subspace, i.e. something like

julia> function rrule(*, a::Number, b::Number)
           function times_pullback(dy)
               da = dy * b
               db = a * dy
               return NoTangent(), ProjectTo(a)(da), ProjectTo(b)(db)
           end
           return a*b, times_pullback
       end

which looks perfectly reasonable.

However, if we create a type like

julia> struct PositiveReal <: Number
           val::Float64
           PositiveReal(x) = x > 0 ? new(x) : error("must be larger than 0")
       end

which is not its own differential type (the natural differential for this is a Float64) we are in trouble.

The problem is that since we only promise ProjectTo to project onto valid differential types, so we can't just define

julia> function ProjectTo(x::PositiveReal)
           return ProjectTo(x.val)
       end

since PositiveReal is not a valid differential type (does not have a zero). For similar reason we do not define ProjectTo(::Tuple), which would solve issues like #440.

The question is: should we loosen this requirement to only project onto differential types? By keeping the requirement we are restricting the use of ProjectTo to functions with arguments that are their own differentials. What bad things happen if we scratch this ProjectTo requirement?

@mzgubic mzgubic added ProjectTo related to the projection functionality design Requires some desgin before changes are made labels Aug 18, 2021
@mcabbott
Copy link
Member

mcabbott commented Aug 18, 2021

Maybe worth saying that the reason for this restriction was just that we weren't completely sure, and leaving it an error in version 1.0 meant we could make it do anything later.

I do think it would be nice if it could be applied more widely. One argument for this came up in the discussion of Functors.jl integration. If Zygote applies ProjectTo to everything, even if this is trivial for unknown types, then this lets you command it to apply a particular projection to your type. Whereas if it only applies ProjectTo to known-safe arrays etc, then you cannot as easily over-ride that.

For this weird number type, I think the defaults will work, i.e. allow any Real as its tangent. I think the projector above with ProjectTo(x.val) will just ensure the tangent is Float64: ProjectTo(PositiveReal(pi))(-33) === -33.0.

For non-number, non-array types, the easy class is non-differentiable things like strings & Symbols. The other class is things which are treated as a struct, and thus I think their tangents will always be dx::Tangent. Perhaps the first question is: Is this correct, do these two cases exhaust the possibilities?

The original design of ProjectTo was going to recurse into arbitrary structs. I think that a current version of that idea would look something like this:

function ProjectTo(x)
    projectors = map(ProjectTo, backing(x))
    # filter(p -> !(p isa ProjectTo{<:AbstractZero}), projectors)  # not sure this will work
    return ProjectTo{Tangent}(; projectors...)
end

function (project::ProjectTo{Tangent}, dx::Tangent{T}) where T
    projected = map(_maybe_call, backing(project), backing(dx))  # but more careful about missing fields
    Tangent{T}(; projected...)
end

But the question then is whether we need that at all, or should simply trust that anything inside a Tangent has already been appropriately projected upstream. In which case you just need something like ProjectTo(::Any) = identity.

Are there functions which accept a Tuple, for which we want to define a rule, where we want projection to happen to all elements? Maybe functions which accept a Pair, ditto?

@oxinabox
Copy link
Member

special case of this is #440

@mcabbott
Copy link
Member

In that case, the gradient of sum(::Tuple) does not currently apply projection:

julia> gradient(abs∘sum, (true, 0f0 + im))
((0.70710677f0 + 0.70710677f0im, 0.70710677f0 + 0.70710677f0im),)

But if it did, then projection on the array of tuples would not be necessary.

@oxinabox
Copy link
Member

oxinabox commented Sep 14, 2021

So the thing we are seeing right now is where ProjectTo would be applied to something we don't support with it,
a x isa Union{Number, AbstractArray{<:Number}} ? ProjectTo(x) : identity is being inserted.
E,g, #448
and JuliaDiff/ChainRules.jl#526
This is strictly worse than the outcome we would get if we made unexpected types in ProjectTo act as identity.
Since now we screw up arrays of arrays etc.

The ideal behavour would be if we could detect that code would MethodError eventually and then it would be as if no rrule had been befined (which ideally would be a MethodError of its own in this ideal world)
We can't live in that world. Maybe in Julia 2.0 or CR 2.0 if we get smart.

So the pragmatic solution seems to be to make it generically fallback to identity

@mcabbott
Copy link
Member

Next question is whether it should be literally identity, or something like ProjectTo{Any} or ProjectTo{Tangent}? Defined as a pass-through, but perhaps allowing T < project_type(proj) to tell you something?

@oxinabox
Copy link
Member

pro of identity is that we can clearly spot it in the type system or stack-trace etc as being "Not handled".

pro of ProjectTo{Any} is we could overload it further if, idk we wanted to condense Vector{<:AbstractZero} with it still.

ProjectTo{Tangent} is I think a different thing.
That i expect to handle natural to structural conversions.
cf #449

@mcabbott
Copy link
Member

if, idk we wanted to condense Vector{<:AbstractZero} with it still.

If the gradient is a vector, then the primal was too, right? So it will have a nontrivial projector. Even for x::Array{Any}.

Right now we do S <: T ? dy : map(project.element, dy) only for arrays of numbers. Which have at widest ProjectTo{Number}() and this won't change.

Sounds like identity is simplest.

If ProjectTo(x) doesn't always return this struct, then we could also consider changing to ProjectTo(::Bool) = Returns(NoTangent()), might be clearer? Not sure.

@mzgubic
Copy link
Member Author

mzgubic commented Sep 20, 2021

How would ProjectTo{Any} look like? Just being a custom identity function? I guess we could do identity for now and then change it to ProjectTo{Any} in case we wanted to compress Vector{<:AbstractZero}?

@oxinabox
Copy link
Member

in case we wanted to compress Vector{<:AbstractZero}?

Not a real example to be clear.
I guess a more realistic one might be to compress Tangents with Zero elements.

@willtebbutt
Copy link
Member

This is strictly worse than the outcome we would get if we made unexpected types in ProjectTo act as identity.

I'm not sure it's strictly worse. I agree it would make some things work that ought to, but it would presumably also make some things that ought not?

Otherwise I agree that the identity option feels like a pragmatic solution for now.

@oxinabox
Copy link
Member

oxinabox commented Sep 21, 2021

x isa Union{Number, AbstractArray{<:Number}} ? ProjectTo(x) : identity
is strictly worse than having identity as the generic fallback if no ProjectTo(x).
As the above code will not run the projection e.g. for ProjectTo(::AbstractArray{<:AbstractArray{<:AbstractArray}}}) which we can actually handle well.

(It is inferior to if ProjectTo would MethodError then to not use that rule, but julia doesn't make that easy right now)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Requires some desgin before changes are made ProjectTo related to the projection functionality
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants