Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with Zygote AD #1516

Open
Uroc327 opened this issue Mar 11, 2022 · 3 comments
Open

Compatibility with Zygote AD #1516

Uroc327 opened this issue Mar 11, 2022 · 3 comments

Comments

@Uroc327
Copy link

Uroc327 commented Mar 11, 2022

It seems that only ForwardDiff is supported for ADing sampling. Is it possible to implement rules for Zygote as well?

julia> ForwardDiff.derivative(x -> sum(rand(Normal(x, 10), 10)), 0.)
10.0

julia> Zygote.gradient(x -> sum(rand(Normal(x, 10), 10)), 0.)
(nothing,)
@trahflow
Copy link

Is it possible to implement rules for Zygote as well?

It is possible, e.g. (just a quick hack):

function ChainRulesCore.rrule(::typeof(rand), d::Normal{T}, n::Integer) where {T}
   vals = rand(d, n)
   function rand_pullback(rand_bar)
       d_bar = Tangent{Normal{T}}(;μ=n, σ=sum((vals .- d.μ)) / d.σ)
       return NoTangent(), d_bar, NoTangent()
   end
   return vals, rand_pullback
end

But I wonder why a dedicated rule is necessary at all. Looking at the the definition for rand for Normal, this should be easily differentiable. I don't really see why Zygote returns nothing here

@devmotion
Copy link
Member

devmotion commented Mar 12, 2023

It is due to https://github.com/JuliaDiff/ChainRules.jl/blob/158ca756ef99ccf3f1dde2e66b5855e8e68e0363/src/rulesets/Random/random.jl#L23-L25. It is a deliberate design decision to mark rand methods as non-differentiable in ChainRules (which is used by Zygote). They were made more restrictive to explicitly not cover e.g. rand(Normal()) (see e.g. the discussion in JuliaDiff/ChainRules.jl#262). However, I assume the problem with the example above is that sampling multiple samples in Distributions is done via pre-allocating the output and sampling with rand! - and that one is marked non-differentiable, regardless of the type of the arguments. See also JuliaDiff/ChainRules.jl#603.

@trahflow
Copy link

trahflow commented Mar 12, 2023

sampling multiple samples in Distributions is done via pre-allocating the output and sampling with rand!

Ah, I missed that. Then everything makes senes.

Sort of a meta question: Where would one put AD rules for such things?
Here in the ext/DistributionsChainRulesCoreExt module?
Also found TuringLang/DistributionsAD.jl#123 but not sure if that Package is still meant to be used?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants