Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarify API for GP approximations #361

Merged
merged 17 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/AbstractGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export rand!,
mean_vector,
marginals,
logpdf,
approx_log_evidence,
elbo,
dtc,
posterior,
Expand Down
27 changes: 27 additions & 0 deletions src/abstract_gp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,30 @@ for (m, f) in [
)
end
end

"""
approx_log_evidence(approx::<Approximation>, lfx::LatentFiniteGP, ys)

Compute an approximation to the log of the marginal likelihood (also known as
"evidence") under the given `approx`imation to the posterior. The return value
of `approx_log_evidence` can be used to optimise the hyperparameters of `lfx`.
"""
function approx_log_evidence end

"""
posterior(fx::FiniteGP, y::AbstractVector{<:Real})
posterior(approx::<Approximation>, fx::FiniteGP, y::AbstractVector{<:Real})
posterior(approx::<Approximation>, lfx::LatentFiniteGP, y::AbstractVector)

Construct the posterior distribution over the latent Gaussian process (`fx.f`
or `lfx.fx.f`), given the observations `y` corresponding to the process's
finite projection (`fx` or `lfx`).

In the two-argument form, this describes exact GP regression with `y` observed
under a Gaussian likelihood, and returns a `PosteriorGP`.

In the three-argument form, the first argument specifies the approximation to
be used (e.g. `VFE` or defined in other packages such as ApproximateGPs.jl),
and returns an `ApproxPosteriorGP`.
"""
function posterior end
2 changes: 1 addition & 1 deletion src/exact_gpr_posterior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ end
"""
posterior(fx::FiniteGP, y::AbstractVector{<:Real})

Construct the posterior distribution over `fx.f` given observations `y` at `x` made under
Construct the posterior distribution over `fx.f` given observations `y` at `fx.x` made under
noise `fx.Σy`. This is another `AbstractGP` object. See chapter 2 of [1] for a recap on
exact inference in GPs. This posterior process has mean function
```julia
Expand Down
63 changes: 40 additions & 23 deletions src/sparse_approximations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@ struct VFE{Tfz<:FiniteGP}
fz::Tfz
end

const DTC = VFE
"""
DTC(fz::FiniteGP)

Similar to `VFE`, but uses a different objective for `approx_log_evidence`.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could maybe do with a better docstring but then it needs to be sorted out more thoroughly anyways (see #309) and I can't think of what it should be right now, so would leave that for some other PR/person/time...

"""
struct DTC{Tfz<:FiniteGP}
fz::Tfz
end

struct ApproxPosteriorGP{Tapprox,Tprior,Tdata} <: AbstractGP
approx::Tapprox
Expand Down Expand Up @@ -48,7 +55,7 @@ true
processes". In: Proceedings of the Twelfth International Conference on Artificial
Intelligence and Statistics. 2009.
"""
function posterior(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
function posterior(vfe::Union{VFE,DTC}, fx::FiniteGP, y::AbstractVector{<:Real})
@assert vfe.fz.f === fx.f

U_y = _cholesky(_symmetric(fx.Σy)).U
Expand All @@ -69,7 +76,7 @@ end

"""
function update_posterior(
f_post_approx::ApproxPosteriorGP{<:VFE},
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}},
fx::FiniteGP,
y::AbstractVector{<:Real}
)
Expand All @@ -78,7 +85,7 @@ Update the `ApproxPosteriorGP` given a new set of observations. Here, we retain
set of pseudo-points.
"""
function update_posterior(
f_post_approx::ApproxPosteriorGP{<:VFE}, fx::FiniteGP, y::AbstractVector{<:Real}
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}}, fx::FiniteGP, y::AbstractVector{<:Real}
st-- marked this conversation as resolved.
Show resolved Hide resolved
)
@assert f_post_approx.prior === fx.f

Expand Down Expand Up @@ -111,14 +118,14 @@ end

"""
function update_posterior(
f_post_approx::ApproxPosteriorGP{<:VFE},
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}},
z::FiniteGP,
)

Update the `ApproxPosteriorGP` given a new set of pseudo-points to append to the existing
set of pseudo-points.
"""
function update_posterior(f_post_approx::ApproxPosteriorGP{<:VFE}, fz::FiniteGP)
function update_posterior(f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}}, fz::FiniteGP)
@assert f_post_approx.prior === fz.f

z_old = inducing_points(f_post_approx)
Expand Down Expand Up @@ -161,48 +168,52 @@ function update_posterior(f_post_approx::ApproxPosteriorGP{<:VFE}, fz::FiniteGP)
x=f_post_approx.data.x,
Σy=f_post_approx.data.Σy,
)
return ApproxPosteriorGP(VFE(fz_new), f_post_approx.prior, cache)
return ApproxPosteriorGP(_update_approx(f_post_approx.approx, fz_new), f_post_approx.prior, cache)
st-- marked this conversation as resolved.
Show resolved Hide resolved
end

