Skip to content

Commit

Permalink
partial fixes. fit! function gives trouble
Browse files Browse the repository at this point in the history
  • Loading branch information
pasq-cat committed Jun 13, 2024
1 parent 19e9b7f commit 9275277
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 30 deletions.
10 changes: 7 additions & 3 deletions src/baselaplace/predicting.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Distributions
"""
functional_variance(la::AbstractLaplace, 𝐉::AbstractArray)
Expand Down Expand Up @@ -39,7 +40,9 @@ function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)
= reshape(fμ, Flux.outputsize(la.model, size(X)))
fvar = functional_variance(la, 𝐉)
fvar = reshape(fvar, size(fμ)...)
return fμ, fvar
fstd = sqrt.(fvar)
normal_distr= [Distributions.Normal(fμ[i, j], fstd[i, j]) for i in 1:size(fμ, 1), j in 1:size(fμ, 2)]
return normal_distr
end

"""
Expand Down Expand Up @@ -75,11 +78,12 @@ predict(la, hcat(x...))
function predict(
la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true
)
fμ, fvar = glm_predictive_distribution(la, X)
normal_distr = glm_predictive_distribution(la, X)
fμ, fvar = mean.(normal_distr), var.(normal_distr)

# Regression:
if la.likelihood == :regression
return fμ, fvar
return normal_distr
end

# Classification:
Expand Down
47 changes: 30 additions & 17 deletions src/mlj_flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Tables
using Distributions
using LinearAlgebra
using LaplaceRedux
using ComputationalResources
using MLJBase
import MLJBase: @mlj_model, metadata_model, metadata_pkg

Expand Down Expand Up @@ -38,13 +39,14 @@ The model is trained using the `fit!` method. The model is defined by the follow
@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic
builder=MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish)
optimiser= Flux.Optimise.Adam()
loss=Flux.Losses.mse
loss=Flux.Losses.mse
epochs::Int=10::(_ > 0)
batch_size::Int=1::(_ > 0)
lambda::Float64=1.0
alpha::Float64=0.0
rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG
optimiser_changes_trigger_retraining::Bool=false::(_ in (true, false))
acceleration::AbstractResource=CPU1()
acceleration=CPU1()::(_ in (CPU1(), CUDALibs()))
subset_of_weights::Symbol=:all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices=nothing
hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in (":full", ":diagonal"))
Expand Down Expand Up @@ -88,13 +90,13 @@ A mutable struct representing a Laplace Classification model that extends the ML
finaliser=Flux.softmax
optimiser= Flux.Optimise.Adam()
loss=Flux.crossentropy
epochs::Int=10
epochs::Int=10::(_ > 0)
batch_size::Int=1::(_ > 0)
lambda::Float64=1.0
alpha::Float64=0.0
rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG
optimiser_changes_trigger_retraining::Bool=false
acceleration::AbstractResource=CPU1()
acceleration=CPU1()::(_ in (CPU1(), CUDALibs()))
subset_of_weights::Symbol=:all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices::Vector{Vector{Int}}=Vector{Vector{Int}}([])
hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in (":full", ":diagonal"))
Expand All @@ -113,6 +115,7 @@ const MLJ_Laplace= Union{LaplaceClassification,LaplaceRegression}
################################ functions shape and build

function MLJFlux.shape(model::LaplaceClassification, X, y)
X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X
n_input = size(X,2)
levels = unique(y)
n_output = length(levels)
Expand All @@ -121,6 +124,7 @@ function MLJFlux.shape(model::LaplaceClassification, X, y)
end

function MLJFlux.shape(model::LaplaceRegression, X, y)
X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X
n_input = size(X,2)
dims = size(y)
if length(dims) == 1
Expand All @@ -132,15 +136,23 @@ function MLJFlux.shape(model::LaplaceRegression, X, y)
end


function MLJFlux.build(model::MLJ_Laplace, shape)
function MLJFlux.build(model::LaplaceClassification,rng, shape)
#chain
chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser)
chain = Flux.Chain(MLJFlux.build(model.builder, rng,shape...), model.finaliser)

