Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ScoreELBO objective #72

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions bench/normallognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function Bijectors.bijector(model::NormalLogNormal)
[1:1, 2:1+length(μ_y)])
end

function normallognormal(; fptype, adtype, family, objective, kwargs...)
function normallognormal(; fptype, adtype, family, objective, max_iter=10^3, kwargs...)
n_dims = 10
μ_x = fptype(5.0)
σ_x = fptype(0.3)
Expand All @@ -43,8 +43,7 @@ function normallognormal(; fptype, adtype, family, objective, kwargs...)
binv = inverse(b)
q_transformed = Bijectors.TransformedDistribution(q, binv)

max_iter = 10^3
AdvancedVI.optimize(
return AdvancedVI.optimize(
model,
obj,
q_transformed,
Expand Down
4 changes: 4 additions & 0 deletions bench/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,9 @@ function variational_objective(objective::Symbol; kwargs...)
AdvancedVI.RepGradELBO(kwargs[:n_montecarlo])
elseif objective == :RepGradELBOSTL
AdvancedVI.RepGradELBO(kwargs[:n_montecarlo], entropy=StickingTheLandingEntropy())
elseif objective == :ScoreGradELBO
AdvancedVI.ScoreGradELBO(kwargs[:n_montecarlo])
elseif objective == :ScoreGradELBOSTL
AdvancedVI.ScoreGradELBO(kwargs[:n_montecarlo], entropy=StickingTheLandingEntropy())
end
end
3 changes: 2 additions & 1 deletion ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ function AdvancedVI.reparam_with_entropy(
q_stop::Bijectors.TransformedDistribution,
n_samples::Int,
ent_est::AdvancedVI.AbstractEntropyEstimator,
obj::AdvancedVI.AbstractVariationalObjective,
)
transform = q.transform
q_unconst = q.dist
q_unconst_stop = q_stop.dist

# Draw samples and compute entropy of the uncontrained distribution
unconstr_samples, unconst_entropy = AdvancedVI.reparam_with_entropy(
rng, q_unconst, q_unconst_stop, n_samples, ent_est
rng, q_unconst, q_unconst_stop, n_samples, ent_est, obj
)

# Apply bijector to samples while estimating its jacobian
Expand Down
43 changes: 42 additions & 1 deletion src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,51 @@ Estimate the entropy of `q`.
"""
function estimate_entropy end

export RepGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy
"""
sample_from_q(obj, rng, q, q_stop, n_samples)

Draw `n_samples` from `q`.

# Arguments
- `obj::AbstractVariationalObjective`: Variational objective.
- `rng::Random.AbstractRNG`: Random number generator.
- `q`: Variational approximation.
- `q_stop`: Same as `q`, but held constant during differentiation.
- `n_samples::Int`: Number of Monte Carlo samples

# Returns
- `samples`: Monte Carlo samples generated through reparameterization. Their support matches that of the target distribution.
"""
function sample_from_q end

"""
reparam_with_entropy(rng, q, q_stop, n_samples, ent_est)

Draw `n_samples` from `q` and compute its entropy.

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `q`: Variational approximation.
- `q_stop`: Same as `q`, but held constant during differentiation. Should only be used for computing the entropy.
- `n_samples::Int`: Number of Monte Carlo samples
- `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.)
- `obj`: The variational objective.
# Returns
- `samples`: Monte Carlo samples generated through reparameterization. Their support matches that of the target distribution.
- `entropy`: An estimate (or exact value) of the differential entropy of `q`.
"""
function reparam_with_entropy end

export
RepGradELBO,
ScoreGradELBO,
ClosedFormEntropy,
StickingTheLandingEntropy,
MonteCarloEntropy

include("objectives/elbo/entropy.jl")
include("objectives/elbo/repgradelbo.jl")
include("objectives/elbo/scoregradelbo.jl")

# Variational Families
export MvLocationScale, MeanFieldGaussian, FullRankGaussian
Expand Down
14 changes: 14 additions & 0 deletions src/objectives/elbo/entropy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,17 @@ function estimate_entropy(
-logpdf(q, mc_sample)
end
end

function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop)
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
estimate_entropy(entropy_estimator, samples, q_maybe_stop)
end
arnauqb marked this conversation as resolved.
Show resolved Hide resolved

function reparam_with_entropy(
rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator, obj::AbstractVariationalObjective
)
samples = sample_from_q(obj, rng, q, q_stop, n_samples)
entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop)
return samples, entropy
end

35 changes: 4 additions & 31 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,45 +45,18 @@ function Base.show(io::IO, obj::RepGradELBO)
return print(io, ")")
end

function estimate_entropy_maybe_stl(
entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop
)
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
return estimate_entropy(entropy_estimator, samples, q_maybe_stop)
end

function estimate_energy_with_samples(prob, samples)
return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
end

"""
reparam_with_entropy(rng, q, q_stop, n_samples, ent_est)

Draw `n_samples` from `q` and compute its entropy.

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `q`: Variational approximation.
- `q_stop`: Same as `q`, but held constant during differentiation. Should only be used for computing the entropy.
- `n_samples::Int`: Number of Monte Carlo samples
- `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.)

# Returns
- `samples`: Monte Carlo samples generated through reparameterization. Their support matches that of the target distribution.
- `entropy`: An estimate (or exact value) of the differential entropy of `q`.
"""
function reparam_with_entropy(
rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator
)
samples = rand(rng, q, n_samples)
entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop)
return samples, entropy
function sample_from_q(::RepGradELBO, rng, q, q_stop, n_samples)
return rand(rng, q, n_samples)
end

function estimate_objective(
rng::Random.AbstractRNG, obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samples
)
samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy)
samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy, obj)
energy = estimate_energy_with_samples(prob, samples)
return energy + entropy
end
Expand All @@ -95,7 +68,7 @@ end
function estimate_repgradelbo_ad_forward(params′, aux)
@unpack rng, obj, problem, adtype, restructure, q_stop = aux
q = restructure_ad_forward(adtype, restructure, params′)
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy, obj)
energy = estimate_energy_with_samples(problem, samples)
elbo = energy + entropy
return -elbo
Expand Down
141 changes: 141 additions & 0 deletions src/objectives/elbo/scoregradelbo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""
ScoreGradELBO(n_samples; kwargs...)

Evidence lower-bound objective computed with score function gradients.
```math
\\begin{aligned}
\\nabla_{\\lambda} \\mathrm{ELBO}\\left(\\lambda\\right)
&\\=
\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[
\\log \\pi\\left(z\\right) \\nabla_{\\lambda} \\log q_{\\lambda}(z)
\\right]
+ \\mathbb{H}\\left(q_{\\lambda}\\right),
\\end{aligned}
```

To reduce the variance of the gradient estimator, we use a baseline computed from a running average of the previous ELBO values and subtract it from the objective.

```math
\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[
\\nabla_{\\lambda} \\log q_{\\lambda}(z) \\left(\\pi\\left(z\\right) - \\beta\\right)
\\right]
```

# Arguments
- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO.

# Keyword Arguments
- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: `ClosedFormEntropy()`)
- `baseline_window_size::Int`: The window size to use to compute the baseline. (Default: `10`)
- `baseline_history::Vector{Float64}`: The history of the baseline. (Default: `Float64[]`)

# Requirements
- The variational approximation ``q_{\\lambda}`` implements `rand` and `logpdf`.
- `logpdf(q, x)` must be differentiable with respect to `q` by the selected AD backend.
- The target distribution and the variational approximation have the same support.

Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
"""
struct ScoreGradELBO{EntropyEst <: AbstractEntropyEstimator} <:
AdvancedVI.AbstractVariationalObjective
entropy::EntropyEst
n_samples::Int
baseline_window_size::Int
baseline_history::Vector{Float64}
end

function ScoreGradELBO(
n_samples::Int;
entropy::AbstractEntropyEstimator = ClosedFormEntropy(),
baseline_window_size::Int = 10,
baseline_history::Vector{Float64} = Float64[]
)
ScoreGradELBO(entropy, n_samples, baseline_window_size, baseline_history)
end


function Base.show(io::IO, obj::ScoreGradELBO)
print(io, "ScoreGradELBO(entropy=")
print(io, obj.entropy)
print(io, ", n_samples=")
print(io, obj.n_samples)
print(io, ", baseline_window_size=")
print(io, obj.baseline_window_size)
print(io, ")")
end

function sample_from_q(::ScoreGradELBO, rng, q, q_stop, n_samples)
return rand(rng, q_stop, n_samples)
end

function compute_control_variate_baseline(history, window_size)
if length(history) == 0
return 1.0
end
min_index = max(1, length(history) - window_size)
return mean(history[min_index:end])
end
Comment on lines +71 to +77
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally think that we should make a whole new set of interface for control variates so that people could easily mix and match various control variates. But for now, I think this could be included in the PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed!


function estimate_energy_with_samples(prob, samples_stop, samples_logprob, samples_logprob_stop, baseline)
fv = Base.Fix1(LogDensityProblems.logdensity, prob).(eachsample(samples_stop))
fv_mean = mean(fv)
score_grad = mean(@. samples_logprob * (fv - baseline))
score_grad_stop = mean(@. samples_logprob_stop * (fv - baseline))
return fv_mean + (score_grad - score_grad_stop)
end

function compute_elbo(q, q_stop, samples_stop, entropy, problem, baseline)
samples_logprob = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop))
samples_logprob_stop = logpdf.(Ref(q_stop), AdvancedVI.eachsample(samples_stop))
energy = estimate_energy_with_samples(problem, samples_stop, samples_logprob, samples_logprob_stop, baseline)
elbo = energy + entropy
return elbo
end

