Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

101 remaining issue on the mljinterface #103

Merged
merged 8 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),

*Note*: We try to adhere to these practices as of version [v0.2.1].

## Version [1.0.1] - 2024-07-19

### Changed
- added the option to return meand and variance to predict in the case of regression[[#101](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/issues/101)]
- modified mlj_flux.jl by adding the ret_distr parameter and fixed mljflux.predict both for classification and regression tasks.

## Version [1.0.0] - 2024-07-17

### Changed
Expand Down
6 changes: 5 additions & 1 deletion src/baselaplace/predicting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ function predict(

# Regression:
if la.likelihood == :regression
return reshape(normal_distr, (:, 1))
if ret_distr
return reshape(normal_distr, (:, 1))
else
return fμ, fvar
end
end

# Classification:
Expand Down
46 changes: 30 additions & 16 deletions src/mlj_flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ using Optimisers: Optimisers

A mutable struct representing a Laplace regression model that extends the `MLJFlux.MLJFluxProbabilistic` abstract type.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
The model is trained using the `fit!` method. The model is defined by the following default parameters:

The model is defined by the following default parameters for all `MLJFlux` models:

- `builder`: a Flux model that constructs the neural network.
- `optimiser`: a Flux optimiser.
Expand All @@ -27,13 +28,17 @@ The model is trained using the `fit!` method. The model is defined by the follow
- `rng`: a random number generator.
- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining.
- `acceleration`: the computational resource to use.

The model also has the following parameters, which are specific to the Laplace approximation:

- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
- `subnetwork_indices`: the indices of the subnetworks.
- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
- `σ`: the standard deviation of the prior distribution.
- `μ₀`: the mean of the prior distribution.
- `P₀`: the covariance matrix of the prior distribution.
- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic
Expand All @@ -55,16 +60,17 @@ MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilis
σ::Float64 = 1.0
μ₀::Float64 = 0.0
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
ret_distr::Bool = false::(_ in (true, false))
fit_prior_nsteps::Int = 100::(_ > 0)
end

"""
MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic

A mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
The model is trained using the `fit!` method. The model is defined by the following default parameters:
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.

The model is defined by the following default parameters for all `MLJFlux` models:
- `builder`: a Flux model that constructs the neural network.
- `finaliser`: a Flux model that processes the output of the neural network.
- `optimiser`: a Flux optimiser.
Expand All @@ -76,13 +82,19 @@ A mutable struct representing a Laplace Classification model that extends the ML
- `rng`: a random number generator.
- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining.
- `acceleration`: the computational resource to use.

The model also has the following parameters, which are specific to the Laplace approximation:

- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
- `subnetwork_indices`: the indices of the subnetworks.
- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
- `σ`: the standard deviation of the prior distribution.
- `μ₀`: the mean of the prior distribution.
- `P₀`: the covariance matrix of the prior distribution.
- `link_approx`: the link approximation to use, either `:probit` or `:plugin`.
- `predict_proba`: a boolean that select whether to predict probabilities or not.
- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic
Expand All @@ -107,6 +119,7 @@ MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbab
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
link_approx::Symbol = :probit::(_ in (:probit, :plugin))
predict_proba::Bool = true::(_ in (true, false))
ret_distr::Bool = false::(_ in (true, false))
fit_prior_nsteps::Int = 100::(_ > 0)
end

Expand Down Expand Up @@ -273,15 +286,11 @@ Predict the output for new input data using a Laplace regression model.
"""
function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew)
Xnew = MLJBase.matrix(Xnew)

model = fitresult[1]
la = fitresult[1]
#convert in a vector of vectors because MLJ ask to do so
X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)]
#inizialize output vector yhat
yhat = []
X_vec = collect(eachrow(Xnew))
# Predict using Laplace and collect the predictions
yhat = [glm_predictive_distribution(model, x_vec) for x_vec in X_vec]

yhat = [map(x -> LaplaceRedux.predict(la, x; ret_distr=model.ret_distr), X_vec)...]
return yhat
end

Expand Down Expand Up @@ -448,13 +457,18 @@ function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew)
la = fitresult[1]
Xnew = MLJBase.matrix(Xnew)
#convert in a vector of vectors because Laplace ask to do so
X_vec = X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)]

# Predict using Laplace and collect the predictions
X_vec = collect(eachrow(Xnew))
predictions = [
LaplaceRedux.predict(
la, x; link_approx=model.link_approx, predict_proba=model.predict_proba
) for x in X_vec
map(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could just try of broadcasting works here (i.e. predict.(la, X_new))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just open a new issue for this.

x -> LaplaceRedux.predict(
la,
x;
link_approx=model.link_approx,
predict_proba=model.predict_proba,
ret_distr=model.ret_distr,
),
X_vec,
)...,
]

return predictions
Expand Down
7 changes: 1 addition & 6 deletions test/laplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,7 @@ end
la = Laplace(nn; likelihood=:regression, subset_of_weights=subset_w)
fit!(la, data)
matrix_normals = Matrix{Normal{T}} where {T<:AbstractFloat}
@test typeof(predict(la, X)) <: matrix_normals

#predict(la, X[1]; link_approx=:plugin)
#predict(la, X[1]; link_approx=:probit)
#predict(la, X[1]; ret_distr=true)
#predict(la, X[1]; ret_distr=true, predict_proba=false)
@test typeof(predict(la, X; ret_distr=true)) <: matrix_normals
end

#testing the function LaplaceRedux.has_softmax_or_sigmoid_final_layer
Expand Down
8 changes: 4 additions & 4 deletions test/mlj_flux_interfacing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using StableRNGs
subset_of_weights=:incorrect,
hessian_structure=:incorrect,
backend=:incorrect,
ret_distr=true,
)

fitresult, cache, _report = MLJBase.fit(model, 0, X, y)
Expand All @@ -46,6 +47,8 @@ using StableRNGs
history = _report.training_losses
@test length(history) == model.epochs + 1

yhat = MLJBase.predict(model, fitresult, X)

# start fresh with small epochs:
model = LaplaceRegression(;
builder=builder,
Expand Down Expand Up @@ -88,13 +91,10 @@ using StableRNGs
N = 300
X = MLJBase.table(rand(Float32, N, 4))
ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N)

builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu)
optimiser = Flux.Optimise.Adam(0.03)

y = ycont

@test basictest_regression(X, y, builder, optimiser, 0.9)
@test basictest_regression(X, ycont, builder, optimiser, 0.9)
end

@testset "Classification" begin
Expand Down
Loading