diff --git a/src/distributions/wishart.jl b/src/distributions/wishart.jl index 32ad5aad7..090490713 100644 --- a/src/distributions/wishart.jl +++ b/src/distributions/wishart.jl @@ -49,14 +49,16 @@ end function Distributions.mean(::typeof(logdet), distribution::WishartMessage) d = size(distribution, 1) - ν, invS = (distribution.ν, distributions.invS) - return mapreduce(i -> digamma((ν + 1 - i) / 2), +, 1:d) + d * log(2) - logdet(invS) + ν, invS = (distribution.ν, distribution.invS) + T = promote_type(typeof(ν), eltype(invS)) + return mapreduce(i -> digamma((ν + 1 - i) / 2), +, 1:d) + d * log(convert(T, 2)) - logdet(invS) end function Distributions.mean(::typeof(logdet), distribution::Wishart) - d = size(distribution, 1) + d = size(distribution, 1) ν, S = params(distribution) - return mapreduce(i -> digamma((ν + 1 - i) / 2), +, 1:d) + d * log(2) + logdet(S) + T = promote_type(typeof(ν), eltype(S)) + return mapreduce(i -> digamma((ν + 1 - i) / 2), +, 1:d) + d * log(convert(T, 2)) + logdet(S) end function Distributions.mean(::typeof(inv), distribution::WishartDistributionsFamily) diff --git a/src/nodes/mv_normal_mean_precision.jl b/src/nodes/mv_normal_mean_precision.jl index b04a8198f..7303dd2a9 100644 --- a/src/nodes/mv_normal_mean_precision.jl +++ b/src/nodes/mv_normal_mean_precision.jl @@ -34,14 +34,16 @@ end m_out, v_out = mean_cov(q_out) df_Λ, S_Λ = params(q_Λ) # prevent allocation of mean matrix - result = zero(promote_type(eltype(m_mean), eltype(m_out), eltype(S_Λ))) + T = promote_type(eltype(m_mean), eltype(m_out), eltype(S_Λ)) + result = zero(T) @inbounds for k1 in 1:dim, k2 in 1:dim # optimize trace operation (indices can be interchanges because of symmetry) result += S_Λ[k1, k2] * (v_out[k1, k2] + v_mean[k1, k2] + (m_out[k2] - m_mean[k2]) * (m_out[k1] - m_mean[k1])) end + result *= df_Λ - result += dim * log2π + result += dim * convert(T, log2π) result -= mean(logdet, q_Λ) result /= 2 @@ -56,8 +58,10 @@ end m, V = mean_cov(q_out_μ) m_Λ = mean(q_Λ) - result = zero(promote_type(eltype(m), eltype(m_Λ))) - result += dim * log2π + T = promote_type(eltype(m), eltype(m_Λ)) + + result = zero(T) + result += dim * convert(T, log2π) result -= mean(logdet, q_Λ) @inbounds for k1 in 1:dim, k2 in 1:dim # optimize trace operation (indices can be interchanges because of symmetry) diff --git a/src/nodes/normal_mixture.jl b/src/nodes/normal_mixture.jl index 6a86dc3f5..48bbdf291 100644 --- a/src/nodes/normal_mixture.jl +++ b/src/nodes/normal_mixture.jl @@ -139,48 +139,19 @@ end # FreeEnergy related functions -@average_energy NormalMixture (q_out::Any, q_switch::Any, q_m::ManyOf{N, UnivariateGaussianDistributionsFamily}, q_p::ManyOf{N, GammaDistributionsFamily}) where {N} = begin +@average_energy NormalMixture (q_out::Any, q_switch::Any, q_m::ManyOf{N, Any}, q_p::ManyOf{N, Any}) where {N} = begin z_bar = probvec(q_switch) return mapreduce(+, 1:N; init = 0.0) do i - return z_bar[i] * score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[i], q_p[i])), nothing) + return avg_energy_nm(variate_form(q_out), q_out, q_m, q_p, z_bar, i) end end -@average_energy NormalMixture (q_out::Any, q_switch::Any, q_m::NTuple{N, MultivariateGaussianDistributionsFamily}, q_p::NTuple{N, Wishart}) where {N} = begin - z_bar = probvec(q_switch) - return mapreduce(+, 1:N; init = 0.0) do i - return z_bar[i] * score(AverageEnergy(), MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[i], q_p[i])), nothing) - end +function avg_energy_nm(::Type{Univariate}, q_out, q_m, q_p, z_bar, i) + return z_bar[i] * score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[i], q_p[i])), nothing) end -@average_energy NormalMixture (q_out::Any, q_switch::Any, q_m::NTuple{N, PointMass{T} where T <: Real}, q_p::NTuple{N, PointMass{T} where T <: Real}) where {N} = begin - z_bar = probvec(q_switch) - return mapreduce(+, 1:N; init = 0.0) do i - return z_bar[i] * score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[i], q_p[i])), nothing) - end -end - -@average_energy NormalMixture (q_out::Any, q_switch::Any, q_m::ManyOf{N, MultivariateGaussianDistributionsFamily}, q_p::ManyOf{N, Wishart}) where {N} = begin - z_bar = probvec(q_switch) - return mapreduce(+, 1:N; init = 0.0) do i - return z_bar[i] * score(AverageEnergy(), MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[i], q_p[i])), nothing) - end -end - -@average_energy NormalMixture (q_out::Any, q_switch::Any, q_m::ManyOf{N, PointMass{T} where T <: Real}, q_p::ManyOf{N, PointMass{T} where T <: Real}) where {N} = begin - z_bar = probvec(q_switch) - return mapreduce(+, 1:N; init = 0.0) do i - return z_bar[i] * score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[i], q_p[i])), nothing) - end -end - -@average_energy NormalMixture ( - q_out::Any, q_switch::Any, q_m::ManyOf{N, PointMass{T} where T <: AbstractVector}, q_p::ManyOf{N, PointMass{T} where T <: AbstractMatrix} -) where {N} = begin - z_bar = probvec(q_switch) - return mapreduce(+, 1:N; init = 0.0) do i - return z_bar[i] * score(AverageEnergy(), MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[i], q_p[i])), nothing) - end +function avg_energy_nm(::Type{Multivariate}, q_out, q_m, q_p, z_bar, i) + return z_bar[i] * score(AverageEnergy(), MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[i], q_p[i])), nothing) end function score(::Type{T}, ::FactorBoundFreeEnergy, ::Stochastic, node::NormalMixtureNode{N, MeanField}, skip_strategy, scheduler) where {T <: CountingReal, N} diff --git a/src/rules/normal_mixture/m.jl b/src/rules/normal_mixture/m.jl index aa7e47342..557a92b57 100644 --- a/src/rules/normal_mixture/m.jl +++ b/src/rules/normal_mixture/m.jl @@ -1,15 +1,11 @@ export rule -@rule NormalMixture((:m, k), Marginalisation) (q_out::Any, q_switch::Any, q_p::GammaDistributionsFamily) = begin +@rule NormalMixture((:m, k), Marginalisation) (q_out::Any, q_switch::Any, q_p::Any) = begin pv = probvec(q_switch) T = eltype(pv) z_bar = clamp.(pv, tiny, one(T) - tiny) - return NormalMeanVariance(mean(q_out), inv(z_bar[k] * mean(q_p))) -end -@rule NormalMixture((:m, k), Marginalisation) (q_out::Any, q_switch::Any, q_p::Wishart) = begin - pv = probvec(q_switch) - T = eltype(pv) - z_bar = clamp.(pv, tiny, one(T) - tiny) - return MvNormalMeanCovariance(mean(q_out), cholinv(z_bar[k] * mean(q_p))) + F = variate_form(q_out) + + return convert(promote_variate_type(F, NormalMeanPrecision), mean(q_out), z_bar[k] * mean(q_p)) end diff --git a/src/rules/normal_mixture/p.jl b/src/rules/normal_mixture/p.jl index 9f8166f92..a0fdd7718 100644 --- a/src/rules/normal_mixture/p.jl +++ b/src/rules/normal_mixture/p.jl @@ -1,16 +1,17 @@ export rule -@rule NormalMixture((:p, k), Marginalisation) (q_out::Any, q_switch::Any, q_m::UnivariateNormalDistributionsFamily) = begin +@rule NormalMixture((:p, k), Marginalisation) (q_out::Any, q_switch::Any, q_m::Any) = begin m_mean_k, v_mean_k = mean_cov(q_m) m_out, v_out = mean_cov(q_out) z_bar = probvec(q_switch) + + return rule_nm_p_k(variate_form(q_out), m_mean_k, v_mean_k, m_out, v_out, z_bar, k) +end + +function rule_nm_p_k(::Type{Univariate}, m_mean_k, v_mean_k, m_out, v_out, z_bar, k) return GammaShapeRate(one(eltype(z_bar)) + z_bar[k] / 2, z_bar[k] * (v_out + v_mean_k + abs2(m_out - m_mean_k)) / 2) end -@rule NormalMixture((:p, k), Marginalisation) (q_out::Any, q_switch::Any, q_m::MultivariateNormalDistributionsFamily) = begin - m_mean_k, v_mean_k = mean_cov(q_m) - m_out, v_out = mean_cov(q_out) - z_bar = probvec(q_switch) - d = length(m_mean_k) - return WishartMessage(one(eltype(z_bar)) + z_bar[k] + d, z_bar[k] * (v_out + v_mean_k + (m_out - m_mean_k) * (m_out - m_mean_k)')) +function rule_nm_p_k(::Type{Multivariate}, m_mean_k, v_mean_k, m_out, v_out, z_bar, k) + return WishartMessage(one(eltype(z_bar)) + z_bar[k] + length(m_mean_k), z_bar[k] * (v_out + v_mean_k + (m_out - m_mean_k) * (m_out - m_mean_k)')) end diff --git a/src/rules/normal_mixture/switch.jl b/src/rules/normal_mixture/switch.jl index 4179e5f25..9c61580f2 100644 --- a/src/rules/normal_mixture/switch.jl +++ b/src/rules/normal_mixture/switch.jl @@ -8,46 +8,17 @@ export rule # return Bernoulli(clamp(softmax((U1, U2))[1], tiny, 1.0 - tiny)) # end -@rule NormalMixture{N}(:switch, Marginalisation) (q_out::Any, q_m::ManyOf{N, UnivariateNormalDistributionsFamily}, q_p::ManyOf{N, GammaDistributionsFamily}) where {N} = begin +@rule NormalMixture{N}(:switch, Marginalisation) (q_out::Any, q_m::ManyOf{N, Any}, q_p::ManyOf{N, Any}) where {N} = begin U = map(zip(q_m, q_p)) do (m, p) - return -score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, m, p)), nothing) + return rule_nm_switch_k(variate_form(m), q_out, m, p) end return Categorical(clamp!(softmax!(U), tiny, one(eltype(U)) - tiny)) end -@rule NormalMixture{N}(:switch, Marginalisation) (q_out::Any, q_m::NTuple{N, MultivariateNormalDistributionsFamily}, q_p::NTuple{N, Wishart}) where {N} = begin - U = map(zip(q_m, q_p)) do (m, p) - return -score(AverageEnergy(), MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}, map((q) -> Marginal(q, false, false), (q_out, m, p)), nothing) - end - return Categorical(clamp!(softmax!(U), tiny, one(eltype(U)) - tiny)) +function rule_nm_switch_k(::Type{Univariate}, q_out, m, p) + return -score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, m, p)), nothing) end -@rule NormalMixture{N}(:switch, Marginalisation) (q_out::Any, q_m::NTuple{N, PointMass{T} where T <: Real}, q_p::NTuple{N, PointMass{T} where T <: Real}) where {N} = begin - U = map(zip(q_m, q_p)) do (m, p) - return -score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, m, p)), nothing) - end - return Categorical(clamp!(softmax!(U), tiny, one(eltype(U)) - tiny)) -end - -@rule NormalMixture{N}(:switch, Marginalisation) (q_out::Any, q_m::ManyOf{N, MultivariateNormalDistributionsFamily}, q_p::ManyOf{N, Wishart}) where {N} = begin - U = map(zip(q_m, q_p)) do (m, p) - return -score(AverageEnergy(), MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}, map((q) -> Marginal(q, false, false), (q_out, m, p)), nothing) - end - return Categorical(clamp!(softmax!(U), tiny, one(eltype(U)) - tiny)) -end - -@rule NormalMixture{N}(:switch, Marginalisation) (q_out::Any, q_m::ManyOf{N, PointMass{T} where T <: Real}, q_p::ManyOf{N, PointMass{T} where T <: Real}) where {N} = begin - U = map(zip(q_m, q_p)) do (m, p) - return -score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, m, p)), nothing) - end - return Categorical(clamp!(softmax!(U), tiny, one(eltype(U)) - tiny)) -end - -@rule NormalMixture{N}(:switch, Marginalisation) ( - q_out::Any, q_m::ManyOf{N, PointMass{T} where T <: AbstractVector}, q_p::ManyOf{N, PointMass{T} where T <: AbstractMatrix} -) where {N} = begin - U = map(zip(q_m, q_p)) do (m, p) - return -score(AverageEnergy(), MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}, map((q) -> Marginal(q, false, false), (q_out, m, p)), nothing) - end - return Categorical(clamp!(softmax!(U), tiny, one(eltype(U)) - tiny)) +function rule_nm_switch_k(::Type{Multivariate}, q_out, m, p) + return -score(AverageEnergy(), MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}, map((q) -> Marginal(q, false, false), (q_out, m, p)), nothing) end diff --git a/test/nodes/test_normal_mixture.jl b/test/nodes/test_normal_mixture.jl new file mode 100644 index 000000000..447ee4805 --- /dev/null +++ b/test/nodes/test_normal_mixture.jl @@ -0,0 +1,91 @@ +module NodesNormalMixtureTest + +using Test +using ReactiveMP +using Random +using Distributions + +import ReactiveMP: @test_rules +import ReactiveMP: WishartMessage, ManyOf + +@testset "NormalMixtureNode" begin + @testset "AverageEnergy" begin + begin + q_out = NormalMeanVariance(0.0, 1.0) + q_switch = Bernoulli(0.2) + q_m = (NormalMeanVariance(1.0, 2.0), NormalMeanVariance(3.0, 4.0)) + q_p = (GammaShapeRate(2.0, 3.0), GammaShapeRate(4.0, 5.0)) + + marginals = ( + Marginal(q_out, false, false), + Marginal(q_switch, false, false), + ManyOf(map(q_m_ -> Marginal(q_m_, false, false), q_m)), + ManyOf(map(q_p_ -> Marginal(q_p_, false, false), q_p)) + ) + + ref_val = + 0.2 * (score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[1], q_p[1])), nothing)) + + 0.8 * (score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[2], q_p[2])), nothing)) + @test score(AverageEnergy(), NormalMixture, Val{(:out, :switch, :m, :p)}, marginals, nothing) ≈ ref_val + end + + begin + q_out = NormalMeanVariance(1.0, 1.0) + q_switch = Bernoulli(0.4) + q_m = (NormalMeanVariance(3.0, 2.0), NormalMeanVariance(3.0, 4.0)) + q_p = (GammaShapeRate(2.0, 3.0), GammaShapeRate(1.0, 5.0)) + + marginals = ( + Marginal(q_out, false, false), + Marginal(q_switch, false, false), + ManyOf(map(q_m_ -> Marginal(q_m_, false, false), q_m)), + ManyOf(map(q_p_ -> Marginal(q_p_, false, false), q_p)) + ) + + ref_val = + 0.4 * (score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[1], q_p[1])), nothing)) + + 0.6 * (score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[2], q_p[2])), nothing)) + @test score(AverageEnergy(), NormalMixture, Val{(:out, :switch, :m, :p)}, marginals, nothing) ≈ ref_val + end + + begin + q_out = NormalMeanVariance(0.0, 1.0) + q_switch = Categorical([0.5, 0.5]) + q_m = (NormalMeanPrecision(1.0, 2.0), NormalMeanPrecision(3.0, 4.0)) + q_p = (GammaShapeRate(3.0, 3.0), GammaShapeRate(4.0, 5.0)) + + marginals = ( + Marginal(q_out, false, false), + Marginal(q_switch, false, false), + ManyOf(map(q_m_ -> Marginal(q_m_, false, false), q_m)), + ManyOf(map(q_p_ -> Marginal(q_p_, false, false), q_p)) + ) + + ref_val = + 0.5 * (score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[1], q_p[1])), nothing)) + + 0.5 * (score(AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[2], q_p[2])), nothing)) + @test score(AverageEnergy(), NormalMixture, Val{(:out, :switch, :m, :p)}, marginals, nothing) ≈ ref_val + end + + begin + q_out = MvNormalMeanCovariance([0.0], [1.0]) + q_switch = Categorical([0.5, 0.5]) + q_m = (MvNormalMeanPrecision([1.0], [2.0]), MvNormalMeanPrecision([3.0], [4.0])) + q_p = (WishartMessage(3.0, fill(3.0, 1, 1)), WishartMessage(4.0, fill(5.0, 1, 1))) + + marginals = ( + Marginal(q_out, false, false), + Marginal(q_switch, false, false), + ManyOf(map(q_m_ -> Marginal(q_m_, false, false), q_m)), + ManyOf(map(q_p_ -> Marginal(q_p_, false, false), q_p)) + ) + + ref_val = + 0.5 * (score(AverageEnergy(), MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[1], q_p[1])), nothing)) + + 0.5 * (score(AverageEnergy(), MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}, map((q) -> Marginal(q, false, false), (q_out, q_m[2], q_p[2])), nothing)) + @test score(AverageEnergy(), NormalMixture, Val{(:out, :switch, :m, :p)}, marginals, nothing) ≈ ref_val + end + end +end + +end diff --git a/test/rules/normal_mixture/test_m.jl b/test/rules/normal_mixture/test_m.jl new file mode 100644 index 000000000..aa214d68e --- /dev/null +++ b/test/rules/normal_mixture/test_m.jl @@ -0,0 +1,69 @@ +module RulesNormalMixtureMTest + +using Test +using ReactiveMP +using Random +using Distributions + +import ReactiveMP: @test_rules + +@testset "rules:NormalMixture:m" begin + @testset "Variational : (m_out::UnivariateNormalDistributionsFamily..., m_p::GammaDistributionsFamily...) k=1" begin + @test_rules [with_float_conversions = true] NormalMixture{2}((:m, k = 1), Marginalisation) [ + (input = (q_out = NormalMeanVariance(8.5, 0.5), q_switch = Bernoulli(0.2), q_p = GammaShapeRate(1.0, 2.0)), output = NormalMeanPrecision(8.5, 0.1)), + ( + input = (q_out = NormalWeightedMeanPrecision(3 / 10, 6 / 10), q_switch = Categorical([0.5, 0.5]), q_p = GammaShapeRate(1.0, 1.0)), + output = NormalMeanPrecision(0.5, 0.5) + ), + ( + input = (q_out = NormalWeightedMeanPrecision(5.0, 1 / 4), q_switch = Categorical([0.75, 0.25]), q_p = GammaShapeScale(1.0, 1.0)), + output = NormalMeanPrecision(20.0, 0.75) + ), + (input = (q_out = NormalWeightedMeanPrecision(1, 1), q_switch = Categorical([1.0, 0.0]), q_p = GammaShapeRate(1.0, 2.0)), output = NormalMeanPrecision(1.0, 0.5)) + ] + end + + @testset "Variational : (m_out::UnivariateNormalDistributionsFamily..., m_p::GammaDistributionsFamily...) k=2" begin + @test_rules [with_float_conversions = true] NormalMixture{2}((:m, k = 2), Marginalisation) [ + (input = (q_out = NormalMeanVariance(8.5, 0.5), q_switch = Bernoulli(0.2), q_p = GammaShapeRate(1.0, 2.0)), output = NormalMeanPrecision(8.5, 0.4)), + ( + input = (q_out = NormalWeightedMeanPrecision(3 / 10, 6 / 10), q_switch = Categorical([0.5, 0.5]), q_p = GammaShapeRate(1.0, 1.0)), + output = NormalMeanPrecision(0.5, 0.5) + ), + ( + input = (q_out = NormalWeightedMeanPrecision(5.0, 1 / 4), q_switch = Categorical([0.75, 0.25]), q_p = GammaShapeScale(1.0, 1.0)), + output = NormalMeanPrecision(20.0, 0.25) + ) + ] + end + + @testset "Variational : (m_out::MultivariateNormalDistributionsFamily..., m_p::Wishart...) k=1" begin + @test_rules [with_float_conversions = true, atol = 1e-4] NormalMixture{2}((:m, k = 1), Marginalisation) [ + ( + input = ( + q_out = MvNormalWeightedMeanPrecision([6.75, 12.0], [4.5 -0.75; -0.75 4.5]), q_switch = Categorical([0.5, 0.5]), q_p = Wishart(3.0, [2.0 -0.25; -0.25 1.0]) + ), + output = MvNormalMeanPrecision([2.0, 3.0], [3.0 -0.375; -0.375 1.5]) + ), + ( + input = ( + q_out = MvNormalMeanPrecision([3.75, 10.3125], [5.25 -0.75; -0.75 3.75]), q_switch = Categorical([0.75, 0.25]), q_p = Wishart(3.0, [2.0 -0.25; -0.25 1.0]) + ), + output = MvNormalMeanPrecision([3.75, 10.3125], [4.5 -0.5625; -0.5625 2.25]) + ), + ( + input = (q_out = MvNormalMeanPrecision([0.75, 17.25], [3.0 -0.75; -0.75 6.0]), q_switch = Categorical([1.0, 0.0]), q_p = Wishart(3.0, [2.0 -0.25; -0.25 1.0])), + output = MvNormalMeanPrecision([0.75, 17.25], [6.0 -0.75; -0.75 3.0]) + ) + ] + end + + @testset "Variational : (m_out::UnivariateNormalDistributionsFamily..., m_p::GammaDistributionsFamily...) k=1" begin + @test_rules [with_float_conversions = true] NormalMixture{2}((:m, k = 1), Marginalisation) [ + (input = (q_out = PointMass(8.5), q_switch = Bernoulli(0.2), q_p = GammaShapeRate(1.0, 2.0)), output = NormalMeanPrecision(8.5, 0.1)), + (input = (q_out = NormalWeightedMeanPrecision(3 / 10, 6 / 10), q_switch = Categorical([0.5, 0.5]), q_p = PointMass(1.0)), output = NormalMeanPrecision(0.5, 0.5)) + ] + end +end + +end diff --git a/test/rules/normal_mixture/test_p.jl b/test/rules/normal_mixture/test_p.jl new file mode 100644 index 000000000..90ba2a080 --- /dev/null +++ b/test/rules/normal_mixture/test_p.jl @@ -0,0 +1,43 @@ +module RulesNormalMixturePTest + +using Test +using ReactiveMP +using Random +using Distributions + +import ReactiveMP: @test_rules +import ReactiveMP: WishartMessage + +@testset "rules:NormalMixture:p" begin + @testset "Variational : (m_out::UnivariateNormalDistributionsFamily..., m_μ::UnivariateNormalDistributionsFamily...) k=1" begin + @test_rules [with_float_conversions = true] NormalMixture{2}((:p, k = 1), Marginalisation) [ + (input = (q_out = NormalMeanVariance(8.5, 0.5), q_switch = Bernoulli(0.2), q_m = NormalMeanVariance(5.0, 2.0)), output = GammaShapeRate(1.1, 1.475)), + (input = (q_out = NormalMeanVariance(-3, 2.0), q_switch = Bernoulli(0.5), q_m = NormalMeanVariance(5.0, 2.0)), output = GammaShapeRate(1.25, 17.0)) + ] + end + + @testset "Variational : (m_out::MultivariateNormalDistributionsFamily..., m_μ::MultivariateNormalDistributionsFamily...) k=1" begin + @test_rules [with_float_conversions = true, atol = 1e-4] NormalMixture{2}((:p, k = 1), Marginalisation) [ + ( + input = (q_out = MvNormalMeanPrecision([8.5], [0.5]), q_switch = Bernoulli(0.2), q_m = MvNormalMeanPrecision([3.0], [0.1])), + output = WishartMessage(2.2, fill(8.45, 1, 1)) + ), + ( + input = (q_out = MvNormalMeanPrecision([8.5, 5.1], [0.5 0.1; 0.1 4]), q_switch = Bernoulli(0.2), q_m = MvNormalMeanPrecision([3.0, 10], [0.1 0.2; 0.2 -0.3])), + output = WishartMessage(3.2, [9.59487 -5.97148; -5.97148 5.13797]) + ), + ( + input = ( + q_out = MvNormalMeanPrecision([5.0, 8.0], [3 0.5; 0.5 -6]), q_switch = Categorical([0.25, 0.75]), q_m = MvNormalMeanPrecision([2.0, -3.0], [2.1 -1.0; -1.0 3.0]) + ), + output = WishartMessage(3.25, [2.47598 8.29032; 8.29032 30.3902]) + ), + ( + input = (q_out = MvNormalMeanCovariance([-3], [2.0]), q_switch = Bernoulli(0.5), q_m = MvNormalMeanCovariance([5.0], [2.0])), + output = WishartMessage(2.5, fill(34.0, 1, 1)) + ) + ] + end +end + +end diff --git a/test/rules/normal_mixture/test_switch.jl b/test/rules/normal_mixture/test_switch.jl new file mode 100644 index 000000000..0190a466f --- /dev/null +++ b/test/rules/normal_mixture/test_switch.jl @@ -0,0 +1,35 @@ +module RulesNormalMixtureSwitchTest + +using Test +using ReactiveMP +using Random +using Distributions + +import ReactiveMP: @test_rules +import ReactiveMP: WishartMessage + +@testset "rules:NormalMixture:switch" begin + @testset "Variational : (m_out::UnivariateNormalDistributionsFamily..., m_μ::UnivariateNormalDistributionsFamily...) k=1" begin + @test_rules [with_float_conversions = true] NormalMixture{2}(:switch, Marginalisation) [( + input = ( + q_out = NormalMeanVariance(8.5, 0.5), + q_m = ManyOf(NormalMeanVariance(5.0, 2.0), NormalMeanVariance(10.0, 3.0)), + q_p = ManyOf(GammaShapeRate(1.0, 2.0), GammaShapeRate(2.0, 1.0)) + ), + output = Categorical([0.7713458788198754, 0.22865412118012463]) + )] + end + + @testset "Variational : (m_out::MultivariateNormalDistributionsFamily..., m_μ::MultivariateNormalDistributionsFamily...) k=1" begin + @test_rules [with_float_conversions = true, atol = 1e-4] NormalMixture{2}(:switch, Marginalisation) [( + input = ( + q_out = MvNormalMeanCovariance([8.5], [0.5]), + q_m = ManyOf(MvNormalMeanCovariance([5.0], [2.0]), MvNormalMeanCovariance([10.0], [3.0])), + q_p = ManyOf(Wishart(2.0, fill(0.25, 1, 1)), Wishart(4.0, fill(0.5, 1, 1))) + ), + output = Categorical([0.7713458788198754, 0.22865412118012463]) + )] + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 76087e7cd..56a3ab6bd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -268,6 +268,7 @@ end addtests(testrunner, "nodes/test_and.jl") addtests(testrunner, "nodes/test_implication.jl") addtests(testrunner, "nodes/test_uniform.jl") + addtests(testrunner, "nodes/test_normal_mixture.jl") addtests(testrunner, "rules/uniform/test_out.jl") @@ -290,6 +291,9 @@ end addtests(testrunner, "rules/bifm_helper/test_out.jl") addtests(testrunner, "rules/normal_mixture/test_out.jl") + addtests(testrunner, "rules/normal_mixture/test_m.jl") + addtests(testrunner, "rules/normal_mixture/test_p.jl") + addtests(testrunner, "rules/normal_mixture/test_switch.jl") addtests(testrunner, "rules/subtraction/test_marginals.jl") addtests(testrunner, "rules/subtraction/test_in1.jl")