Skip to content

Commit

Permalink
Merge branch 'master' into torfjelde/namedtuple-initial-params
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Jul 22, 2024
2 parents 14bb2cf + 36008f9 commit 8b79ac0
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.28.1"
version = "0.28.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
40 changes: 40 additions & 0 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,46 @@ function getcontext(f::LogDensityFunction)
return f.context === nothing ? leafcontext(f.model.context) : f.context
end

"""
getmodel(f)
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
"""
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) =
getmodel(LogDensityProblemsAD.parent(f))
getmodel(f::DynamicPPL.LogDensityFunction) = f.model

"""
setmodel(f, model[, adtype])
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
!!! warning
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
might require recompilation of the gradient tape, depending on the AD backend.
"""
function setmodel(
f::LogDensityProblemsAD.ADGradientWrapper,
model::DynamicPPL.Model,
adtype::ADTypes.AbstractADType,
)
# TODO: Should we handle `SciMLBase.NoAD`?
# For an `ADGradientWrapper` we do the following:
# 1. Update the `Model` in the underlying `LogDensityFunction`.
# 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype`
# to ensure that the recompilation of gradient tapes, etc. also occur. For example,
# ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just
# replacing the corresponding field with the new model won't be sufficient to obtain
# the correct gradients.
return LogDensityProblemsAD.ADgradient(
adtype, setmodel(LogDensityProblemsAD.parent(f), model)
)
end
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return Accessors.@set f.model = model
end

# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time
# we need to define these annoying methods to ensure that we stay compatible with everything.
getsampler(f::LogDensityFunction) = getsampler(getcontext(f))
Expand Down
22 changes: 21 additions & 1 deletion test/logdensityfunction.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
using Test, DynamicPPL, LogDensityProblems
using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff

@testset "`getmodel` and `setmodel`" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
= DynamicPPL.LogDensityFunction(model)
@test DynamicPPL.getmodel(ℓ) == model
@test DynamicPPL.setmodel(ℓ, model).model == model

# ReverseDiff related
∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(false))
@test DynamicPPL.getmodel(∇ℓ) == model
@test DynamicPPL.getmodel(DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())) ==
model
∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(true))
new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())
@test DynamicPPL.getmodel(new_∇ℓ) == model
# HACK(sunxd): rely on internal implementation detail, i.e., naming of `compiledtape`
@test new_∇ℓ.compiledtape != ∇ℓ.compiledtape
end
end

@testset "LogDensityFunction" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
Expand Down
1 change: 1 addition & 0 deletions test/turing/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down

0 comments on commit 8b79ac0

Please sign in to comment.