diff --git a/Project.toml b/Project.toml index 46b2bc046..253736676 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,7 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.3" + +version = "0.24.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 494fb0e47..2b28b44a9 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,6 +14,15 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false +# Allows samplers, etc. to hook into the final logp accumulation in the tilde-pipeline. +function acclogp_assume!!(context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp!!(context, vi, logp) +end + +function acclogp_observe!!(context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp!!(context, vi, logp) +end + # assume """ tilde_assume(context::SamplingContext, right, vn, vi) @@ -115,7 +124,7 @@ probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) value, logp, vi = tilde_assume(context, right, vn, vi) - return value, acclogp!!(context, vi, logp) + return value, acclogp_assume!!(context, vi, logp) end # observe @@ -181,7 +190,7 @@ probability of `vi` with the returned value. """ function tilde_observe!!(context, right, left, vi) logp, vi = tilde_observe(context, right, left, vi) - return left, acclogp!!(context, vi, logp) + return left, acclogp_observe!!(context, vi, logp) end function assume(rng, spl::Sampler, dist) @@ -383,7 +392,7 @@ Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. """ function dot_tilde_assume!!(context, right, left, vn, vi) value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) - return value, acclogp!!(context, vi, logp), vi + return value, acclogp_assume!!(context, vi, logp), vi end # `dot_assume` @@ -539,7 +548,8 @@ function get_and_set_val!( if istrans(vi) push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,)) # NOTE: Need to add the correction. - acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r))) + # FIXME: This is not great. + acclogp_assume!!(vi, sum(logabsdetjac.(bijector.(dists), r))) # `push!!` sets the trans-flag to `false` by default. settrans!!.((vi,), true, vns) else @@ -634,7 +644,7 @@ Falls back to `dot_tilde_observe(context, right, left, vi)`. """ function dot_tilde_observe!!(context, right, left, vi) logp, vi = dot_tilde_observe(context, right, left, vi) - return left, acclogp!!(context, vi, logp) + return left, acclogp_observe!!(context, vi, logp) end # Falls back to non-sampler definition.