Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarify API for GP approximations #361

Merged
merged 17 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/AbstractGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export rand!,
mean_vector,
marginals,
logpdf,
approx_log_evidence,
elbo,
dtc,
posterior,
Expand Down
27 changes: 27 additions & 0 deletions src/abstract_gp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,30 @@ for (m, f) in [
)
end
end

"""
approx_log_evidence(approx::<Approximation>, lfx::LatentFiniteGP, ys)

Compute an approximation to the log of the marginal likelihood (also known as
"evidence") under the given `approx`imation to the posterior. The return value
of `approx_log_evidence` can be used to optimise the hyperparameters of `lfx`.
"""
function approx_log_evidence end

"""
posterior(fx::FiniteGP, y::AbstractVector{<:Real})
posterior(approx::<Approximation>, fx::FiniteGP, y::AbstractVector{<:Real})
posterior(approx::<Approximation>, lfx::LatentFiniteGP, y::AbstractVector)

Construct the posterior distribution over the latent Gaussian process (`fx.f`
or `lfx.fx.f`), given the observations `y` corresponding to the process's
finite projection (`fx` or `lfx`).

In the two-argument form, this describes exact GP regression with `y` observed
under a Gaussian likelihood, and returns a `PosteriorGP`.

In the three-argument form, the first argument specifies the approximation to
be used (e.g. `VFE` or defined in other packages such as ApproximateGPs.jl),
and returns an `ApproxPosteriorGP`.
"""
function posterior end
3 changes: 3 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@
@deprecate sampleplot!(plt::RecipesBase.AbstractPlot, gp::FiniteGP, n::Int; kwargs...) sampleplot!(
plt, gp; samples=n, kwargs...
)

@deprecate elbo(dtc::DTC, fx, y) approx_log_evidence(dtc, fx, y)
@deprecate dtc(vfe::Union{VFE,DTC}, fx, y) approx_log_evidence(vfe, fx, y)
8 changes: 7 additions & 1 deletion src/exact_gpr_posterior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@ struct PosteriorGP{Tprior,Tdata} <: AbstractGP
data::Tdata
end

struct ExactGP end

posterior(::ExactGP, fx::FiniteGP, y::AbstractVector{<:Real}) = posterior(fx, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have gone the other way around?

But maybe that's too complicated...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this order of arguments quite intuitive.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I meant that the implemented definition would be posterior(::ExactPosterior, gp, y) and that posterior(gp, y) would default to it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make any difference in practice? (I kinda like it as it is but that might just be status-quo bias too...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice I think the answer is no.
Within the code, there might be something to be said for consistency. Every posterior is defined with the 3-argument form, but the exact one gets a special alias.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be fair, there is only one posterior. Shouldn't VFE and DTC dispatched on approx_posterior?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@theogf what do you mean? there is no approx_posterior method..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha somehow in my mind there was a approx_posterior method. So yeah then it's back to

  • Should we have a unique 3-args posterior API (where the 2-args default to ExactInference) ?
  • Should we just have 2-args methods and have the GP wrapped (like in ApproximateGPs) ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ApproximateGPs has the 3-args posterior though, no wrapping?


approx_log_evidence(::ExactGP, fx::FiniteGP, y::AbstractVector{<:Real}) = logpdf(fx, y)

"""
posterior(fx::FiniteGP, y::AbstractVector{<:Real})

Construct the posterior distribution over `fx.f` given observations `y` at `x` made under
Construct the posterior distribution over `fx.f` given observations `y` at `fx.x` made under
noise `fx.Σy`. This is another `AbstractGP` object. See chapter 2 of [1] for a recap on
exact inference in GPs. This posterior process has mean function
```julia
Expand Down
69 changes: 44 additions & 25 deletions src/sparse_approximations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@ struct VFE{Tfz<:FiniteGP}
fz::Tfz
end

const DTC = VFE
"""
DTC(fz::FiniteGP)

Similar to `VFE`, but uses a different objective for `approx_log_evidence`.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could maybe do with a better docstring but then it needs to be sorted out more thoroughly anyways (see #309) and I can't think of what it should be right now, so would leave that for some other PR/person/time...

"""
struct DTC{Tfz<:FiniteGP}
fz::Tfz
end

struct ApproxPosteriorGP{Tapprox,Tprior,Tdata} <: AbstractGP
approx::Tapprox
Expand Down Expand Up @@ -48,7 +55,7 @@ true
processes". In: Proceedings of the Twelfth International Conference on Artificial
Intelligence and Statistics. 2009.
"""
function posterior(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
function posterior(vfe::Union{VFE,DTC}, fx::FiniteGP, y::AbstractVector{<:Real})
@assert vfe.fz.f === fx.f

U_y = _cholesky(_symmetric(fx.Σy)).U
Expand All @@ -69,7 +76,7 @@ end

"""
function update_posterior(
f_post_approx::ApproxPosteriorGP{<:VFE},
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}},
fx::FiniteGP,
y::AbstractVector{<:Real}
)
Expand All @@ -78,7 +85,9 @@ Update the `ApproxPosteriorGP` given a new set of observations. Here, we retain
set of pseudo-points.
"""
function update_posterior(
f_post_approx::ApproxPosteriorGP{<:VFE}, fx::FiniteGP, y::AbstractVector{<:Real}
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}},
fx::FiniteGP,
y::AbstractVector{<:Real},
)
@assert f_post_approx.prior === fx.f

Expand Down Expand Up @@ -111,14 +120,14 @@ end

"""
function update_posterior(
f_post_approx::ApproxPosteriorGP{<:VFE},
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}},
z::FiniteGP,
)

Update the `ApproxPosteriorGP` given a new set of pseudo-points to append to the existing
set of pseudo-points.
"""
function update_posterior(f_post_approx::ApproxPosteriorGP{<:VFE}, fz::FiniteGP)
function update_posterior(f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}}, fz::FiniteGP)
@assert f_post_approx.prior === fz.f

