From a802272f1913d28fac197e23610c7a9a5a481a5d Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Fri, 3 Feb 2023 14:53:59 +0000 Subject: [PATCH 1/5] CompatHelper: bump compat for Bijectors to 0.11, (keep existing compat) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index bb5208bc..26ca4f9d 100644 --- a/Project.toml +++ b/Project.toml @@ -17,7 +17,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [compat] -Bijectors = "0.4.0, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10" +Bijectors = "0.4.0, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.11" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6" DocStringExtensions = "0.8, 0.9" From 9e20c53dfbea938175614d9b6d948b53387ca998 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 3 Feb 2023 14:54:32 +0000 Subject: [PATCH 2/5] Apply suggestions from code review --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 26ca4f9d..edc93b56 100644 --- a/Project.toml +++ b/Project.toml @@ -17,7 +17,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [compat] -Bijectors = "0.4.0, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.11" +Bijectors = "0.11, 0.12" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6" DocStringExtensions = "0.8, 0.9" From e3d58d6534ff12a02c9cf213ee24f2942e614578 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 3 Feb 2023 14:59:49 +0000 Subject: [PATCH 3/5] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index edc93b56..fb1e2025 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,7 @@ Requires = "0.5, 1.0" StatsBase = "0.32, 0.33" StatsFuns = "0.8, 0.9, 1" Tracker = "0.2.3" -julia = "1" +julia = "1.6" [extras] Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" From cd2589387bca5419f474d2be8a1e042b57650be1 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 3 Feb 2023 15:02:00 +0000 Subject: [PATCH 4/5] Update CI.yml --- .github/workflows/CI.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0ac7c326..9731f20c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -25,7 +25,7 @@ jobs: - os: macOS-latest arch: x86 include: - - version: '1.0' + - version: '1.6' os: ubuntu-latest arch: x64 - os: ubuntu-latest @@ -60,4 +60,4 @@ jobs: if: matrix.coverage with: github-token: ${{ secrets.GITHUB_TOKEN }} - path-to-lcov: lcov.info \ No newline at end of file + path-to-lcov: lcov.info From c5404a1947ae72d8ee81d6993c106c3bdb8b7a2a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 3 Feb 2023 16:24:22 +0000 Subject: [PATCH 5/5] bump Bijectors.jl compat bounds and replace forward with rand_and_logjac --- Project.toml | 4 ++-- src/AdvancedVI.jl | 1 + src/advi.jl | 6 +++--- src/utils.jl | 16 ++++++++++++++++ 4 files changed, 22 insertions(+), 5 deletions(-) create mode 100644 src/utils.jl diff --git a/Project.toml b/Project.toml index bb5208bc..4a105b6a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.1.6" +version = "0.2.0" [deps] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" @@ -17,7 +17,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [compat] -Bijectors = "0.4.0, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10" +Bijectors = "0.11, 0.12" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6" DocStringExtensions = "0.8, 0.9" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 09180388..8c0a2e21 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -19,6 +19,7 @@ end const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) include("ad.jl") +include("utils.jl") using Requires function __init__() diff --git a/src/advi.jl b/src/advi.jl index a5f880ce..a9318d95 100644 --- a/src/advi.jl +++ b/src/advi.jl @@ -81,8 +81,8 @@ function (elbo::ELBO)( # = 𝔼[log p(x, f⁻¹(z̃)) + logabsdet(J(f⁻¹(z̃)))] + ℍ(q̃(z̃)) # = 𝔼[log p(x, z) - logabsdetjac(J(f(z)))] + ℍ(q̃(z̃)) - # But our `forward(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac` - _, z, logjac, _ = forward(rng, q) + # But our `rand_and_logjac(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac` + z, logjac = rand_and_logjac(rng, q) res = (logπ(z) + logjac) / num_samples if q isa TransformedDistribution @@ -92,7 +92,7 @@ function (elbo::ELBO)( end for i = 2:num_samples - _, z, logjac, _ = forward(rng, q) + z, logjac = rand_and_logjac(rng, q) res += (logπ(z) + logjac) / num_samples end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 00000000..7f593ca8 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,16 @@ +using Distributions + +using Random: Random +using Bijectors: Bijectors + + +function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution) + x = rand(rng, dist) + return x, zero(eltype(x)) +end + +function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution) + x = rand(rng, dist.dist) + y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x) + return y, logjac +end