Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new glm_predictive_distribution and corresponding changes in LaplaceRedux.predict #99

Merged
merged 13 commits into from
Jul 18, 2024
13 changes: 8 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,29 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),

*Note*: We try to adhere to these practices as of version [v0.2.1].

## Version [1.0.0] - 2024-07-01
## Version [1.0.0] - 2024-07-17

### Changed

- Changed the behavior of the `predict` function so that it now gives the user the possibility to get distributions from the Distributions.jl package as output. [#99]
- Calling a Laplace object on an array, `(la::AbstractLaplace)(X::AbstractArray)` now simply calls the underlying neural network on data. In other words, it returns the generic predictions, not LA predictions. This was implemented to facilitate better interplay with `MLJFlux`. [#39]
- Moving straight to `1.0.0` now for package, because zero major versions cause compat headaches with other packages in Taija ecosystem. [#39]
- Removed support for `v1.7`, now `v1.9` as lower bound. This is because we are now overloading the `MLJFlux.train` and `MLJFlux.train_epoch` functions, which were added in version `v0.5.0` of that package, which is lower-bounded at `v1.9`. [#39]
- Updated codecov workflow in CI.yml. [#39]
- fixed test functions [#39]
- adapted the LaplaceClassification and the LaplaceRegression struct to use the new @mlj_model macro from MLJBase.[#39]
- Changed the fit! method arguments. [#39]
- Changed the predict functions for both LaplaceClassification and LaplaceRegression.[#39]
- Changed the predict functions for both LaplaceClassification and LaplaceRegression.[#39]

### Removed

- Removed the shape, build and clean! functions.[#39]
- Removed Review dog for code format suggestions. [#39]

### Added

- Added new keyword parameter ret_distr::Bool=false to predict. [#99]

## Version [0.2.3] - 2024-05-31

### Changed
Expand All @@ -32,9 +37,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
- Changed MMI.clean! to check the value of link_approx only in the case likelihood is set to `:classification`
- Now the likelihood type in LaplaceClassification and LaplaceRegression is automatically set by the inner constructor. The user is not required to provide it as a parameter anymore.




## Version [0.2.2] - 2024-05-30

### Changed
Expand All @@ -48,6 +50,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
Added Distributions to LaplaceRedux dependency ( needed for MMI.predict(model::LaplaceRegression, fitresult, Xnew) )


>>>>>>> main

## Version [0.2.1] - 2024-05-29

Expand Down
119 changes: 87 additions & 32 deletions src/baselaplace/predicting.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,28 @@
using Distributions: Distributions
using Distributions: Normal, Bernoulli, Categorical
using Flux
using Statistics: mean, var

"""
has_softmax_or_sigmoid_final_layer(model::Flux.Chain)

Check if the FLux model ends with a sigmoid or with a softmax layer

Input:
- `model`: the Flux Chain object that represent the neural network.
Return:
- `has_finaliser`: true if the check is positive, false otherwise.

"""
function has_softmax_or_sigmoid_final_layer(model::Flux.Chain)
# Get the last layer of the model
last_layer = last(model.layers)

# Check if the last layer is either softmax or sigmoid
has_finaliser = (last_layer == Flux.sigmoid || last_layer == Flux.softmax)

return has_finaliser
end

"""
functional_variance(la::AbstractLaplace, 𝐉::AbstractArray)

Expand All @@ -20,7 +43,7 @@ Computes the linearized GLM predictive.
- `X::AbstractArray`: Input data.

# Returns

- `normal_distr` A normal distribution N(fμ,fvar) approximating the predictive distribution p(y|X) given the input data X.
- `fμ::AbstractArray`: Mean of the predictive distribution. The output shape is column-major as in Flux.
- `fvar::AbstractArray`: Variance of the predictive distribution. The output shape is column-major as in Flux.

Expand All @@ -42,13 +65,8 @@ function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)
fvar = functional_variance(la, 𝐉)
fvar = reshape(fvar, size(fμ)...)
fstd = sqrt.(fvar)
normal_distr = [
Distributions.Normal(fμ[i, j], fstd[i, j]) for i in 1:size(fμ, 1),
j in 1:size(fμ, 2)
]
#normal_distr = [
#Distributions.Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 1)] maybe this one is the correct one
return normal_distr
normal_distr = [Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 2)]
return (normal_distr, fμ, fvar)
end

"""
Expand All @@ -61,13 +79,18 @@ Computes predictions from Bayesian neural network.
- `la::AbstractLaplace`: A Laplace object.
- `X::AbstractArray`: Input data.
- `link_approx::Symbol=:probit`: Link function approximation. Options are `:probit` and `:plugin`.
- `predict_proba::Bool=true`: If `true` (default), returns probabilities for classification tasks.
- `predict_proba::Bool=true`: If `true` (default) apply a sigmoid or a softmax function to the output of the Flux model.
- `return_distr::Bool=false`: if `false` (default), the function output either the direct output of the chain or pseudo-probabilities (if predict_proba= true).
if `true` predict return a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.

# Returns

- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux.
- `fvar::AbstractArray`: If regression, it also returns the variance of the predictive distribution. The output shape is column-major as in Flux.

For classification tasks, LaplaceRedux provides different options:
if ret_distr is false:
- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux.
if ret_distr is true:
- a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.
For regression tasks:
- `normal_distr::Distributions.Normal`:the array of Normal distributions computed by glm_predictive_distribution.
# Examples

```julia-repl
Expand All @@ -82,37 +105,69 @@ predict(la, hcat(x...))
```
"""
function predict(
pasq-cat marked this conversation as resolved.
Show resolved Hide resolved
la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true
la::AbstractLaplace,
X::AbstractArray;
link_approx=:probit,
predict_proba::Bool=true,
ret_distr::Bool=false,
)
normal_distr = glm_predictive_distribution(la, X)
fμ, fvar = mean.(normal_distr), var.(normal_distr)
normal_distr, fμ, fvar = glm_predictive_distribution(la, X)

# Regression:
if la.likelihood == :regression
return normal_distr
return reshape(normal_distr,(:,1))
end

# Classification:
if la.likelihood == :classification
has_finaliser = has_softmax_or_sigmoid_final_layer(la.model)
pat-alt marked this conversation as resolved.
Show resolved Hide resolved

# Probit approximation
if link_approx == :probit
z = probit(fμ, fvar)
end
# case when no softmax/sigmoid function is applied
if has_finaliser == false

if link_approx == :plugin
z = fμ
end
# Probit approximation
if link_approx == :probit
z = probit(fμ, fvar)
end

if link_approx == :plugin
z = fμ
end

# Sigmoid/Softmax
if predict_proba
if la.posterior.n_out == 1
p = Flux.sigmoid(z)
if ret_distr
p = map(x -> Bernoulli(x), p)
pat-alt marked this conversation as resolved.
Show resolved Hide resolved
end

else
p = Flux.softmax(z; dims=1)
if ret_distr
p = mapslices(col -> Categorical(col), p; dims=1)
end
end
else
if ret_distr
@warn "the model does not produce pseudo-probabilities. ret_distr will not work if predict_proba is set to false."
end
p = z
end
else # case when has_finaliser is true
if predict_proba == false
@warn "the model already produce pseudo-probabilities since it has either sigmoid or a softmax layer as a final layer."
end
if ret_distr
if la.posterior.n_out == 1
p = map(x -> Bernoulli(x), fμ)
else
p = mapslices(col -> Categorical(col), fμ; dims=1)
end

# Sigmoid/Softmax
if predict_proba
if la.posterior.n_out == 1
p = Flux.sigmoid(z)
else
p = Flux.softmax(z; dims=1)
p =
end
else
p = z
end

return p
Expand Down
40 changes: 23 additions & 17 deletions test/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.9.4"
julia_version = "1.10.3"
manifest_format = "2.0"
project_hash = "fa6672850323ab23f77b8212aabdf7f033fa4213"
project_hash = "ef612be958e3bfb9b8e55d3a8895bfedb9cfa7d1"

[[deps.AbstractFFTs]]
deps = ["LinearAlgebra"]
Expand Down Expand Up @@ -267,7 +267,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.5+0"
version = "1.1.1+0"

[[deps.CompositionsBase]]
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
Expand Down Expand Up @@ -899,9 +899,14 @@ uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
version = "8.4.0+0"

[[deps.LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[deps.LibGit2_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"]
uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
version = "1.6.4+0"

[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
Expand Down Expand Up @@ -1088,7 +1093,7 @@ version = "1.1.9"
[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.28.2+0"
version = "2.28.2+1"

[[deps.Measures]]
git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102"
Expand Down Expand Up @@ -1136,7 +1141,7 @@ version = "0.3.4"

[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
version = "2022.10.11"
version = "2023.1.10"

[[deps.MultivariateStats]]
deps = ["Arpack", "Distributions", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"]
Expand Down Expand Up @@ -1208,12 +1213,12 @@ version = "0.2.5"
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.21+4"
version = "0.3.23+4"

[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
version = "0.8.1+0"
version = "0.8.1+2"

[[deps.OpenMPI_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"]
Expand Down Expand Up @@ -1259,7 +1264,7 @@ version = "1.6.3"
[[deps.PCRE2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15"
version = "10.42.0+0"
version = "10.42.0+1"

[[deps.PDMats]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
Expand Down Expand Up @@ -1323,7 +1328,7 @@ version = "0.43.4+0"
[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.9.2"
version = "1.10.0"

[[deps.PlotThemes]]
deps = ["PlotUtils", "Statistics"]
Expand Down Expand Up @@ -1429,7 +1434,7 @@ deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[deps.Random]]
deps = ["SHA", "Serialization"]
deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[deps.RealDot]]
Expand Down Expand Up @@ -1555,6 +1560,7 @@ version = "1.2.1"
[[deps.SparseArrays]]
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
version = "1.10.0"

[[deps.SparseInverseSubset]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
Expand Down Expand Up @@ -1621,7 +1627,7 @@ version = "3.4.0"
[[deps.Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
version = "1.9.0"
version = "1.10.0"

[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
Expand Down Expand Up @@ -1697,9 +1703,9 @@ deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[[deps.SuiteSparse_jll]]
deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"]
deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
version = "5.10.1+6"
version = "7.2.1+1"

[[deps.TOML]]
deps = ["Dates"]
Expand Down Expand Up @@ -2057,7 +2063,7 @@ version = "0.10.1"
[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.13+0"
version = "1.2.13+1"

[[deps.Zstd_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
Expand Down Expand Up @@ -2126,7 +2132,7 @@ version = "0.15.1+0"
[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+0"
version = "5.8.0+1"

[[deps.libevdev_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand Down Expand Up @@ -2172,7 +2178,7 @@ version = "1.52.0+1"
[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
version = "17.4.0+0"
version = "17.4.0+2"

[[deps.x264_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
Loading