Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "low-rank" variational families #76

Merged
merged 46 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
03563ea
rename location scale source file
Red-Portal Aug 3, 2024
5ab7286
revert renaming of location_scale file
Red-Portal Aug 3, 2024
3e0bf3d
add location-low-rank-scale family (except `entropy` and `logpdf`)
Red-Portal Aug 3, 2024
0bd6e5c
add feature complete `MvLocationScaleLowRank` with tests
Red-Portal Aug 5, 2024
34546e1
fix remove misleading comment
Red-Portal Aug 5, 2024
e030f2d
fix add missing test files
Red-Portal Aug 5, 2024
c7f36d6
fix broadcasting error on Julia 1.6
Red-Portal Aug 5, 2024
1bb3e3e
fix bug in sampling from `LocationScaleLowRank`
Red-Portal Aug 7, 2024
ddd2122
fix missing squared bug in `LocationScaleLowRank`
Red-Portal Aug 7, 2024
b24737f
add documentation for low-rank families
Red-Portal Aug 9, 2024
1d56953
add convenience constructors for `LocationScaleLowRank`
Red-Portal Aug 9, 2024
6752c6b
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Aug 10, 2024
52568b5
fix mhauru's suggestions and run formatter
Red-Portal Aug 10, 2024
96eae86
run formatter
Red-Portal Aug 10, 2024
15556da
run formatter
Red-Portal Aug 10, 2024
f796154
fix bugs and improve comments in `MvLocationScale` and lowrank
Red-Portal Aug 11, 2024
6b1699c
promote families.md into a higher category
Red-Portal Aug 11, 2024
5187d76
add test for `MVLocationScale` with non-Gaussian
Red-Portal Aug 14, 2024
8821908
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Aug 27, 2024
6dfc919
tighten compat bound for `Distributions`
Red-Portal Aug 27, 2024
c3ce393
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Sep 4, 2024
5c04d50
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Sep 5, 2024
ba293e5
fix base distribution standardization bug in `LocationScale`
Red-Portal Sep 5, 2024
426d943
fix base distribution standardization bug in `LocationScaleLowRank`
Red-Portal Sep 5, 2024
3cc9e80
format weird indentation in test `for` loops
Red-Portal Sep 5, 2024
0481dda
update docs add example for `LocationScaleLowRank`
Red-Portal Sep 5, 2024
8449402
fix docs warn about divergence when using `MvLocationScaleLowRank`
Red-Portal Sep 6, 2024
ff14c4c
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Sep 9, 2024
e48f231
Merge branch 'master' into lowrank
yebai Sep 10, 2024
aa8feee
Merge branch 'master' into lowrank
yebai Sep 10, 2024
5149869
Merge branch 'master' into lowrank
yebai Sep 10, 2024
e196da6
Update Benchmark.yml
yebai Sep 10, 2024
e4bff67
disable more features for PRs from forks
yebai Sep 10, 2024
894a849
fix `LocationScale` interfaces to only allow univariate base dist
Red-Portal Sep 11, 2024
f1cabba
Merge branch 'lowrank' of github.com:Red-Portal/AdvancedVI.jl into lo…
Red-Portal Sep 11, 2024
ce6793c
fix test comparison operator for families
Red-Portal Sep 11, 2024
71aeb5a
fix test comparison operator for families
Red-Portal Sep 11, 2024
77ace2b
fix test comparison operator for families
Red-Portal Sep 11, 2024
641de39
fix test comparison operator for families
Red-Portal Sep 11, 2024
a58f209
fix test comparison operator for families
Red-Portal Sep 11, 2024
846b259
fix test comparison operator for families
Red-Portal Sep 11, 2024
1116f68
fix test comparison operator for families
Red-Portal Sep 11, 2024
42d730d
fix formatting
Red-Portal Sep 11, 2024
99d08c5
fix formatting
Red-Portal Sep 11, 2024
4a90c5d
fix scale lower bound to `1e-4`
Red-Portal Sep 12, 2024
c41709b
fix docstring for `LowRankGaussian`
Red-Portal Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
using AdvancedVI
using Documenter

DocMeta.setdocmeta!(
AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true
)
DocMeta.setdocmeta!(AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true)

makedocs(;
modules = [AdvancedVI],
sitename = "AdvancedVI.jl",
repo = "https://github.com/TuringLang/AdvancedVI.jl/blob/{commit}{path}#{line}",
format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"),
pages = ["AdvancedVI" => "index.md",
"General Usage" => "general.md",
"Examples" => "examples.md",
"ELBO Maximization" => [
"Overview" => "elbo/overview.md",
"Reparameterization Gradient Estimator" => "elbo/repgradelbo.md",
"Location-Scale Variational Family" => "locscale.md",
]],
modules=[AdvancedVI],
sitename="AdvancedVI.jl",
repo="https://github.com/TuringLang/AdvancedVI.jl/blob/{commit}{path}#{line}",
format=Documenter.HTML(; prettyurls=get(ENV, "CI", nothing) == "true"),
pages=[
"AdvancedVI" => "index.md",
"General Usage" => "general.md",
"Examples" => "examples.md",
"ELBO Maximization" => [
"Overview" => "elbo/overview.md",
"Reparameterization Gradient Estimator" => "elbo/repgradelbo.md",
],
"Variational Families" => "families.md",
],
)

deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true)
145 changes: 145 additions & 0 deletions docs/src/families.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# [Reparameterizable Variational Families](@id families)

The [RepGradELBO](@ref repgradelbo) objective assumes that the members of the variational family have a differentiable sampling path.
We provide multiple pre-packaged variational families that can be readily used.

## The `LocationScale` Family

The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as

```math
z \sim q_{\lambda} \qquad\Leftrightarrow\qquad
z \stackrel{d}{=} C u + m;\quad u \sim \varphi
```

where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*.
``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``.
The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``.

The probability density is given by

```math
q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)),
```

the covariance is given as

```math
\mathrm{Var}\left(q_{\lambda}\right) = C \mathrm{Var}(q_{\lambda}) C^{\top}
```

and the entropy is given as

```math
\mathbb{H}(q_{\lambda}) = \mathbb{H}(\varphi) + \log |C|,
```

where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution.
Notice the ``\mathbb{H}(\varphi)`` does not depend on ``\log |C|``.
The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution.

!!! note

For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned.
Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities.

```@docs
MvLocationScale
```

The following are specialized constructors for convenience:

```@docs
FullRankGaussian
MeanFieldGaussian
```

### Gaussian Variational Families

```julia
using AdvancedVI, LinearAlgebra, Distributions;
μ = zeros(2);

L = LowerTriangular(diagm(ones(2)));
q = FullRankGaussian(μ, L)

L = Diagonal(ones(2));
q = MeanFieldGaussian(μ, L)
```

### Student-$$t$$ Variational Families

```julia
using AdvancedVI, LinearAlgebra, Distributions;
μ = zeros(2);
ν = 3;

# Full-Rank
L = LowerTriangular(diagm(ones(2)));
q = MvLocationScale(μ, L, TDist(ν))

# Mean-Field
L = Diagonal(ones(2));
q = MvLocationScale(μ, L, TDist(ν))
```

### Laplace Variational families

```julia
using AdvancedVI, LinearAlgebra, Distributions;
μ = zeros(2);

# Full-Rank
L = LowerTriangular(diagm(ones(2)));
q = MvLocationScale(μ, L, Laplace())

# Mean-Field
L = Diagonal(ones(2));
q = MvLocationScale(μ, L, Laplace())
```

## The `LocationScaleLowRank` Family

In practice, `LocationScale` families with full-rank scale matrices are known to converge slowly as they require a small SGD stepsize.
Low-rank variational families can be an effective alternative[^ONS2018].
`LocationScaleLowRank` generally represent any ``d``-dimensional distribution which its sampling path can be represented as

```math
z \sim q_{\lambda} \qquad\Leftrightarrow\qquad
z \stackrel{d}{=} D u_1 + U u_2 + m;\quad u_1, u_2 \sim \varphi
```

where ``D \in \mathbb{R}^{d \times d}`` is a diagonal matrix, ``U \in \mathbb{R}^{d \times r}`` is a dense low-rank matrix for the rank ``r > 0``, ``m \in \mathbb{R}^d`` is the location, and ``\varphi`` is the *base distribution*.
``m``, ``D``, and ``U`` form the variational parameters ``\lambda = (m, D, U)``.

The covariance of this distribution is given as

```math
\mathrm{Var}\left(q_{\lambda}\right) = D \mathrm{Var}(\varphi) D + U \mathrm{Var}(\varphi) U^{\top}
```

and the entropy is given by the matrix determinant lemma as

```math
\mathbb{H}(q_{\lambda})
= \mathbb{H}(\varphi) + \log |\Sigma|
= \mathbb{H}(\varphi) + 2 \log |D| + \log |I + U^{\top} D^{-2} U|,
```

where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution.

!!! note

`logpdf` for `LocationScaleLowRank` is unfortunately not computationally efficient and has the same time complexity as `LocationScale` with a full-rank scale.

```@docs
MvLocationScaleLowRank
```

The following is a specialized constructor for convenience:

```@docs
LowRankGaussian
```

[^ONS2018]: Ong, V. M. H., Nott, D. J., & Smith, M. S. (2018). Gaussian variational approximation with a factor covariance structure. Journal of Computational and Graphical Statistics, 27(3), 465-478.
80 changes: 0 additions & 80 deletions docs/src/locscale.md

This file was deleted.

4 changes: 4 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ export MvLocationScale, MeanFieldGaussian, FullRankGaussian

include("families/location_scale.jl")

export MvLocationScaleLowRank, LowRankGaussian

include("families/location_scale_low_rank.jl")

# Optimization Routine

function optimize end
Expand Down
Loading
Loading