_update_approx(vfe::VFE, fz_new::FiniteGP) = VFE(fz_new)
_update_approx(dtc::DTC, fz_new::FiniteGP) = DTC(fz_new)

# AbstractGP interface implementation.

function Statistics.mean(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function Statistics.mean(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
return mean(f.prior, x) + cov(f.prior, x, inducing_points(f)) * f.data.α
end

function Statistics.cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function Statistics.cov(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
return cov(f.prior, x) - At_A(A) + Xt_invA_X(f.data.Λ_ε, A)
end

function Statistics.var(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function Statistics.var(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
return var(f.prior, x) - diag_At_A(A) + diag_Xt_invA_X(f.data.Λ_ε, A)
end

function Statistics.cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector, y::AbstractVector)
function Statistics.cov(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector, y::AbstractVector)
st-- marked this conversation as resolved.
Show resolved Hide resolved
A_zx = f.data.U' \ cov(f.prior, inducing_points(f), x)
A_zy = f.data.U' \ cov(f.prior, inducing_points(f), y)
return cov(f.prior, x, y) - A_zx'A_zy + Xt_invA_Y(A_zx, f.data.Λ_ε, A_zy)
end

function StatsBase.mean_and_cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function StatsBase.mean_and_cov(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
m_post = mean(f.prior, x) + A' * f.data.m_ε
C_post = cov(f.prior, x) - At_A(A) + Xt_invA_X(f.data.Λ_ε, A)
return m_post, C_post
end

function StatsBase.mean_and_var(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function StatsBase.mean_and_var(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
m_post = mean(f.prior, x) + A' * f.data.m_ε
c_post = var(f.prior, x) - diag_At_A(A) + diag_Xt_invA_X(f.data.Λ_ε, A)
return m_post, c_post
end

inducing_points(f::ApproxPosteriorGP{<:VFE}) = f.approx.fz.x
inducing_points(f::ApproxPosteriorGP{<:Union{VFE,DTC}}) = f.approx.fz.x

"""
approx_log_evidence(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
elbo(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})

The Titsias Evidence Lower BOund (ELBO) [1]. `y` are observations of `fx`, and `v.z`
Expand All @@ -228,14 +239,18 @@ true
processes". In: Proceedings of the Twelfth International Conference on Artificial
Intelligence and Statistics. 2009.
"""
function elbo(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
function approx_log_evidence(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
@assert vfe.fz.f === fx.f
_dtc, A = _compute_intermediates(fx, y, vfe.fz)
return _dtc - (tr_Cf_invΣy(fx, fx.Σy) - sum(abs2, A)) / 2
dtc_objective, A = _compute_intermediates(fx, y, vfe.fz)
return dtc_objective - (tr_Cf_invΣy(fx, fx.Σy) - sum(abs2, A)) / 2
end

elbo(vfe::VFE, fx, y) = approx_log_evidence(vfe, fx, y) # TODO deprecate?
theogf marked this conversation as resolved.
Show resolved Hide resolved
elbo(dtc::DTC, fx, y) = approx_log_evidence(dtc, fx, y) # TODO deprecation warning

"""
dtc(v::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
approx_log_evidence(dtc::DTC, fx::FiniteGP, y::AbstractVector{<:Real})
dtc(d::DTC, fx::FiniteGP, y::AbstractVector{<:Real}) [deprecated]

The Deterministic Training Conditional (DTC) [1]. `y` are observations of `fx`, and `v.z`
are inducing points.
Expand All @@ -260,13 +275,15 @@ true
Sparse Gaussian Process Regression". In: Proceedings of the Ninth International Workshop on
Artificial Intelligence and Statistics. 2003
"""
function dtc(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
@assert vfe.fz.f === fx.f
_dtc, _ = _compute_intermediates(fx, y, vfe.fz)
return _dtc
function approx_log_evidence(dtc::DTC, fx::FiniteGP, y::AbstractVector{<:Real})
@assert dtc.fz.f === fx.f
dtc_objective, _ = _compute_intermediates(fx, y, dtc.fz)
return dtc_objective
end

# Factor out computations common to the `elbo` and `dtc`.
dtc(vfe::Union{VFE,DTC}, fx, y) = approx_log_evidence(vfe, fx, y) # TODO deprecation warning

# Factor out computations of `approx_log_evidence` common to `VFE` and `DTC`
function _compute_intermediates(fx::FiniteGP, y::AbstractVector{<:Real}, fz::FiniteGP)
length(fx) == length(y) || throw(
DimensionMismatch(
Expand Down