Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Circular References #204

Open
willtebbutt opened this issue Jul 29, 2024 · 4 comments
Open

Circular References #204

willtebbutt opened this issue Jul 29, 2024 · 4 comments
Assignees
Labels
enhancement New feature or request high priority

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Jul 29, 2024

Properly writing this up is motivated by ongoing problems with circular references, highlighted in a PR linked to by #197 . These circular references appear in the testing infrastructure for Turing.jl -- while they could in principle be removed, it's inconvenient to do so, and Tapir.jl ought to be able to handle them.

The Problem

You can construct circular references via seemingly straightforward types as follows:

julia> mutable struct Foo
           x
           y::Float64
       end

julia> foo = Foo(nothing, 5.0)
Foo(nothing, 5.0)

julia> foo.x = foo
Foo(Foo(#= circular reference @-1 =#), 5.0)

Tapir.jl can handle these if they appear inside functions that are being differentiated. For example,

julia> function f(x::Float64)
           foo = Foo(nothing, x)
           @noinline foo.x = foo # no-inline to avoid optimising everything away
           return foo.y
       end

f (generic function with 1 method)

julia> using Tapir

julia> rule = Tapir.build_rrule(f, 5.0);

julia> Tapir.value_and_gradient!!(rule, f, 5.0)
(5.0, (NoTangent(), 1.0))

However, the same is not true if they occur as an argument, or a value returned from, a function being differentiated:

julia> foo = Foo(nothing, 5.0)
Foo(nothing, 5.0)

julia> foo.x = foo
Foo(Foo(#= circular reference @-1 =#), 5.0)

julia> g(foo::Foo) = foo.y
g (generic function with 1 method)

julia> rule = Tapir.build_rrule(g, foo);

julia> Tapir.value_and_gradient!!(rule, g, foo)
ERROR: Found a StackOverFlow error when trying to wrap inputs. This often means that Tapir.jl has encountered a self-referential type. Tapir.jl is not presently able to handle self-referential types, so if you are indeed using a self-referential type somewhere, you will need to refactor to avoid it if you wish to use Tapir.jl.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] __create_coduals(args::Tuple{typeof(g), Foo})
   @ Tapir ~/.julia/packages/Tapir/mkYSZ/src/interface.jl:134
 [3] value_and_gradient!!(::Tapir.DerivedRule{…}, ::Function, ::Foo)
   @ Tapir ~/.julia/packages/Tapir/mkYSZ/src/interface.jl:126
 [4] top-level scope
   @ REPL[18]:1

caused by: StackOverflowError:
Stacktrace:
     [1] zero_tangent(x::Foo)
       @ Tapir ~/.julia/packages/Tapir/mkYSZ/src/tangents.jl:422
     [2] macro expansion
       @ ~/.julia/packages/Tapir/mkYSZ/src/tangents.jl:0 [inlined]--- the last 2 lines are repeated 26660 more times ---
Some type information was truncated. Use `show(err)` to see complete types.

The problem is that neither zero_tangent nor randn_tangent account for the possibility of circular references.

A Solution

Modify zero_tangent and randn_tangent to know about, and correctly deal with, circular references.

The simplest solution is to keep track of all memory addresses / objects with fixed memory addresses which have been allocated during the current call to zero_tangent or randn_tangent, and to avoid generating a new tangent if one already exists for the memory address associated to a primal.
For example, a sketch implementation for zero_tangent

function zero_tangent_internal(x::P, d::IdDict) where {P}
    if ismutabletype(P) && in(x, d)
        return d[x]
    end

    t = ...
    d[x] = t
    return t
end

function zero_tangent(x::P) where {P}
    return zero_tangent_internal(x, IdDict())
end

This strategy mirors exactly the strategy of deepcopy -- see its docstring and Base.deepcopy_internal. A similar strategy is (I believe) used by Enzyme.jl.

A Side-Benefit: Aliasing

This will also ensure that if two arguments to a function alias each other, such as

julia> x = randn(10);

julia> h(x, y) = sum(x .* y)
h (generic function with 1 method)

julia> h(x, x)
12.11240397077079

that provided we call zero_tangent or randn_tangent once for all arguments, the correct result will emerge.

Performance Concerns

This additional work necessarily has some overhead associated to it in general. However, for bits types, there is no risk of circular referencing. For such types, the checks are not necessary, and the IdDict allocation can be avoided. In this case, the performance will be identical to how it currently is.

Moreover, the above IdDict implementation is naive. We should just use pointer_from_objref to obtain a Ptr{Nothing} which points to the address associated to a given object. We could therefore use a Dict{Ptr{Nothing}, Any} to store the tangents.

Furthermore, when querying an element from such a Dict, we can assert that the value returned is of type tangent_type(P), where P is the primal type, thus avoiding propagating any type instabilities.

Types which refer to themselves are a different problem

This problem is distinct from that associated to types such as

julia> mutable struct Bar
           x::Union{Nothing, Bar}
       end

julia> tangent_type(Bar)
ERROR: StackOverflowError:
@yebai
Copy link
Contributor

yebai commented Aug 1, 2024

IIRC, we improved the error message for self-referential zero tangent types in #144. Is that fundamentally different to circular references? I am asking because the error handling mechanism didn't work for circular references.

@willtebbutt
Copy link
Member Author

You are correct that we improved the error messages -- if you take a look at the example in this issue, you'll see that the error message is the improved one. The stack overflow e.g. here seems to have circumvented it though, I think it must be because it doesn't go via value_and_gradient!!.

Self-referential vs circular references is a mistake on my part -- I should have been calling them circular references the whole time.

@yebai
Copy link
Contributor

yebai commented Aug 1, 2024

Thanks for the clarification. The solution in #144 indeed does not seem to be general enough and deserves improvements.

@willtebbutt
Copy link
Member Author

We would just have to ensure that the error that gets generated via value_and_gradient!! also gets generated in more places. Turing.jl still goes via the LogDensityProblemsAD interface, rather than DI.jl, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request high priority
Projects
None yet
Development

No branches or pull requests

3 participants