Skip to content

Commit

Permalink
Merge pull request #186 from sebapersson/AddPigeonExt
Browse files Browse the repository at this point in the history
Add PigeonExt
  • Loading branch information
sebapersson committed Mar 13, 2024
2 parents 14602f2 + a650d52 commit 9ce8c4c
Show file tree
Hide file tree
Showing 13 changed files with 185 additions and 11 deletions.
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Pigeons = "0eb8d820-af6a-4919-95ae-11206f830c31"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Expand All @@ -52,6 +53,7 @@ PEtabLogDensityProblemsExtension = ["LogDensityProblems", "LogDensityProblemsAD"
PEtabMCMCChainsExtension = ["MCMCChains"]
PEtabOptimExtension = ["Optim"]
PEtabOptimizationExtension = ["Optimization"]
PEtabPigeonsExtension = ["LogDensityProblems", "LogDensityProblemsAD", "Bijectors", "MCMCChains", "Pigeons"]
PEtabPlotsExtension = ["Plots"]
PEtabPyCallExtension = ["PyCall"]
PEtabSciMLSensitivityExtension = ["SciMLSensitivity", "Zygote"]
Expand Down Expand Up @@ -82,6 +84,7 @@ Optimization = "3"
OptimizationOptimJL = "0.2"
OrdinaryDiffEq = "6"
Plots = "1"
Pigeons = "0.4"
PreallocationTools = "0.4"
PrecompileTools = "1"
Printf = "1"
Expand Down Expand Up @@ -123,4 +126,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Bijectors", "Test", "SafeTestsets", "SciMLSensitivity", "Zygote", "Ipopt", "Optim", "Plots", "Optimization", "OptimizationOptimJL", "FiniteDifferences", "PyCall", "LogDensityProblems", "LogDensityProblemsAD", "AdaptiveMCMC", "AdvancedHMC", "MCMCChains"]
test = ["Aqua", "Bijectors", "Test", "SafeTestsets", "SciMLSensitivity",
"Zygote", "Ipopt", "Optim", "Plots", "Optimization", "OptimizationOptimJL",
"FiniteDifferences", "PyCall", "LogDensityProblems", "LogDensityProblemsAD",
"AdaptiveMCMC", "AdvancedHMC", "MCMCChains", "Pigeons"]
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PEtab = "48d54b35-e43e-4a66-a5a1-dde6b987cf69"
Pigeons = "0eb8d820-af6a-4919-95ae-11206f830c31"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"

Expand All @@ -31,5 +32,6 @@ ModelingToolkit = "8"
Optim = "1"
OrdinaryDiffEq = "6"
PEtab = "2"
Pigeons = "0.4"
Plots = "1"
StatsPlots = "0.15"
1 change: 1 addition & 0 deletions docs/src/API_choosen.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ get_odesol
solve_all_conditions
compute_runtime_accuracy
PEtabLogDensity
PEtabPigeonReference
to_prior_scale
to_chains
```
40 changes: 36 additions & 4 deletions docs/src/HMC.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

When fitting a model with PEtab.jl the unknown model parameters are estimated within a frequentist framework, and the goal is to find the maximum likelihood estimate. When prior knowledge about the parameters is available, Bayesian inference is an alternative approach to fitting the model to data. The aim with Bayesian inference is to infer the posterior distribution of unknown parameters given the data, ``p(\theta | y)`` by running Markov chain Monte Carlo (MCMC) algorithm that samples from the Posterior.

PEtab.jl supports Bayesian inference via two packages:
PEtab.jl supports Bayesian inference via three packages:

- **Adaptive Metropolis Hastings Samplers** available in [AdaptiveMCMC.jl](https://github.com/mvihola/AdaptiveMCMC.jl)
- **Hamiltonian Monte Carlo (HMC) Samplers**: available in [AdvancedHMC.jl](https://github.com/TuringLang/AdvancedHMC.jl). Here the default choice is the NUTS sampler, which is used by [Turing.jl](https://github.com/TuringLang/Turing.jl), and is also the default in Stan. HMC samplers are often more efficient than other methods.
- **Adaptive Parallel Tempering Samplers**: available in [Pigeons.jl](https://github.com/Julia-Tempering/Pigeons.jl). Parallel tempering is most suitable for multi-modal (it can jump modes) or non-identifiable posteriors.

This document covers how to create a `PEtabODEProblem` with priors, and how to use both [AdaptiveMCMC.jl](https://github.com/mvihola/AdaptiveMCMC.jl) and [AdvancedHMC.jl](https://github.com/TuringLang/AdvancedHMC.jl) for Bayesian inference.
This document covers how to create a `PEtabODEProblem` with priors, and how to use [AdaptiveMCMC.jl](https://github.com/mvihola/AdaptiveMCMC.jl), [AdvancedHMC.jl](https://github.com/TuringLang/AdvancedHMC.jl), and [Pigeons.jl](https://github.com/Julia-Tempering/Pigeons.jl) for Bayesian inference.

!!! note
To use the Bayesian inference functionality in PEtab.jl, the Bijectors, LogDensityProblems, and LogDensityProblemsAD packages must be loaded.
To use the Bayesian inference functionality in PEtab.jl, the Bijectors, LogDensityProblems, and LogDensityProblemsAD packages must be loaded. For parallel tempering Pigeons and MCMCChains must also be loaded.

## Setting up a Bayesian inference problems

Expand Down Expand Up @@ -147,7 +148,7 @@ plot(chain_hmc)
!!! note
When converting the output to a `MCMCChains` the parameters are transformed to the prior-scale (inference scale).

## Bayesian inference with AdaptiveMCMC.jl (NUTS)
## Bayesian inference with AdaptiveMCMC.jl

Given a starting point we can run the robust adaptive MCMC sampler with $200 \, 000$ by:

Expand All @@ -167,3 +168,34 @@ plot(chain_adapt)
```

Any other algorithm found in AdaptiveMCMC.jl [documentation](https://github.com/mvihola/AdaptiveMCMC.jl) can also be used.

## Bayesian inference with Pigeons.jl

When using a parallel tempering algorithm an easy to sample from reference distribution that covers a wide range of parameter space is needed to, for example, jump between modes. For a `PEtabODEProblem`, this reference is the prior distribution, and to set it up do:

```@example 1; ansicolor=false
reference = PEtabPigeonReference(petab_problem)
nothing #hide
```

Given a starting point we can now take $2^{10}$ adaptive parallel tempering samples by:

```@example 1; ansicolor=false
using Pigeons
Random.seed!(123)
target.initial_value .= xinference
pt = pigeons(target = target,
reference = reference,
n_rounds=10, # 2^10 samples
record = [traces; record_default()])
nothing #hide
```

and we can convert the output to a `MCMCChains`

```@example 1; ansicolor=false
pt_chain = to_chains(pt, target)
plot(pt_chain)
```

Here we used the default `SliceSampler` as local explorer. Alternative explorers, such as an adaptive MALA, can be used instead by setting `explorer=AutoMALA()` in the `pigeons` function call. More information can be found in Pigeons.jl [documentation](https://pigeons.run/dev/).
1 change: 1 addition & 0 deletions ext/PEtabLogDensityProblemsExtension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ include(joinpath(@__DIR__, "PEtabLogDensityProblemsExtension", "Common.jl"))
include(joinpath(@__DIR__, "PEtabLogDensityProblemsExtension", "Init_structs.jl"))
include(joinpath(@__DIR__, "PEtabLogDensityProblemsExtension", "Likelihood.jl"))
include(joinpath(@__DIR__, "PEtabLogDensityProblemsExtension", "LogDensityProblem.jl"))
include(joinpath(@__DIR__, "PEtabLogDensityProblemsExtension", "Pigeons.jl"))
include(joinpath(@__DIR__, "PEtabLogDensityProblemsExtension", "Prior.jl"))

end
2 changes: 1 addition & 1 deletion ext/PEtabLogDensityProblemsExtension/Init_structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function PEtab.PEtabLogDensity(petab_problem::PEtabODEProblem)::PEtab.PEtabLogDe

# For via autodiff compute the gradient of the prior and Jacobian correction
_prior_correction = (x_inference) -> let inference_info = inference_info
prior = compute_prior(x_inference, inference_info)
prior = PEtab.compute_prior(x_inference, inference_info)
correction = Bijectors.logabsdetjac(inference_info.inv_bijectors, x_inference)
return prior + correction
end
Expand Down
4 changes: 2 additions & 2 deletions ext/PEtabLogDensityProblemsExtension/LogDensityProblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function _logtarget(x_inference::AbstractVector{T}, compute_nllh::Function,
inference_info::PEtab.InferenceInfo)::T where {T <: Real}
# Logposterior with Jacobian correction for transformed parameters
logtarget = PEtab.compute_llh(x_inference, compute_nllh, inference_info)
logtarget += compute_prior(x_inference, inference_info)
logtarget += PEtab.compute_prior(x_inference, inference_info)
logtarget += Bijectors.logabsdetjac(inference_info.inv_bijectors, x_inference)
return logtarget
end
Expand All @@ -24,7 +24,7 @@ function _logtarget_gradient(x_inference::AbstractVector{T}, _nllh_gradient::Fun
nllh, logtarget_grad = _nllh_gradient(x_nllh)

# Logposterior with Jacobian correction for transformed parameters
logtarget = nllh * -1 + compute_prior(x_inference, inference_info)
logtarget = nllh * -1 + PEtab.compute_prior(x_inference, inference_info)
logtarget += Bijectors.logabsdetjac(inference_info.inv_bijectors, x_inference)

# Gradient with transformation correction
Expand Down
7 changes: 7 additions & 0 deletions ext/PEtabLogDensityProblemsExtension/Pigeons.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
LogDensityProblems.dimension(lp::PEtab.PEtabPigeonReference) = lp.dim

LogDensityProblems.logdensity(lp::PEtab.PEtabPigeonReference, x) = lp(x)

function PEtab.get_correction(logreference::PEtab.PEtabPigeonReference, x)
return Bijectors.logabsdetjac(logreference.inference_info.inv_bijectors, x)
end
4 changes: 2 additions & 2 deletions ext/PEtabLogDensityProblemsExtension/Prior.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function compute_prior(x_inference::AbstractVector{T},
inference_info::PEtab.InferenceInfo)::T where {T <: Real}
function PEtab.compute_prior(x_inference,
inference_info::PEtab.InferenceInfo)
logpdf_prior = 0.0
x_prior = inference_info.inv_bijectors(x_inference)
for (i, prior) in pairs(inference_info.priors)
Expand Down
58 changes: 58 additions & 0 deletions ext/PEtabPigeonsExtension.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
module PEtabPigeonsExtension

using ModelingToolkit
using Distributions
using PEtab
using Pigeons
using Bijectors
using LogDensityProblems
using LogDensityProblemsAD
using MCMCChains

function Pigeons.initialization(log_potential::PEtab.PEtabLogDensity, rng, ::Int)
return deepcopy(log_potential.initial_value)
end

function LogDensityProblemsAD.ADgradient(::Symbol, log_potential::PEtab.PEtabLogDensity,
buffers::Pigeons.Augmentation)
Pigeons.BufferedAD(log_potential, buffers)
end

function LogDensityProblems.logdensity_and_gradient(logpotential::Pigeons.BufferedAD{<:PEtabLogDensity},
x)
logdens, grad = logpotential.enclosed.logtarget_gradient(x)
logpotential.buffer .= grad
return logdens, logpotential.buffer
end

function Pigeons.sample_iid!(log_prior::PEtab.PEtabPigeonReference, replica, shared)
@unpack state, rng = replica
sample_iid!(state, rng, log_prior.inference_info)
end

function sample_iid!(state::AbstractVector, rng,
inference_info::PEtab.InferenceInfo)::Nothing
for i in eachindex(state)
state[i] = rand(rng, inference_info.priors[i])
end
state .= inference_info.bijectors(state)
return nothing
end

function PEtab.to_chains(res::Pigeons.PT, target::PEtab.PEtabLogDensity;
start_time = nothing, end_time = nothing)
# Dependent on method
inference_info = target.inference_info
out = sample_array(res)[:, 1:(end - 1), :]
for i in 1:size(out)[1]
out[i, :, :] .= inference_info.inv_bijectors(out[i, :, 1])
end
if isnothing(start_time) || isnothing(end_time)
return MCMCChains.Chains(out, inference_info.parameters_id)
else
_chain = MCMCChains.Chains(out, inference_info.parameters_id)
return MCMCChains.setinfo(_chain, (start_time = start_time, stop_time = end_time))
end
end

end
6 changes: 5 additions & 1 deletion src/PEtab.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ export PEtabModel, PEtabODEProblem, ODESolver, SteadyStateSolver, PEtabModel,
PEtabODEProblem, remake_PEtab_problem, Fides, PEtabOptimisationResult, IpoptOptions,
IpoptOptimiser, PEtabParameter, PEtabObservable, PEtabMultistartOptimisationResult,
generate_startguesses, get_ps, get_u0, get_odeproblem, get_odesol, PEtabEvent,
PEtabLogDensity, solve_all_conditions, compute_runtime_accuracy
PEtabLogDensity, solve_all_conditions, compute_runtime_accuracy, PEtabPigeonReference

# These are given as extensions, but their docstrings are availble in the
# general documentation
Expand All @@ -100,7 +100,10 @@ export calibrate_model, calibrate_model_multistart, run_PEtab_select
function get_obs_comparison_plots end
export get_obs_comparison_plots

# Functions that only appear in extension
function compute_llh end
function compute_prior end
function get_correction end
function correct_gradient! end

"""
Expand Down Expand Up @@ -142,6 +145,7 @@ if !isdefined(Base, :get_extension)
include(joinpath(@__DIR__, "..", "ext", "PEtabSciMLSensitivityExtension.jl"))
include(joinpath(@__DIR__, "..", "ext", "PEtabLogDensityProblemsExtension.jl"))
include(joinpath(@__DIR__, "..", "ext", "PEtabPlotsExtension.jl"))
include(joinpath(@__DIR__, "..", "ext", "PEtabPigeonsExtension.jl"))
end

export to_chains, to_prior_scale
Expand Down
30 changes: 30 additions & 0 deletions src/Structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -927,3 +927,33 @@ struct PEtabLogDensity{T <: InferenceInfo,
initial_value::Vector{T2}
dim::I
end
function (logpotential::PEtab.PEtabLogDensity)(x)
return logpotential.logtarget(x)
end

"""
PEtabPigeonReference(prob::PEtabODEProblem)
Construct a reference distribution (prior) for parallell tempering with Pigeon.jl
This LogDensityProblem method defines everything needed to perform Bayesian inference
with libraries such as `AdvancedHMC.jl` (which includes algorithms like NUTS, used by `Turing.jl`),
`AdaptiveMCMC.jl` for adaptive Markov Chain Monte Carlo methods, and `Pigeon.jl` for parallel tempering
methods. For examples on how to perform inference, see the documentation.
"""
struct PEtabPigeonReference{T <: InferenceInfo,
I <: Integer}
inference_info::T
dim::I
end
function PEtabPigeonReference(petab_problem::PEtabODEProblem)
inference_info = InferenceInfo(petab_problem)
return PEtabPigeonReference(inference_info, petab_problem.n_parameters_esimtate)
end
function (logreference::PEtab.PEtabPigeonReference)(x)
# Correction must occur in Prior as in the Prior/reference the value is not
# scaled by a temperatur
logprior = PEtab.compute_prior(x, logreference.inference_info)
correction = get_correction(logreference, x)
return logprior + correction
end
33 changes: 33 additions & 0 deletions test/Inference.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using PEtab, OrdinaryDiffEq, ModelingToolkit, Distributions, Random, DataFrames, CSV
using Bijectors, LogDensityProblems, LogDensityProblemsAD, MCMCChains
using AdaptiveMCMC, AdvancedHMC
using Pigeons

function get_reference_stats(path_data)
# Reference chain 10000 samples via Turing of HMC
Expand Down Expand Up @@ -97,6 +98,38 @@ end
@test reference_stats.nt.std[2] adaptive_stats.nt.std[2] atol=1e-2
@test reference_stats.nt.std[3] adaptive_stats.nt.std[3] atol=1e-2
end

# Parallell tempering (with AutoMALA the slowest)
# Setup with Pigeon.jl
Random.seed!(123)
log_potential = PEtabLogDensity(petab_problem)
log_potential.initial_value .= xinference
reference_potential = PEtabPigeonReference(petab_problem)
pt = pigeons(target = log_potential,
reference = reference_potential,
n_rounds=10,
record = [traces; record_default()])
pt_chain = to_chains(pt, log_potential)
pt_stats = summarystats(pt_chain)
@testset "Parallell tempering" begin
@test reference_stats.nt.mean[1] pt_stats.nt.mean[1] atol=2e-1
@test reference_stats.nt.mean[2] pt_stats.nt.mean[2] atol=2e-1
@test reference_stats.nt.mean[3] pt_stats.nt.mean[3] atol=1e-2
@test reference_stats.nt.std[1] pt_stats.nt.std[1] atol=5e-1
@test reference_stats.nt.std[2] pt_stats.nt.std[2] atol=1e-2
@test reference_stats.nt.std[3] pt_stats.nt.std[3] atol=1e-2
end

# Check AutoMALA runs
Random.seed!(123)
pt = pigeons(target = log_potential,
reference = reference_potential,
explorer=AutoMALA(),
n_rounds=6,
record = [traces; record_default()])
pt_chain = to_chains(pt, log_potential)
pt_stats = summarystats(pt_chain)
@test reference_stats.nt.mean[3] pt_stats.nt.mean[3] atol=1e-2
end

@testset "Check inference transformed parameters" begin
Expand Down

0 comments on commit 9ce8c4c

Please sign in to comment.