z_old = inducing_points(f_post_approx)
Expand Down Expand Up @@ -161,48 +170,56 @@ function update_posterior(f_post_approx::ApproxPosteriorGP{<:VFE}, fz::FiniteGP)
x=f_post_approx.data.x,
Σy=f_post_approx.data.Σy,
)
return ApproxPosteriorGP(VFE(fz_new), f_post_approx.prior, cache)
return ApproxPosteriorGP(
_update_approx(f_post_approx.approx, fz_new), f_post_approx.prior, cache
)
end

_update_approx(vfe::VFE, fz_new::FiniteGP) = VFE(fz_new)
_update_approx(dtc::DTC, fz_new::FiniteGP) = DTC(fz_new)
Comment on lines +174 to +179
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the right way to handle this? if anyone can think of a better approach please say:)


# AbstractGP interface implementation.

function Statistics.mean(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function Statistics.mean(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
return mean(f.prior, x) + cov(f.prior, x, inducing_points(f)) * f.data.α
end

function Statistics.cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function Statistics.cov(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
return cov(f.prior, x) - At_A(A) + Xt_invA_X(f.data.Λ_ε, A)
end

function Statistics.var(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function Statistics.var(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
return var(f.prior, x) - diag_At_A(A) + diag_Xt_invA_X(f.data.Λ_ε, A)
end

function Statistics.cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector, y::AbstractVector)
function Statistics.cov(
f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector, y::AbstractVector
)
A_zx = f.data.U' \ cov(f.prior, inducing_points(f), x)
A_zy = f.data.U' \ cov(f.prior, inducing_points(f), y)
return cov(f.prior, x, y) - A_zx'A_zy + Xt_invA_Y(A_zx, f.data.Λ_ε, A_zy)
end

function StatsBase.mean_and_cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function StatsBase.mean_and_cov(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
m_post = mean(f.prior, x) + A' * f.data.m_ε
C_post = cov(f.prior, x) - At_A(A) + Xt_invA_X(f.data.Λ_ε, A)
return m_post, C_post
end

function StatsBase.mean_and_var(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function StatsBase.mean_and_var(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
m_post = mean(f.prior, x) + A' * f.data.m_ε
c_post = var(f.prior, x) - diag_At_A(A) + diag_Xt_invA_X(f.data.Λ_ε, A)
return m_post, c_post
end

inducing_points(f::ApproxPosteriorGP{<:VFE}) = f.approx.fz.x
inducing_points(f::ApproxPosteriorGP{<:Union{VFE,DTC}}) = f.approx.fz.x

"""
approx_log_evidence(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
elbo(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})

The Titsias Evidence Lower BOund (ELBO) [1]. `y` are observations of `fx`, and `v.z`
Expand All @@ -228,14 +245,16 @@ true
processes". In: Proceedings of the Twelfth International Conference on Artificial
Intelligence and Statistics. 2009.
"""
function elbo(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
function approx_log_evidence(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
@assert vfe.fz.f === fx.f
_dtc, A = _compute_intermediates(fx, y, vfe.fz)
return _dtc - (tr_Cf_invΣy(fx, fx.Σy) - sum(abs2, A)) / 2
dtc_objective, A = _compute_intermediates(fx, y, vfe.fz)
return dtc_objective - (tr_Cf_invΣy(fx, fx.Σy) - sum(abs2, A)) / 2
end

elbo(vfe::VFE, fx, y) = approx_log_evidence(vfe, fx, y)

"""
dtc(v::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
approx_log_evidence(dtc::DTC, fx::FiniteGP, y::AbstractVector{<:Real})

The Deterministic Training Conditional (DTC) [1]. `y` are observations of `fx`, and `v.z`
are inducing points.
Expand All @@ -248,25 +267,25 @@ julia> x = randn(1000);

julia> z = range(-5.0, 5.0; length=256);

julia> v = VFE(f(z));
julia> d = DTC(f(z));

julia> y = rand(f(x, 0.1));

julia> isapprox(dtc(v, f(x, 0.1), y), logpdf(f(x, 0.1), y); atol=1e-6, rtol=1e-6)
julia> isapprox(approx_log_evidence(d, f(x, 0.1), y), logpdf(f(x, 0.1), y); atol=1e-6, rtol=1e-6)
true
```

[1] - M. Seeger, C. K. I. Williams and N. D. Lawrence. "Fast Forward Selection to Speed Up
Sparse Gaussian Process Regression". In: Proceedings of the Ninth International Workshop on
Artificial Intelligence and Statistics. 2003
"""
function dtc(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
@assert vfe.fz.f === fx.f
_dtc, _ = _compute_intermediates(fx, y, vfe.fz)
return _dtc
function approx_log_evidence(dtc::DTC, fx::FiniteGP, y::AbstractVector{<:Real})
@assert dtc.fz.f === fx.f
dtc_objective, _ = _compute_intermediates(fx, y, dtc.fz)
return dtc_objective
end

# Factor out computations common to the `elbo` and `dtc`.
# Factor out computations of `approx_log_evidence` common to `VFE` and `DTC`
function _compute_intermediates(fx::FiniteGP, y::AbstractVector{<:Real}, fz::FiniteGP)
length(fx) == length(y) || throw(
DimensionMismatch(
Expand Down
Loading