Skip to content

Commit

Permalink
Merge pull request #89 from JuliaTrustworthyAI/88-fix-docstrings
Browse files Browse the repository at this point in the history
88 fix docstrings
  • Loading branch information
pasq-cat authored May 29, 2024
2 parents 82fe1d8 + 00c6932 commit 00bdfc2
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 16 deletions.
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

*Note*: We try to adhere to these practices as of version [v0.2.1].

## Version [0.2.1] - 2024-05-29

### Changed

- Improved the docstring for the `predict` and `glm_predictive_distribution` methods. [#88]

### Added

- Added `probit` helper function to compute probit approximation for classification. [#88]
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LaplaceRedux"
uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
authors = ["Patrick Altmeyer"]
version = "0.2.0"
version = "0.2.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
59 changes: 44 additions & 15 deletions src/baselaplace/predicting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,32 @@ end
glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)
Computes the linearized GLM predictive.
# Arguments
- `la::AbstractLaplace`: A Laplace object.
- `X::AbstractArray`: Input data.
# Returns
- `fμ::AbstractArray`: Mean of the predictive distribution. The output shape is column-major as in Flux.
- `fvar::AbstractArray`: Variance of the predictive distribution. The output shape is column-major as in Flux.
# Examples
```julia-repl
using Flux, LaplaceRedux
using LaplaceRedux.Data: toy_data_linear
x, y = toy_data_linear()
data = zip(x,y)
nn = Chain(Dense(2,1))
la = Laplace(nn; likelihood=:classification)
fit!(la, data)
glm_predictive_distribution(la, hcat(x...))
"""
function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)
𝐉, fμ = Curvature.jacobians(la.est_params.curvature, X)
= reshape(fμ, Flux.outputsize(la.model, size(X)))
fvar = functional_variance(la, 𝐉)
fvar = reshape(fvar, size(fμ)...)
return fμ, fvar
Expand All @@ -24,14 +47,27 @@ end
Computes predictions from Bayesian neural network.
# Arguments
- `la::AbstractLaplace`: A Laplace object.
- `X::AbstractArray`: Input data.
- `link_approx::Symbol=:probit`: Link function approximation. Options are `:probit` and `:plugin`.
- `predict_proba::Bool=true`: If `true` (default), returns probabilities for classification tasks.
# Returns
- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux.
- `fvar::AbstractArray`: If regression, it also returns the variance of the predictive distribution. The output shape is column-major as in Flux.
# Examples
```julia-repl
using Flux, LaplaceRedux
using LaplaceRedux.Data: toy_data_linear
x, y = toy_data_linear()
data = zip(x,y)
nn = Chain(Dense(2,1))
la = Laplace(nn)
la = Laplace(nn; likelihood=:classification)
fit!(la, data)
predict(la, hcat(x...))
```
Expand All @@ -51,8 +87,7 @@ function predict(

# Probit approximation
if link_approx == :probit
κ = 1 ./ sqrt.(1 .+ π / 8 .* fvar)
z = κ .*
z = probit(fμ, fvar)
end

if link_approx == :plugin
Expand All @@ -75,20 +110,14 @@ function predict(
end

"""
predict(la::AbstractLaplace, X::Matrix; link_approx=:probit, predict_proba::Bool=true)
Compute predictive posteriors for a batch of inputs.
probit(fμ::AbstractArray, fvar::AbstractArray)
Predicts on a matrix of inputs. Note, input is assumed to be batched only if it is a matrix.
If the input dimensionality of the model is 1 (a vector), one should still prepare a 1×B matrix batch as input.
Compute the probit approximation of the predictive distribution.
"""
function predict(
la::AbstractLaplace, X::Matrix; link_approx=:probit, predict_proba::Bool=true
)
return stack([
predict(la, X[:, i]; link_approx=link_approx, predict_proba=predict_proba) for
i in 1:size(X, 2)
])
function probit(fμ::AbstractArray, fvar::AbstractArray)
κ = 1 ./ sqrt.(1 .+ π / 8 .* fvar)
z = κ .*
return z
end

"""
Expand Down

0 comments on commit 00bdfc2

Please sign in to comment.