diff --git a/Project.toml b/Project.toml index 6322bfa7..3219992c 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ ChainRulesCore = "1.16" DiffResults = "1" Distributions = "0.25.111" DocStringExtensions = "0.8, 0.9" -Enzyme = "0.13" +Enzyme = "0.12.32" FillArrays = "1.3" ForwardDiff = "0.10.36" Functors = "0.4" diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl index 7fa05e56..45b3c547 100644 --- a/ext/AdvancedVIEnzymeExt.jl +++ b/ext/AdvancedVIEnzymeExt.jl @@ -18,10 +18,11 @@ end function AdvancedVI.value_and_gradient!( ::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult ) + Enzyme.API.runtimeActivity!(true) ∇x = DiffResults.gradient(out) fill!(∇x, zero(eltype(∇x))) _, y = Enzyme.autodiff( - Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true), Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, ∇x) + Enzyme.ReverseWithPrimal, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, ∇x) ) DiffResults.value!(out, y) return out @@ -34,10 +35,11 @@ function AdvancedVI.value_and_gradient!( aux, out::DiffResults.MutableDiffResult, ) + Enzyme.API.runtimeActivity!(true) ∇x = DiffResults.gradient(out) fill!(∇x, zero(eltype(∇x))) _, y = Enzyme.autodiff( - Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true) + Enzyme.ReverseWithPrimal, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, ∇x),