return chain
end


function MLJFlux.build(model::LaplaceRegression,rng,shape)
#chain
chain = MLJFlux.build(model.builder,rng , shape...)

return chain
end

function MLJFlux.fitresult(model::LaplaceClassification, chain, y)
return (chain, length(unique(y_cl)))
return (chain, length(unique(y)))
end

function MLJFlux.fitresult(model::LaplaceRegression, chain, y)
Expand Down Expand Up @@ -174,7 +186,8 @@ Fit the LaplaceClassification model using MLJFlux.
- `model::LaplaceClassification`: The fitted LaplaceClassification model.
"""
function MLJFlux.fit!(model::LaplaceClassification, chain,penalty,optimiser,epochs, verbosity, X, y)
X = MLJBase.matrix(X, transpose=true)
X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X
X = X'

# Integer encode the target variable y
#y_onehot = unique(y) .== permutedims(y)
Expand Down Expand Up @@ -286,8 +299,8 @@ Fit the LaplaceRegression model using Flux.jl.
- `model::LaplaceRegression`: The fitted LaplaceRegression model.
"""
function MLJFlux.fit!(model::LaplaceRegression, chain,penalty,optimiser,epochs, verbosity, X, y)

X = MLJBase.matrix(X, transpose=true)
X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X
X = X'
la = LaplaceRedux.Laplace(
chain;
likelihood=:regression,
Expand Down Expand Up @@ -373,16 +386,16 @@ function MLJFlux.predict(model::LaplaceRegression, Xnew)
# Predict using Laplace and collect the predictions
yhat = [glm_predictive_distribution(model.la, x_vec) for x_vec in X_vec]

predictions = []
for row in eachrow(yhat)
#predictions = []
#for row in eachrow(yhat)

mean_val = Float64(row[1][1][1])
std_val = sqrt(Float64(row[1][2][1]))
#mean_val = Float64(row[1][1][1])
#std_val = sqrt(Float64(row[1][2][1]))
# Append a Normal distribution:
push!(predictions, Normal(mean_val, std_val))
end
#push!(predictions, Normal(mean_val, std_val))
#end

return predictions
return yhat

end

Expand Down
19 changes: 9 additions & 10 deletions test/mlj_flux_interfacing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@ function basictest_regression(X, y, builder, optimiser, threshold)
builder=builder,
optimiser=optimiser,
acceleration=CPUThreads(),
loss= Flux.Losses.mse,
rng=stable_rng,
lambda=-1.0,
alpha=-1.0,
epochs=-1,
batch_size=-1,
likelihood=:incorrect,
subset_of_weights=:incorrect,
hessian_structure=:incorrect,
backend=:incorrect,
link_approx=:incorrect,
backend=:incorrect
)

fitresult, cache, _report = MLJBase.fit(model, 0, X, y)
Expand Down Expand Up @@ -92,7 +91,7 @@ ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N)
builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu)
optimizer = Flux.Optimise.Adam(0.03)

@test basictest_regression(X, y, builder, optimizer, 0.9)
@test basictest_regression(X, ycont, builder, optimizer, 0.9)



Expand All @@ -103,16 +102,16 @@ function basictest_classification(X, y, builder, optimiser, threshold)

stable_rng = StableRNGs.StableRNG(123)

model = LaplaceRegression(;
model = LaplaceClassification(;
builder=builder,
optimiser=optimiser,
acceleration=CPUThreads(),
rng=stable_rng,
lambda=-1.0,
alpha=-1.0,
loss= Flux.crossentropy,
epochs=-1,
batch_size=-1,
likelihood=:incorrect,
lambda=-1.0,
alpha=-1.0,
rng=stable_rng,
acceleration=CPUThreads(),
subset_of_weights=:incorrect,
hessian_structure=:incorrect,
backend=:incorrect,
Expand Down

0 comments on commit 9275277

Please sign in to comment.