Skip to content

Commit

Permalink
Optimise istrans
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Sep 5, 2024
1 parent bf73fd0 commit 2a0a939
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 4 deletions.
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.29"
version = "0.29.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -30,6 +30,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
Expand All @@ -38,6 +39,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]
DynamicPPLTapirExt = ["Tapir"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
Expand All @@ -63,14 +65,16 @@ OrderedCollections = "1"
Random = "1.6"
Requires = "1"
ReverseDiff = "1"
Tapir = "0.2.44"
Test = "1.6"
ZygoteRules = "0.2"
julia = "1.6"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
19 changes: 19 additions & 0 deletions ext/DynamicPPLTapirExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module DynamicPPLTapirExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using Tapir: Tapir
else
using ..DynamicPPL: DynamicPPL
using ..Tapir: Tapir
end

using Tapir: DefaultCtx, CoDual, NoPullback, primal, zero_fcodual

# This is purely an optimisation.
Tapir.@is_primitive DefaultCtx Tuple{typeof(DynamicPPL.istrans), Vararg}
function Tapir.rrule!!(f::CoDual{typeof(DynamicPPL.istrans)}, x::Vararg{CoDual, N}) where {N}
return zero_fcodual(DynamicPPL.istrans(map(primal, x)...)), NoPullback(f, x...)
end

end # module
4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Accessors = "0.1"
ADTypes = "0.2, 1"
AbstractMCMC = "5"
AbstractPPL = "0.8.2"
Accessors = "0.1"
Bijectors = "0.13"
Compat = "4.3.0"
Distributions = "0.25"
Expand All @@ -43,6 +44,7 @@ MCMCChains = "6.0.4"
MacroTools = "0.5.5"
ReverseDiff = "1"
StableRNGs = "1"
Tapir = "0.2.44"
Tracker = "0.2.23"
Zygote = "0.6"
julia = "1.6"
9 changes: 9 additions & 0 deletions test/ext/DynamicPPLTapirExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testset "DynamicPPLTapirExt" begin
Tapir.TestUtils.test_rule(
Xoshiro(123), istrans, VarInfo();
perf_flag=:none,
interface_only=true,
is_primitive=true,
interp=Tapir.TapirInterpreter(),
)
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using ForwardDiff
using LogDensityProblems, LogDensityProblemsAD
using MacroTools
using MCMCChains
using Tapir
using Tracker
using ReverseDiff
using Zygote
Expand Down Expand Up @@ -68,6 +69,7 @@ include("test_util.jl")

@testset "ad" begin
include("ext/DynamicPPLForwardDiffExt.jl")
include("ext/DynamicPPLTapirExt.jl")
include("ad.jl")
end

Expand Down

0 comments on commit 2a0a939

Please sign in to comment.