function estimate_objective(
rng::Random.AbstractRNG,
obj::ScoreGradELBO,
q,
prob;
n_samples::Int = obj.n_samples,
)
samples, entropy = reparam_with_entropy(rng, q, q, obj.n_samples, obj.entropy, obj)
energy = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
return mean(energy) + entropy
end

function estimate_objective(
obj::ScoreGradELBO, q, prob; n_samples::Int = obj.n_samples)
estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
end

function estimate_scoregradelbo_ad_forward(params′, aux)
@unpack rng, obj, problem, restructure, q_stop = aux
baseline = compute_control_variate_baseline(obj.baseline_history, obj.baseline_window_size)
q = restructure(params′)
samples_stop, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy, obj)
elbo = compute_elbo(q, q_stop, samples_stop, entropy, problem, baseline)
return -elbo
end


function AdvancedVI.estimate_gradient!(
rng::Random.AbstractRNG,
obj::ScoreGradELBO,
adtype::ADTypes.AbstractADType,
out::DiffResults.MutableDiffResult,
prob,
params,
restructure,
state,
)
q_stop = restructure(params)
aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop)
AdvancedVI.value_and_gradient!(
adtype, estimate_scoregradelbo_ad_forward, params, aux, out,
)
nelbo = DiffResults.value(out)
stat = (elbo = -nelbo,)
push!(obj.baseline_history, -nelbo)
out, nothing, stat
end
1 change: 0 additions & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ function optimize(

for t in 1:max_iter
stat = (iteration=t,)

grad_buf, obj_st, stat′ = estimate_gradient!(
rng,
objective,
Expand Down
6 changes: 3 additions & 3 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ if @isdefined(Tapir)
AD_distributionsad[:Tapir] = AutoTapir(; safe_mode=false)
end

if @isdefined(Enzyme)
AD_distributionsad[:Enzyme] = AutoEnzyme()
end
#if @isdefined(Enzyme)
# AD_distributionsad[:Enzyme] = AutoEnzyme()
#end

@testset "inference RepGradELBO DistributionsAD" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
Expand Down
6 changes: 3 additions & 3 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ if @isdefined(Tapir)
AD_locationscale[:Tapir] = AutoTapir(; safe_mode=false)
end

if @isdefined(Enzyme)
AD_locationscale[:Enzyme] = AutoEnzyme()
end
#if @isdefined(Enzyme)
# AD_locationscale[:Enzyme] = AutoEnzyme()
#end

@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
Expand Down
6 changes: 3 additions & 3 deletions test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ if @isdefined(Tapir)
AD_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false)
end

if @isdefined(Enzyme)
AD_locationscale_bijectors[:Enzyme] = AutoEnzyme()
end
#if @isdefined(Enzyme)
# AD_locationscale_bijectors[:Enzyme] = AutoEnzyme()
#end

@testset "inference RepGradELBO VILocationScale Bijectors" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
Expand Down
Loading