Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in reparameterization with Bijectors.TransformedDistribution #52

Merged
merged 2 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,45 @@ else
using ..Random
end

function AdvancedVI.reparam_with_entropy(
rng ::Random.AbstractRNG,
q ::Bijectors.TransformedDistribution,
q_stop ::Bijectors.TransformedDistribution,
n_samples::Int,
ent_est ::AdvancedVI.AbstractEntropyEstimator
)
transform = q.transform
q_base = q.dist
q_base_stop = q_stop.dist
base_samples = rand(rng, q_base, n_samples)
it = AdvancedVI.eachsample(base_samples)
sample_init = first(it)
function transform_samples_with_jacobian(unconst_samples, transform, n_samples)
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
unconst_iter = AdvancedVI.eachsample(unconst_samples)
unconst_init = first(unconst_iter)

samples_init, logjac_init = with_logabsdet_jacobian(transform, unconst_init)

samples_and_logjac = mapreduce(
AdvancedVI.catsamples_and_acc,
Iterators.drop(it, 1);
init=with_logabsdet_jacobian(transform, sample_init)
Iterators.drop(unconst_iter, 1);
init=(AdvancedVI.samples_expand_dim(samples_init), logjac_init)
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
) do sample
with_logabsdet_jacobian(transform, sample)
end
samples = first(samples_and_logjac)
logjac = last(samples_and_logjac)
logjac = last(samples_and_logjac)/n_samples
samples, logjac
end

entropy_base = AdvancedVI.estimate_entropy_maybe_stl(
ent_est, base_samples, q_base, q_base_stop
function AdvancedVI.reparam_with_entropy(
rng ::Random.AbstractRNG,
q ::Bijectors.TransformedDistribution,
q_stop ::Bijectors.TransformedDistribution,
n_samples::Int,
ent_est ::AdvancedVI.AbstractEntropyEstimator
)
transform = q.transform
q_unconst = q.dist
q_unconst_stop = q_stop.dist

# Draw samples and compute entropy of the uncontrained distribution
unconst_samples, unconst_entropy = AdvancedVI.reparam_with_entropy(
rng, q_unconst, q_unconst_stop, n_samples, ent_est
)

entropy = entropy_base + logjac/n_samples
# Apply bijector to samples while estimating its jacobian
samples, logjac = transform_samples_with_jacobian(
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
unconst_samples, transform, n_samples
)
entropy = unconst_entropy + logjac
samples, entropy
end
end
4 changes: 4 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ function catsamples_and_acc(
return (x, ∑y)
end

function samples_expand_dim(x::AbstractVector)
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
reshape(x, (:,1))
end

11 changes: 6 additions & 5 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ using Test
(modelname, modelconstr) ∈ Dict(
:Normal=> normal_meanfield,
),
(objname, objective) ∈ Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) ∈ Dict(
:ForwarDiff => AutoForwardDiff(),
Expand All @@ -33,7 +34,7 @@ using Test
q0 = TuringDiagMvNormal(μ0, diag(L0))

@testset "convergence" begin
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
Expand All @@ -45,7 +46,7 @@ using Test
L = sqrt(cov(q))
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ ≤ Δλ₀/T^(1/4)
@test Δλ ≤ Δλ0/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand Down
17 changes: 9 additions & 8 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ using Test

@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype [Float64, Float32],
(modelname, modelconstr) Dict(
realtype in [Float64, Float32],
(modelname, modelconstr) in Dict(
:Normal=> normal_meanfield,
:Normal=> normal_fullrank,
),
(objname, objective) ∈ Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
(adbackname, adbackend) in Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
Expand All @@ -37,7 +38,7 @@ using Test
end

@testset "convergence" begin
Δλ₀ = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
Expand All @@ -49,7 +50,7 @@ using Test
L = q.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ ≤ Δλ₀/T^(1/4)
@test Δλ ≤ Δλ0/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand Down
17 changes: 9 additions & 8 deletions test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ using Test

@testset "inference RepGradELBO VILocationScale Bijectors" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype [Float64, Float32],
(modelname, modelconstr) Dict(
realtype in [Float64, Float32],
(modelname, modelconstr) in Dict(
:NormalLogNormalMeanField => normallognormal_meanfield,
),
(objname, objective) ∈ Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
(adbackname, adbackend) in Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
Expand Down Expand Up @@ -42,7 +43,7 @@ using Test
q0_z = Bijectors.transformed(q0_η, b⁻¹)

@testset "convergence" begin
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
q, stats, _ = optimize(
rng, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
Expand All @@ -54,7 +55,7 @@ using Test
L = q.dist.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ ≤ Δλ₀/T^(1/4)
@test Δλ ≤ Δλ0/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand Down
Loading