Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic rewrite of the package 2023 edition Part I: ADVI #49

Merged
merged 213 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from 208 commits
Commits
Show all changes
213 commits
Select commit Hold shift + click to select a range
b49cf3e
refactor ADVI, change gradient operation interface
Red-Portal Mar 14, 2023
88e0b79
remove unused file, remove unused dependency
Red-Portal Mar 14, 2023
c2fb3f8
fix ADVI elbo computation more efficiently
Red-Portal Mar 15, 2023
83161fd
fix missing entropy regularization term
Red-Portal Mar 15, 2023
efa8106
add LogDensityProblem interface
Red-Portal Mar 18, 2023
4ae2fbf
refactor use bijectors directly instead of transformed distributions
Red-Portal Mar 18, 2023
2bf2a42
Merge branch 'master' of https://github.com/TuringLang/AdvancedVI.jl …
Red-Portal Jun 7, 2023
1cadb51
fix type restrictions
Red-Portal Jun 7, 2023
3474e8d
remove unused file
Red-Portal Jun 7, 2023
03a2767
fix use of with_logabsdet_jacobian
Red-Portal Jun 8, 2023
09c44fb
restructure project; move the main VI routine to its own file
Red-Portal Jun 8, 2023
b7407ce
remove redundant import
Red-Portal Jun 8, 2023
4040149
restructure project into more modular objective estimators
Red-Portal Jun 8, 2023
2a4514e
migrate to AbstractDifferentiation
Red-Portal Jun 9, 2023
93a16d8
add location scale pre-packaged variational family, add functors
Red-Portal Jun 9, 2023
2b6e9eb
Revert "migrate to AbstractDifferentiation"
Red-Portal Jun 10, 2023
1bfec36
fix use optimized MvNormal specialization, add logpdf for Loc.Scale.
Red-Portal Jun 10, 2023
1003606
remove dead code
Red-Portal Jun 10, 2023
60a9987
fix location-scale logpdf
Red-Portal Jun 10, 2023
cd84f02
add sticking-the-landing (STL) estimator
Red-Portal Jun 10, 2023
768641b
migrate to Optimisers.jl
Red-Portal Jun 10, 2023
ca02fa3
remove execution time measurement (replace later with somethin else)
Red-Portal Jun 10, 2023
a48377f
fix use multiple dispatch for deciding whether to stop entropy grad.
Red-Portal Jun 12, 2023
0b40ccf
add termination decision, callback arguments
Red-Portal Jun 12, 2023
21db3fb
add Base.show to modules
Red-Portal Jun 12, 2023
25c51b4
add interface calling `restructure`, rename rebuild -> restructure
Red-Portal Jun 12, 2023
fc20046
add estimator state interface, add control variate interface to ADVI
Red-Portal Jun 12, 2023
6faa807
fix `show(advi)` to show control variate
Red-Portal Jun 12, 2023
7095d27
fix simplify `show(advi.control_variate)`
Red-Portal Jun 12, 2023
9169ae2
fix type piracy by wrapping location-scale bijected distribution
Red-Portal Jun 12, 2023
3db7301
remove old AdvancedVI custom optimizers
Red-Portal Jun 26, 2023
e6a082a
fix Location Scale to not depend on Bijectors
Red-Portal Jun 26, 2023
a034ebd
fix RNG namespace
Red-Portal Jul 12, 2023
e19abd3
fix location scale logpdf bug
Red-Portal Jul 13, 2023
680c186
add Accessors dependency
Red-Portal Jul 13, 2023
6c3efa8
Merge branch 'master' of https://github.com/TuringLang/AdvancedVI.jl …
Red-Portal Jul 13, 2023
4c6cabf
add location scale, autodiff tests
Red-Portal Jul 13, 2023
06db2f0
add Accessors import statement
Red-Portal Jul 13, 2023
12de2bd
remove optimiser tests
Red-Portal Jul 13, 2023
bbb2cc6
refactor slightly generalize the distribution tests for the future
Red-Portal Jul 13, 2023
1974846
migrate to SimpleUnPack, migrate to ADTypes
Red-Portal Jul 13, 2023
19c62c8
rename vi.jl to optimize.jl
Red-Portal Jul 13, 2023
63da51d
fix estimate_gradient to use adtypes
Red-Portal Jul 13, 2023
65ab473
add exact inference tests
Red-Portal Jul 13, 2023
3e5a452
remove Turing dependency in tests
Red-Portal Jul 13, 2023
3117cec
remove unused projection
Red-Portal Jul 14, 2023
b1ca9cf
remove redundant `ADVIEnergy` object (now baked into `ADVI`)
Red-Portal Jul 14, 2023
fcbb729
add more tests, fix rng seed for tests
Red-Portal Jul 14, 2023
0f6f6a4
add more tests, fix seed for tests
Red-Portal Jul 14, 2023
f5f5863
fix non-determinism bug
Red-Portal Jul 14, 2023
ade0d10
fix test hyperparameters so that tests pass, minor cleanups
Red-Portal Jul 14, 2023
0caf7a9
fix minor reorganization
Red-Portal Jul 14, 2023
5658cbf
add missing files
Red-Portal Jul 14, 2023
c712a97
fix add missing file, rename adbackend argument
Red-Portal Jul 14, 2023
bee839d
fix errors
Red-Portal Jul 14, 2023
913911e
rename test suite
Red-Portal Jul 14, 2023
d50cabb
refactor renamed arguments for ADVI to be shorter
Red-Portal Jul 15, 2023
b134f70
fix compile error in advi test
Red-Portal Jul 15, 2023
a6ba379
add initial doc
Red-Portal Jul 15, 2023
619b1c0
remove unused epsilon argument in location scale
Red-Portal Jul 15, 2023
f1c02f0
add project file for documenter
Red-Portal Jul 15, 2023
b0f259a
refactor STL gradient calculation to use multiple dispatch
Red-Portal Jul 16, 2023
b72c258
fix type bugs, relax test threshold for the exact inference tests
Red-Portal Jul 16, 2023
a8df9eb
refactor derivative utils to match NormalizingFlows.jl with extras
Red-Portal Aug 13, 2023
e8db6a7
add documentation, refactor optimize
Red-Portal Aug 13, 2023
65a2b37
fix bug missing extension
Red-Portal Aug 13, 2023
1a02051
remove tracker from tests
Red-Portal Aug 13, 2023
d8b5ea5
remove export for internal derivative utils
Red-Portal Aug 13, 2023
818bc2c
fix test errors, old interface
Red-Portal Aug 13, 2023
215abf3
fix wrong derivative interface, add documentation
Red-Portal Aug 13, 2023
88ad768
update documentation
Red-Portal Aug 13, 2023
e66935b
add doc build CI
Red-Portal Aug 13, 2023
9f1c647
remove convergence criterion for now
Red-Portal Aug 13, 2023
c8b3ee3
remove outdated export
Red-Portal Aug 13, 2023
afda1a1
update documentation
Red-Portal Aug 13, 2023
0d37ace
update documentation
Red-Portal Aug 13, 2023
b8b113d
update documentation
Red-Portal Aug 13, 2023
b78e713
fix type error in test
Red-Portal Aug 16, 2023
a0564b5
remove default ADType argument
Red-Portal Aug 16, 2023
3795d1e
update README
Red-Portal Aug 17, 2023
28a35bc
update make getting started example actually run Julia
Red-Portal Aug 17, 2023
620b38e
fix remove Float32 tests for inference tests
Red-Portal Aug 17, 2023
fa53398
update version
Red-Portal Aug 17, 2023
e909f41
add documentation publishing url
Red-Portal Aug 17, 2023
43f5b75
fix wrong uuid for ForwardDiff
Red-Portal Aug 17, 2023
468d5ca
Update CI.yml
yebai Aug 17, 2023
c07a511
refactor use `sum` and `mean` instead of abusing `mapreduce`
Red-Portal Aug 17, 2023
8256df1
Merge branch 'rewriting_advancedvi' of github.com:Red-Portal/Advanced…
Red-Portal Aug 17, 2023
13a8a44
remove tests for `FullMonteCarlo`
Red-Portal Aug 17, 2023
aadf8d3
add tests for the `optimize` interface
Red-Portal Aug 18, 2023
8c4e13d
fix turn off Zygote tests for now
Red-Portal Aug 18, 2023
0b708e6
remove unused function
Red-Portal Aug 18, 2023
be61acd
refactor change bijector field name, simplify STL estimator
Red-Portal Aug 18, 2023
fb519a5
update documentation
Red-Portal Aug 18, 2023
8682fd9
update STL documentation
Red-Portal Aug 18, 2023
9a16ee1
update STL documentation
Red-Portal Aug 18, 2023
fc74afa
update location scale documentation
Red-Portal Aug 18, 2023
4be30a1
fix README
Red-Portal Aug 19, 2023
c58309d
fix math in README
Red-Portal Aug 19, 2023
5b5bd3e
add gradient to arguments of callback!, remove `gradient_norm` info
Red-Portal Aug 20, 2023
967021d
fix math in README.md
Red-Portal Aug 21, 2023
4dab522
fix type constraint in `ZygoteExt`
Red-Portal Aug 21, 2023
8ab2f19
fix import of `Random`
Red-Portal Aug 21, 2023
83dec9f
refactor `__init__()`
Red-Portal Aug 21, 2023
a3e563c
fix type constraint in definition of `value_and_gradient!`
Red-Portal Aug 21, 2023
5553bb9
refactor `ZygoteExt`; use `only` instead of `first`
Red-Portal Aug 21, 2023
79b4557
refactor type constraint in `ReverseDiffExt`
Red-Portal Aug 21, 2023
656b44b
refactor remove outdated debug mode macro
Red-Portal Aug 21, 2023
c794063
fix remove outdated DEBUG mechanism
Red-Portal Aug 21, 2023
0c5cc1c
fix LaTeX in README: `operatorname` is currently broken
Red-Portal Aug 21, 2023
29d7d27
remove `SimpleUnPack` dependency
Red-Portal Aug 22, 2023
75eef44
fix LaTeX in docs and README
Red-Portal Aug 22, 2023
40574f4
add warning about forward-mode AD when using `LocationScale`
Red-Portal Aug 22, 2023
8738256
fix documentation
Red-Portal Aug 22, 2023
8173744
fix remove reamining use of `@unpack`
Red-Portal Aug 22, 2023
e0548ae
Revert "remove `SimpleUnPack` dependency"
Red-Portal Aug 22, 2023
6ab95a0
Revert "fix remove reamining use of `@unpack`"
Red-Portal Aug 22, 2023
f0ec242
fix documentation for `optimize`
Red-Portal Aug 22, 2023
1d4c1b6
add specializations of `Optimise.destructure` for mean-field
Red-Portal Aug 22, 2023
231835f
add test for `Optimisers.destructure` specializations
Red-Portal Aug 22, 2023
ea2d426
add specialization of `rand` for meanfield resulting in faster AD
Red-Portal Aug 22, 2023
3033d75
add argument checks for `VIMeanFieldGaussian`, `VIFullRankGaussian`
Red-Portal Aug 22, 2023
0cc36c0
update documentation
Red-Portal Aug 22, 2023
b7d3471
fix type instability, bug in argument check in `LocationScale`
Red-Portal Aug 22, 2023
df50e83
add missing import bug
Red-Portal Aug 22, 2023
ae3e9b0
refactor test, fix type bug in tests for `LocationScale`
Red-Portal Aug 22, 2023
e4002cf
add missing compat entries
Red-Portal Aug 22, 2023
8c82569
fix missing package import in test
Red-Portal Aug 22, 2023
c2e7517
add additional tests for sampling `LocationScale`
Red-Portal Aug 22, 2023
3a6f8bf
fix bug in batch in-place `rand!` for `LocationScale`
Red-Portal Aug 22, 2023
b78ef4b
fix bug in inference test initialization
Red-Portal Aug 22, 2023
a1f7e98
add missing file
Red-Portal Aug 23, 2023
8b783ec
fix remove use of for 1.6
Red-Portal Aug 23, 2023
12cd9f2
refactor adjust inference test hyperparameters to be more robust
Red-Portal Aug 23, 2023
837c729
refactor `optimize` to return `obj_state`, add warm start kwargs
Red-Portal Aug 24, 2023
95629a5
refactor make tests more robust, reduce amount of tests
Red-Portal Aug 24, 2023
0b4b865
fix remove a cholesky in test model
Red-Portal Aug 24, 2023
b49f4eb
fix compat bounds, remove unused package
Red-Portal Aug 24, 2023
947a070
bump compat for ADTypes 0.2
Red-Portal Aug 24, 2023
a9b3f48
fix broken LaTeX in README
Red-Portal Aug 24, 2023
54826eb
remove redundant use of PDMats in docs
Red-Portal Aug 24, 2023
1d1c8ff
fix use `Cholesky` signature supported in 1.6
Red-Portal Aug 24, 2023
7bac95b
revert custom variational families and docs
Red-Portal Aug 24, 2023
d2ae29f
remove doc action for now
Red-Portal Aug 24, 2023
fb84e3d
revert README for now
Red-Portal Aug 24, 2023
0575b23
refactor remove redundant `rng` argument to `ADVI`, improve docs
Red-Portal Aug 25, 2023
ecc5242
fix wrong whitespace in tests
Red-Portal Aug 25, 2023
1cff3df
refactor `estimate_gradient` to `estimate_gradient!`, add docs
Red-Portal Aug 25, 2023
54acd8a
refactor add default `init` impl, update docs
Red-Portal Aug 25, 2023
61a2272
merge (manually) commit ff32ac642d6aa3a08d371ed895aa6b4026b06b92
Red-Portal Aug 26, 2023
c56d29e
fix test for new interface, change interface for `optimize`, `advi`
Red-Portal Aug 26, 2023
913b469
fix integer subtype error in documentation of advi
Red-Portal Sep 1, 2023
385a653
fix remove redundant argument for `advi`
Red-Portal Sep 1, 2023
4716b62
Merge branch 'rewriting_advancedvi_optimize' of github.com:Red-Portal…
Red-Portal Sep 1, 2023
c9df90e
remove manifest
Red-Portal Sep 1, 2023
19d11d1
refactor remove imports and use fully qualified names
Red-Portal Sep 1, 2023
59bd4f8
update documentation for `AbstractVariationalObjective`
Red-Portal Sep 1, 2023
dedc5cf
refactor use StableRNG instead of Random123
Red-Portal Sep 1, 2023
e35dc67
refactor migrate to Test, re-enable x86 tests
Red-Portal Sep 1, 2023
6413183
refactor remove inner constructor for `ADVI`
Red-Portal Sep 5, 2023
1668bae
fix swap `export`s and `include`s
Red-Portal Sep 7, 2023
a8f1254
fix doscs for `ADVI`
Red-Portal Sep 7, 2023
7b368c1
fix use `FillArrays` in the test problems
Red-Portal Sep 7, 2023
f216b37
fix `optimize` docs
Red-Portal Sep 7, 2023
9e0338d
fix improve argument names and docs for `optimize`
Red-Portal Sep 7, 2023
d6fcaf6
fix tests to match new interface of `optimize`
Red-Portal Sep 7, 2023
5799f1e
refactor move utility functions to new file
Red-Portal Sep 7, 2023
2229d61
fix docs for `optimize`
Red-Portal Sep 7, 2023
bc48e14
refactor advi internal objective
Red-Portal Sep 7, 2023
9949a04
refactor move `rng` to be an optional first argument
Red-Portal Sep 7, 2023
81010cd
Merge branch 'rewriting_advancedvi_optimize' of github.com:Red-Portal…
Red-Portal Sep 7, 2023
92cf354
fix docs for optimize
Red-Portal Sep 7, 2023
d75fd3c
add compat bounds to test dependencies
Red-Portal Sep 8, 2023
faa91ce
update compat bound for `Optimisers`
Red-Portal Sep 8, 2023
6dc0bb7
fix test compat
Red-Portal Sep 8, 2023
e941ad4
fix remove `!` in callback
Red-Portal Oct 23, 2023
15e0553
fix rng argument position in `advi`
Red-Portal Oct 23, 2023
a643cf2
fix callback signature in `optimize`
Red-Portal Oct 23, 2023
ffa69a3
refactor reorganize test files and naming
Red-Portal Oct 23, 2023
d5026e1
fix simplify description for `optimize`
Red-Portal Oct 23, 2023
764406b
fix remove redundant `Nothing` type signature for `maybe_init`
Red-Portal Oct 23, 2023
65006cb
fix remove "internal use" warning in documentation
Red-Portal Oct 23, 2023
b23a610
refactor change `estimate_gradient!` signature to be type stable
Red-Portal Oct 23, 2023
6c6634f
Merge branch 'rewriting_advancedvi_optimize' of github.com:Red-Portal…
Red-Portal Oct 23, 2023
9c242a5
add signature for computing `advi` over a fixed set of samples
Red-Portal Oct 23, 2023
e014863
fix change test tolerance
Red-Portal Oct 23, 2023
71184fa
fix update documentation for `estimate_gradient!`
Red-Portal Oct 23, 2023
9f6d663
refactor remove type constraint for variational parameters
Red-Portal Oct 23, 2023
a673520
fix remove dead code
Red-Portal Oct 23, 2023
a3f9886
add compat entry for stdlib
Red-Portal Oct 23, 2023
7a92708
add compat entry for stdlib in `test/`
Red-Portal Oct 23, 2023
5dd434d
fix rng argument position in tests
Red-Portal Oct 24, 2023
a764d9b
refactor change name of inference test
Red-Portal Oct 24, 2023
8af8a5f
fix documentation for `optimize`
Red-Portal Oct 24, 2023
5f1fb52
refactor rewrite the documentation for the global interfaces
Red-Portal Oct 24, 2023
2491c64
fix compat error
Red-Portal Oct 24, 2023
92d1489
fix documentation for `optimize` to be single line
Red-Portal Oct 24, 2023
a03e955
refactor remove begin end for one-liner
Red-Portal Oct 24, 2023
ff83c03
refactor create unified interface for estimating objectives
Red-Portal Nov 10, 2023
aecc655
refactor unify interface for entropy estimator, fix advi docs
Red-Portal Nov 10, 2023
a8d532a
fix STL estimator to use manually stopped gradients instead
Red-Portal Nov 10, 2023
65e9b12
add inference test for a non-bijector model
Red-Portal Nov 10, 2023
3691f16
refactor add indirections to handle STL and bijectors in ADVI
Red-Portal Nov 11, 2023
a063583
refactor split inference tests for advi+distributionsad
Red-Portal Nov 11, 2023
316b629
refactor rename advi to repgradelbo and not use bijectors directly
Red-Portal Nov 21, 2023
13b2088
fix documentation for estimate_objective
Red-Portal Nov 23, 2023
b0e1be1
refactor add indirection in repgradelbo for interacting with `q`
Red-Portal Nov 23, 2023
7361ed4
add TransformedDistribution support as extension
Red-Portal Nov 23, 2023
d2e7614
Update src/objectives/elbo/repgradelbo.jl
Red-Portal Dec 8, 2023
77686b5
fix docstring for entropy estimator
Red-Portal Dec 8, 2023
8461b43
fix `reparam_with_entropy` specialization for bijectors
Red-Portal Dec 8, 2023
8c559e3
Merge branch 'rewriting_advancedvi_optimize' of github.com:Red-Portal…
Red-Portal Dec 8, 2023
bd925cc
enable Zygote for non-bijector tests
Red-Portal Dec 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,37 +1,71 @@
name = "AdvancedVI"
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
version = "0.2.4"
version = "0.3.0"

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[weakdeps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
AdvancedVIEnzymeExt = "Enzyme"
AdvancedVIForwardDiffExt = "ForwardDiff"
AdvancedVIReverseDiffExt = "ReverseDiff"
AdvancedVIZygoteExt = "Zygote"
AdvancedVIBijectorsExt = "Bijectors"

[compat]
Bijectors = "0.11, 0.12, 0.13"
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6"
ADTypes = "0.1, 0.2"
Accessors = "0.1"
Bijectors = "0.13"
ChainRulesCore = "1.16"
DiffResults = "1"
Distributions = "0.25.87"
DocStringExtensions = "0.8, 0.9"
ForwardDiff = "0.10.3"
ProgressMeter = "1.0.0"
Requires = "0.5, 1.0"
Enzyme = "0.11.7"
FillArrays = "1.3"
ForwardDiff = "0.10.36"
Functors = "0.4"
LinearAlgebra = "1"
LogDensityProblems = "2"
Optimisers = "0.2.16, 0.3"
ProgressMeter = "1.6"
Random = "1"
Requires = "1.0"
ReverseDiff = "1.15.1"
SimpleUnPack = "1.1.0"
StatsBase = "0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
Tracker = "0.2.3"
Zygote = "0.6.63"
julia = "1.6"

[extras]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Test"]
37 changes: 37 additions & 0 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

module AdvancedVIBijectorsExt

if isdefined(Base, :get_extension)
using AdvancedVI
using Bijectors
using Random
else
using ..AdvancedVI
using ..Bijectors
using ..Random
end

function AdvancedVI.reparam_with_entropy(
rng ::Random.AbstractRNG,
n_samples::Int,
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
q ::Bijectors.TransformedDistribution,
q_stop ::Bijectors.TransformedDistribution,
ent_est
)
transform = q.transform
q_base = q.dist
q_base_stop = q_stop.dist
∑logabsdetjac = 0.0
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
base_samples = rand(rng, q_base, n_samples)
samples = mapreduce(hcat, eachcol(base_samples)) do base_sample
sample, logabsdetjac = with_logabsdet_jacobian(transform, base_sample)
∑logabsdetjac += logabsdetjac
sample
end
entropy_base = AdvancedVI.estimate_entropy_maybe_stl(
ent_est, base_samples, q_base, q_base_stop
)
entropy = entropy_base + ∑logabsdetjac/n_samples
samples, entropy
end
end
26 changes: 26 additions & 0 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

module AdvancedVIEnzymeExt

if isdefined(Base, :get_extension)
using Enzyme
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
else
using ..Enzyme
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
end

# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916)
function AdvancedVI.value_and_gradient!(

Check warning on line 15 in ext/AdvancedVIEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AdvancedVIEnzymeExt.jl#L15

Added line #L15 was not covered by tests
ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
) where {T<:Real}
y = f(θ)
DiffResults.value!(out, y)
∇θ = DiffResults.gradient(out)
fill!(∇θ, zero(T))
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ))
return out

Check warning on line 23 in ext/AdvancedVIEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AdvancedVIEnzymeExt.jl#L18-L23

Added lines #L18 - L23 were not covered by tests
end

end
29 changes: 29 additions & 0 deletions ext/AdvancedVIForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

module AdvancedVIForwardDiffExt

if isdefined(Base, :get_extension)
using ForwardDiff
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
else
using ..ForwardDiff
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
end

getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
) where {T<:Real}
chunk_size = getchunksize(ad)
config = if isnothing(chunk_size)
ForwardDiff.GradientConfig(f, θ)
else
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size))

Check warning on line 23 in ext/AdvancedVIForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AdvancedVIForwardDiffExt.jl#L23

Added line #L23 was not covered by tests
end
ForwardDiff.gradient!(out, f, θ, config)
return out
end

end
23 changes: 23 additions & 0 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

module AdvancedVIReverseDiffExt

if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using ReverseDiff
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..ReverseDiff
end

# ReverseDiff without compiled tape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not handle compiled tape?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into it.

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
tp = ReverseDiff.GradientTape(f, θ)
ReverseDiff.gradient!(out, tp, θ)
return out
end

end
24 changes: 24 additions & 0 deletions ext/AdvancedVIZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

module AdvancedVIZygoteExt

if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using Zygote
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..Zygote
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
y, back = Zygote.pullback(f, θ)
∇θ = back(one(y))
DiffResults.value!(out, y)
DiffResults.gradient!(out, only(∇θ))
return out
end

end
Loading
Loading