From 8664cf900f2f8575f5192de16a5277e24715fc5a Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Thu, 30 May 2024 19:13:14 +0200 Subject: [PATCH 01/32] split LaplaceApproximation. Removed Duplicated MMI.clean! function. Unified MLJFlux.shape! and MLJFlux.build! --- CHANGELOG.md | 12 ++ src/LaplaceRedux.jl | 4 +- src/mlj_flux.jl | 435 ++++++++++++++++++++++++++++++-------------- 3 files changed, 315 insertions(+), 136 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 916da01..f0e3467 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). *Note*: We try to adhere to these practices as of version [v0.2.1]. +## Version [0.2.2] - 2024-05-30 + +### Changed + +- Unified duplicated function MMI.clean!: previously MMI.clean! consisted of two separate functions for handling :classification and :regression types respectively. Now, a single MMI.clean! function handles both cases efficiently.[#39] +- Split LaplaceApproximation struct in two different structs:LaplaceClassification and LaplaceRegression [#39] +- Unified the MLJFlux.shape and the MLJFlux.build functions to handle both :classification and :regression tasks. In particular, shape now handles multi-output regression cases too [[#39](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/issues/39)] + +### Added + + + ## Version [0.2.1] - 2024-05-29 diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl index 6ac5e95..cc73699 100644 --- a/src/LaplaceRedux.jl +++ b/src/LaplaceRedux.jl @@ -20,6 +20,6 @@ export optimize_prior!, glm_predictive_distribution, posterior_covariance, posterior_precision include("mlj_flux.jl") -export LaplaceApproximation - +export LaplaceClassification +export LaplaceRegression end diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 3a766c2..3e49a53 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -6,8 +6,10 @@ using Random using Tables using ComputationalResources using Statistics +using Distributions +using LinearAlgebra -mutable struct LaplaceApproximation{B,F,O,L} <: MLJFlux.MLJFluxProbabilistic +mutable struct LaplaceClassification{B,F,O,L} <: MLJFlux.MLJFluxProbabilistic builder::B finaliser::F optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl @@ -62,7 +64,7 @@ The model also has the following parameters, which are specific to the Laplace a - `link_approx`: the link approximation to use, either `:probit` or `:plugin`. - `fit_params`: additional parameters to pass to the `fit!` method. """ -function LaplaceApproximation(; +function LaplaceClassification(; builder::B=MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish), finaliser::F=Flux.softmax, optimiser::O=Flux.Optimise.Adam(), @@ -85,7 +87,7 @@ function LaplaceApproximation(; link_approx::Symbol=:probit, fit_params::Dict{Symbol,Any}=Dict{Symbol,Any}(:override => true), ) where {B,F,O,L} - model = LaplaceApproximation( + model = LaplaceClassification( builder, finaliser, optimiser, @@ -116,15 +118,143 @@ function LaplaceApproximation(; return model end -function MLJFlux.shape(model::LaplaceApproximation, X, y) +mutable struct LaplaceRegression{B,F,O,L} <: MLJFlux.MLJFluxProbabilistic + builder::B + finaliser::F + optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl + loss::L # can be called as in `loss(yhat, y)` + epochs::Int # number of epochs + batch_size::Int # size of a batch + lambda::Float64 # regularization strength + alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) + rng::Union{AbstractRNG,Int64} + optimiser_changes_trigger_retraining::Bool + acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` + likelihood::Symbol + subset_of_weights::Symbol + subnetwork_indices::Vector{Vector{Int}} + hessian_structure::Union{HessianStructure,Symbol,String} + backend::Symbol + σ::Real + μ₀::Real + P₀::Union{AbstractMatrix,UniformScaling,Nothing} + link_approx::Symbol + fit_params::Dict{Symbol,Any} + la::Union{Nothing,AbstractLaplace} +end + +""" + LaplaceRegression(; builder, finaliser, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration, likelihood, subset_of_weights, subnetwork_indices, hessian_structure, backend, σ, μ₀, P₀, link_approx, fit_params) + +A probabilistic regression model that uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. The model is trained using the `fit!` method. The model is defined by the following default parameters for all `MLJFlux` models: + +- `builder`: a Flux model that constructs the neural network. +- `finaliser`: a Flux model that processes the output of the neural network. +- `optimiser`: a Flux optimiser. +- `loss`: a loss function that takes the predicted output and the true output as arguments. +- `epochs`: the number of epochs to train the model. +- `batch_size`: the size of a batch. +- `lambda`: the regularization strength. +- `alpha`: the regularization mix (0 for all l2, 1 for all l1). +- `rng`: a random number generator. +- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining. +- `acceleration`: the computational resource to use. + +The model also has the following parameters, which are specific to the Laplace approximation: + +- `likelihood`: the likelihood of the model `:regression`. +- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. +- `subnetwork_indices`: the indices of the subnetworks. +- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. +- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`. +- `σ`: the standard deviation of the prior distribution. +- `μ₀`: the mean of the prior distribution. +- `P₀`: the covariance matrix of the prior distribution. +- `link_approx`: the link approximation to use, either `:probit` or `:plugin`. +- `fit_params`: additional parameters to pass to the `fit!` method. +""" +function LaplaceRegression(; + builder::B=MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish), + finaliser::F=x->x, + optimiser::O=Flux.Optimise.Adam(), + loss::L=Flux.mse, + epochs::Int=10, + batch_size::Int=1, + lambda::Float64=1.0, + alpha::Float64=0.0, + rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG, + optimiser_changes_trigger_retraining::Bool=false, + acceleration::AbstractResource=CPU1(), + likelihood::Symbol=:regression, + subset_of_weights::Symbol=:all, + subnetwork_indices::Vector{Vector{Int}}=Vector{Vector{Int}}([]), + hessian_structure::Union{HessianStructure,Symbol,String}=:full, + backend::Symbol=:GGN, + σ::Float64=1.0, + μ₀::Float64=0.0, + P₀::Union{AbstractMatrix,UniformScaling,Nothing}=nothing, + link_approx::Symbol=:probit, + fit_params::Dict{Symbol,Any}=Dict{Symbol,Any}(:override => true), +) where {B,F,O,L} + model = LaplaceRegression( + builder, + finaliser, + optimiser, + loss, + epochs, + batch_size, + lambda, + alpha, + rng, + optimiser_changes_trigger_retraining, + acceleration, + likelihood, + subset_of_weights, + subnetwork_indices, + hessian_structure, + backend, + σ, + μ₀, + P₀, + link_approx, + fit_params, + nothing, + ) + + message = MMI.clean!(model) + isempty(message) || @warn message + + return model +end + + + + + + + +function MLJFlux.shape(model::Union{LaplaceClassification,LaplaceRegression}, X, y) X = X isa Matrix ? Tables.table(X) : X - levels = MMI.classes(y[1]) - n_output = length(levels) - n_input = length(Tables.schema(X).names) - return (n_input, n_output) + n_input = length(Tables.columnnames(X)) + + if model isa LaplaceClassification + levels = MMI.classes(y[1]) + n_output = length(levels) + return (n_input, n_output) + elseif model isa LaplaceRegression + dims = size(y) + if length(dims) == 1 + n_output= 1 + else + n_output= dims[2] + end + return (n_input, n_output) + end + end -function MLJFlux.build(model::LaplaceApproximation, rng, shape) + +function MLJFlux.build(model::Union{LaplaceClassification,LaplaceRegression}, rng, shape) # Construct the chain chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) # Construct Laplace model and store it in the model object @@ -142,11 +272,75 @@ function MLJFlux.build(model::LaplaceApproximation, rng, shape) return chain end -function MLJFlux.fitresult(model::LaplaceApproximation, chain, y) - return (chain, model.la, MMI.classes(y[1])) +function MLJFlux.fitresult(model::Union{LaplaceClassification,LaplaceRegression}, chain, y) + if model isa LaplaceClassification + return (chain, model.la, MMI.classes(y[1])) + else + return (chain, model.la, size(y) ) + end end -function MLJFlux.train!(model::LaplaceApproximation, penalty, chain, optimiser, X, y) + +function MMI.clean!(model::Union{LaplaceClassification,LaplaceRegression}) + warning = "" + if model.lambda < 0 + warning *= "Need `lambda ≥ 0`. Resetting `lambda = 0`. " + model.lambda = 0 + end + if model.alpha < 0 || model.alpha > 1 + warning *= "Need alpha in the interval `[0, 1]`. " * "Resetting `alpha = 0`. " + model.alpha = 0 + end + if model.epochs < 0 + warning *= "Need `epochs ≥ 0`. Resetting `epochs = 10`. " + model.epochs = 10 + end + if model.batch_size <= 0 + warning *= "Need `batch_size > 0`. Resetting `batch_size = 1`. " + model.batch_size = 1 + end + if model.acceleration isa CUDALibs && gpu_isdead() + warning *= + "`acceleration isa CUDALibs` " * "but no CUDA device (GPU) currently live. " + end + if !(model.acceleration isa CUDALibs || model.acceleration isa CPU1) + warning *= "`Undefined acceleration, falling back to CPU`" + model.acceleration = CPU1() + end + if model.likelihood ∉ (:regression, :classification) + warning *= "Need `likelihood ∈ (:regression, :classification)`. " * + "Resetting to default `likelihood = :regression`. " + model.likelihood = :regression + end + if model.subset_of_weights ∉ (:all, :last_layer, :subnetwork) + warning *= + "Need `subset_of_weights ∈ (:all, :last_layer, :subnetwork)`. " * + "Resetting `subset_of_weights = :all`. " + model.subset_of_weights = :all + end + if String(model.hessian_structure) ∉ ("full", "diagonal") && + !(typeof(model.hessian_structure) <: HessianStructure) + warning *= + "Need `hessian_structure ∈ (:full, :diagonal)` or `hessian_structure ∈ (:full, :diagonal)` or typeof(model.hessian_structure) <: HessianStructure." * + "Resetting `hessian_structure = :full`. " + model.hessian_structure = :full + end + if model.backend ∉ (:GGN, :EmpiricalFisher) + warning *= + "Need `backend ∈ (:GGN, :EmpiricalFisher)`. " * "Resetting `backend = :GGN`. " + model.backend = :GGN + end + if model.link_approx ∉ (:probit, :plugin) + warning *= + "Need `link_approx ∈ (:probit, :plugin)`. " * + "Resetting `link_approx = :probit`. " + model.link_approx = :probit + end + return warning +end +######################################### train , fit and predict for classification + +function MLJFlux.train!(model::LaplaceClassification, penalty, chain, optimiser, X, y) loss = model.loss n_batches = length(y) training_loss = zero(Float32) @@ -164,7 +358,7 @@ function MLJFlux.train!(model::LaplaceApproximation, penalty, chain, optimiser, end function MLJFlux.fit!( - model::LaplaceApproximation, penalty, chain, optimiser, epochs, verbosity, X, y + model::LaplaceClassification, penalty, chain, optimiser, epochs, verbosity, X, y ) loss = model.loss @@ -210,66 +404,9 @@ function MLJFlux.fit!( return chain, history end -function MMI.clean!(model::LaplaceApproximation) - warning = "" - if model.lambda < 0 - warning *= "Need `lambda ≥ 0`. Resetting `lambda = 0`. " - model.lambda = 0 - end - if model.alpha < 0 || model.alpha > 1 - warning *= "Need alpha in the interval `[0, 1]`. " * "Resetting `alpha = 0`. " - model.alpha = 0 - end - if model.epochs < 0 - warning *= "Need `epochs ≥ 0`. Resetting `epochs = 10`. " - model.epochs = 10 - end - if model.batch_size <= 0 - warning *= "Need `batch_size > 0`. Resetting `batch_size = 1`. " - model.batch_size = 1 - end - if model.acceleration isa CUDALibs && gpu_isdead() - warning *= - "`acceleration isa CUDALibs` " * "but no CUDA device (GPU) currently live. " - end - if !(model.acceleration isa CUDALibs || model.acceleration isa CPU1) - warning *= "`Undefined acceleration, falling back to CPU`" - model.acceleration = CPU1() - end - if model.likelihood ∉ (:classification, :regression) - warning *= - "Need `likelihood ∈ (:classification, :regression)`. " * - "Resetting `likelihood = :classification`. " - model.likelihood = :classification - end - if model.subset_of_weights ∉ (:all, :last_layer, :subnetwork) - warning *= - "Need `subset_of_weights ∈ (:all, :last_layer, :subnetwork)`. " * - "Resetting `subset_of_weights = :all`. " - model.subset_of_weights = :all - end - if String(model.hessian_structure) ∉ ("full", "diagonal") && - !(typeof(model.hessian_structure) <: HessianStructure) - warning *= - "Need `hessian_structure ∈ (:full, :diagonal)` or `hessian_structure ∈ (:full, :diagonal)` or typeof(model.hessian_structure) <: HessianStructure." * - "Resetting `hessian_structure = :full`. " - model.hessian_structure = :full - end - if model.backend ∉ (:GGN, :EmpiricalFisher) - warning *= - "Need `backend ∈ (:GGN, :EmpiricalFisher)`. " * "Resetting `backend = :GGN`. " - model.backend = :GGN - end - if model.link_approx ∉ (:probit, :plugin) - warning *= - "Need `link_approx ∈ (:probit, :plugin)`. " * - "Resetting `link_approx = :probit`. " - model.link_approx = :probit - end - return warning -end -function MMI.predict(model::LaplaceApproximation, fitresult, Xnew) + +function MMI.predict(model::LaplaceClassification, fitresult, Xnew) chain, la, levels = fitresult # re-format Xnew into acceptable input for Laplace: X = MLJFlux.reformat(Xnew) @@ -280,74 +417,104 @@ function MMI.predict(model::LaplaceApproximation, fitresult, Xnew) i in 1:size(X, 2) ]..., ) - if la.likelihood == :classification - # return a UnivariateFinite: - return MMI.UnivariateFinite(levels, yhat) - end - if la.likelihood == :regression - # return a UnivariateNormal: - return MMI.UnivariateNormal(yhat[1], sqrt(yhat[2])) - end -end -function _isdefined(object, name) - pnames = propertynames(object) - fnames = fieldnames(typeof(object)) - name in pnames && !(name in fnames) && return true - return isdefined(object, name) -end + return MMI.UnivariateFinite(levels, yhat) + -function _equal_to_depth_one(x1, x2) - names = propertynames(x1) - names === propertynames(x2) || return false - for name in names - getproperty(x1, name) == getproperty(x2, name) || return false - end - return true end -function MMI.is_same_except( - m1::M1, m2::M2, exceptions::Symbol... -) where {M1<:LaplaceApproximation,M2<:LaplaceApproximation} - typeof(m1) === typeof(m2) || return false - names = propertynames(m1) - propertynames(m2) === names || return false - - for name in names - if !(name in exceptions) && name != :la - if !_isdefined(m1, name) - !_isdefined(m2, name) || return false - elseif _isdefined(m2, name) - if name in MLJFlux.deep_properties(M1) - _equal_to_depth_one(getproperty(m1, name), getproperty(m2, name)) || - return false - else - ( - MMI.is_same_except(getproperty(m1, name), getproperty(m2, name)) || - getproperty(m1, name) isa AbstractRNG || - getproperty(m2, name) isa AbstractRNG - ) || return false - end - else - return false - end + +######################################################## train ,fit and predict for regression + + + + +function MLJFlux.train!(model::LaplaceRegression, penalty, chain, optimiser, X, y) + loss = model.loss + n_batches = length(y) + training_loss = zero(Float32) + for i in 1:n_batches + parameters = Flux.params(chain) + gs = Flux.gradient(parameters) do + yhat = chain(X[i]) + batch_loss = loss(yhat, y[i]) + penalty(parameters) / n_batches + training_loss += batch_loss + return batch_loss end + Flux.update!(optimiser, parameters, gs) end - return true + return training_loss / n_batches end -MMI.metadata_model( - LaplaceApproximation; - input=Union{ - AbstractMatrix{MMI.Continuous}, - MMI.Table(MMI.Continuous), - MMI.Table{AbstractVector{MMI.Continuous}}, - }, - target=Union{ - AbstractArray{MMI.Finite}, - AbstractArray{MMI.Continuous}, - AbstractVector{MMI.Finite}, - AbstractVector{MMI.Continuous}, - }, - path="MLJFlux.LaplaceApproximation", +function MLJFlux.fit!( + model::LaplaceRegression, penalty, chain, optimiser, epochs, verbosity, X, y ) + loss = model.loss + + # intitialize and start progress meter: + meter = Progress( + epochs + 1; + dt=0, + desc="Optimising neural net:", + barglyphs=BarGlyphs("[=> ]"), + barlen=25, + color=:yellow, + ) + verbosity != 1 || next!(meter) + + # initiate history: + n_batches = length(y) + + parameters = Flux.params(chain) + + # initial loss: + losses = ( + loss(chain(X[i]), y[i]) + penalty(parameters) / n_batches for i in 1:n_batches + ) + history = [mean(losses)] + + for i in 1:epochs + current_loss = MLJFlux.train!( + model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y + ) + verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" + verbosity != 1 || next!(meter) + push!(history, current_loss) + end + + la = model.la + + # fit the Laplace model: + fit!(la, zip(X, y); model.fit_params...) + optimize_prior!(la; verbose=false, n_steps=100) + + model.la = la + + return chain, history +end + + +function MMI.predict(model::LaplaceRegression, fitresult, Xnew) + chain, la, levels = fitresult + # re-format Xnew into acceptable input for Laplace: + X = MLJFlux.reformat(Xnew) + # predict using Laplace: + yhat = vcat( + [ + glm_predictive_distribution(la, MLJFlux.tomat(X[:, i]); link_approx=model.link_approx)' for + i in 1:size(X, 2) + ]..., + ) + println(size(yhat)) + predictions = [] + for row in eachrow(yhat) + + mean_val = Float64(row[1][1]) + std_val = sqrt(Float64(row[2][1])) + # Append a Normal distribution: + push!(predictions, Normal(mean_val, std_val)) + end + + return predictions + +end From 039f172971caf00ee4a3f643f9683b7a5071e5bd Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Thu, 30 May 2024 19:13:14 +0200 Subject: [PATCH 02/32] updated metadata and changelog. uploaded Project.toml --- CHANGELOG.md | 13 ++ Project.toml | 1 + src/LaplaceRedux.jl | 4 +- src/mlj_flux.jl | 435 ++++++++++++++++++++++++++++++-------------- 4 files changed, 317 insertions(+), 136 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 916da01..09b9be1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). *Note*: We try to adhere to these practices as of version [v0.2.1]. +## Version [0.2.2] - 2024-05-30 + +### Changed + +- Unified duplicated function MMI.clean!: previously MMI.clean! consisted of two separate functions for handling :classification and :regression types respectively. Now, a single MMI.clean! function handles both cases efficiently.[#39] +- Split LaplaceApproximation struct in two different structs:LaplaceClassification and LaplaceRegression [#39] +- Unified the MLJFlux.shape and the MLJFlux.build functions to handle both :classification and :regression tasks. In particular, shape now handles multi-output regression cases too [[#39](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/issues/39)] +- Changed model metadata for LaplaceClassification and LaplaceRegression + +### Added + Added Distributions to LaplaceRedux dependency ( needed for MMI.predict(model::LaplaceRegression, fitresult, Xnew) ) + + ## Version [0.2.1] - 2024-05-29 diff --git a/Project.toml b/Project.toml index 24355ed..8353012 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.2.1" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl index 6ac5e95..cc73699 100644 --- a/src/LaplaceRedux.jl +++ b/src/LaplaceRedux.jl @@ -20,6 +20,6 @@ export optimize_prior!, glm_predictive_distribution, posterior_covariance, posterior_precision include("mlj_flux.jl") -export LaplaceApproximation - +export LaplaceClassification +export LaplaceRegression end diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 3a766c2..3e49a53 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -6,8 +6,10 @@ using Random using Tables using ComputationalResources using Statistics +using Distributions +using LinearAlgebra -mutable struct LaplaceApproximation{B,F,O,L} <: MLJFlux.MLJFluxProbabilistic +mutable struct LaplaceClassification{B,F,O,L} <: MLJFlux.MLJFluxProbabilistic builder::B finaliser::F optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl @@ -62,7 +64,7 @@ The model also has the following parameters, which are specific to the Laplace a - `link_approx`: the link approximation to use, either `:probit` or `:plugin`. - `fit_params`: additional parameters to pass to the `fit!` method. """ -function LaplaceApproximation(; +function LaplaceClassification(; builder::B=MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish), finaliser::F=Flux.softmax, optimiser::O=Flux.Optimise.Adam(), @@ -85,7 +87,7 @@ function LaplaceApproximation(; link_approx::Symbol=:probit, fit_params::Dict{Symbol,Any}=Dict{Symbol,Any}(:override => true), ) where {B,F,O,L} - model = LaplaceApproximation( + model = LaplaceClassification( builder, finaliser, optimiser, @@ -116,15 +118,143 @@ function LaplaceApproximation(; return model end -function MLJFlux.shape(model::LaplaceApproximation, X, y) +mutable struct LaplaceRegression{B,F,O,L} <: MLJFlux.MLJFluxProbabilistic + builder::B + finaliser::F + optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl + loss::L # can be called as in `loss(yhat, y)` + epochs::Int # number of epochs + batch_size::Int # size of a batch + lambda::Float64 # regularization strength + alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) + rng::Union{AbstractRNG,Int64} + optimiser_changes_trigger_retraining::Bool + acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` + likelihood::Symbol + subset_of_weights::Symbol + subnetwork_indices::Vector{Vector{Int}} + hessian_structure::Union{HessianStructure,Symbol,String} + backend::Symbol + σ::Real + μ₀::Real + P₀::Union{AbstractMatrix,UniformScaling,Nothing} + link_approx::Symbol + fit_params::Dict{Symbol,Any} + la::Union{Nothing,AbstractLaplace} +end + +""" + LaplaceRegression(; builder, finaliser, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration, likelihood, subset_of_weights, subnetwork_indices, hessian_structure, backend, σ, μ₀, P₀, link_approx, fit_params) + +A probabilistic regression model that uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. The model is trained using the `fit!` method. The model is defined by the following default parameters for all `MLJFlux` models: + +- `builder`: a Flux model that constructs the neural network. +- `finaliser`: a Flux model that processes the output of the neural network. +- `optimiser`: a Flux optimiser. +- `loss`: a loss function that takes the predicted output and the true output as arguments. +- `epochs`: the number of epochs to train the model. +- `batch_size`: the size of a batch. +- `lambda`: the regularization strength. +- `alpha`: the regularization mix (0 for all l2, 1 for all l1). +- `rng`: a random number generator. +- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining. +- `acceleration`: the computational resource to use. + +The model also has the following parameters, which are specific to the Laplace approximation: + +- `likelihood`: the likelihood of the model `:regression`. +- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. +- `subnetwork_indices`: the indices of the subnetworks. +- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. +- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`. +- `σ`: the standard deviation of the prior distribution. +- `μ₀`: the mean of the prior distribution. +- `P₀`: the covariance matrix of the prior distribution. +- `link_approx`: the link approximation to use, either `:probit` or `:plugin`. +- `fit_params`: additional parameters to pass to the `fit!` method. +""" +function LaplaceRegression(; + builder::B=MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish), + finaliser::F=x->x, + optimiser::O=Flux.Optimise.Adam(), + loss::L=Flux.mse, + epochs::Int=10, + batch_size::Int=1, + lambda::Float64=1.0, + alpha::Float64=0.0, + rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG, + optimiser_changes_trigger_retraining::Bool=false, + acceleration::AbstractResource=CPU1(), + likelihood::Symbol=:regression, + subset_of_weights::Symbol=:all, + subnetwork_indices::Vector{Vector{Int}}=Vector{Vector{Int}}([]), + hessian_structure::Union{HessianStructure,Symbol,String}=:full, + backend::Symbol=:GGN, + σ::Float64=1.0, + μ₀::Float64=0.0, + P₀::Union{AbstractMatrix,UniformScaling,Nothing}=nothing, + link_approx::Symbol=:probit, + fit_params::Dict{Symbol,Any}=Dict{Symbol,Any}(:override => true), +) where {B,F,O,L} + model = LaplaceRegression( + builder, + finaliser, + optimiser, + loss, + epochs, + batch_size, + lambda, + alpha, + rng, + optimiser_changes_trigger_retraining, + acceleration, + likelihood, + subset_of_weights, + subnetwork_indices, + hessian_structure, + backend, + σ, + μ₀, + P₀, + link_approx, + fit_params, + nothing, + ) + + message = MMI.clean!(model) + isempty(message) || @warn message + + return model +end + + + + + + + +function MLJFlux.shape(model::Union{LaplaceClassification,LaplaceRegression}, X, y) X = X isa Matrix ? Tables.table(X) : X - levels = MMI.classes(y[1]) - n_output = length(levels) - n_input = length(Tables.schema(X).names) - return (n_input, n_output) + n_input = length(Tables.columnnames(X)) + + if model isa LaplaceClassification + levels = MMI.classes(y[1]) + n_output = length(levels) + return (n_input, n_output) + elseif model isa LaplaceRegression + dims = size(y) + if length(dims) == 1 + n_output= 1 + else + n_output= dims[2] + end + return (n_input, n_output) + end + end -function MLJFlux.build(model::LaplaceApproximation, rng, shape) + +function MLJFlux.build(model::Union{LaplaceClassification,LaplaceRegression}, rng, shape) # Construct the chain chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) # Construct Laplace model and store it in the model object @@ -142,11 +272,75 @@ function MLJFlux.build(model::LaplaceApproximation, rng, shape) return chain end -function MLJFlux.fitresult(model::LaplaceApproximation, chain, y) - return (chain, model.la, MMI.classes(y[1])) +function MLJFlux.fitresult(model::Union{LaplaceClassification,LaplaceRegression}, chain, y) + if model isa LaplaceClassification + return (chain, model.la, MMI.classes(y[1])) + else + return (chain, model.la, size(y) ) + end end -function MLJFlux.train!(model::LaplaceApproximation, penalty, chain, optimiser, X, y) + +function MMI.clean!(model::Union{LaplaceClassification,LaplaceRegression}) + warning = "" + if model.lambda < 0 + warning *= "Need `lambda ≥ 0`. Resetting `lambda = 0`. " + model.lambda = 0 + end + if model.alpha < 0 || model.alpha > 1 + warning *= "Need alpha in the interval `[0, 1]`. " * "Resetting `alpha = 0`. " + model.alpha = 0 + end + if model.epochs < 0 + warning *= "Need `epochs ≥ 0`. Resetting `epochs = 10`. " + model.epochs = 10 + end + if model.batch_size <= 0 + warning *= "Need `batch_size > 0`. Resetting `batch_size = 1`. " + model.batch_size = 1 + end + if model.acceleration isa CUDALibs && gpu_isdead() + warning *= + "`acceleration isa CUDALibs` " * "but no CUDA device (GPU) currently live. " + end + if !(model.acceleration isa CUDALibs || model.acceleration isa CPU1) + warning *= "`Undefined acceleration, falling back to CPU`" + model.acceleration = CPU1() + end + if model.likelihood ∉ (:regression, :classification) + warning *= "Need `likelihood ∈ (:regression, :classification)`. " * + "Resetting to default `likelihood = :regression`. " + model.likelihood = :regression + end + if model.subset_of_weights ∉ (:all, :last_layer, :subnetwork) + warning *= + "Need `subset_of_weights ∈ (:all, :last_layer, :subnetwork)`. " * + "Resetting `subset_of_weights = :all`. " + model.subset_of_weights = :all + end + if String(model.hessian_structure) ∉ ("full", "diagonal") && + !(typeof(model.hessian_structure) <: HessianStructure) + warning *= + "Need `hessian_structure ∈ (:full, :diagonal)` or `hessian_structure ∈ (:full, :diagonal)` or typeof(model.hessian_structure) <: HessianStructure." * + "Resetting `hessian_structure = :full`. " + model.hessian_structure = :full + end + if model.backend ∉ (:GGN, :EmpiricalFisher) + warning *= + "Need `backend ∈ (:GGN, :EmpiricalFisher)`. " * "Resetting `backend = :GGN`. " + model.backend = :GGN + end + if model.link_approx ∉ (:probit, :plugin) + warning *= + "Need `link_approx ∈ (:probit, :plugin)`. " * + "Resetting `link_approx = :probit`. " + model.link_approx = :probit + end + return warning +end +######################################### train , fit and predict for classification + +function MLJFlux.train!(model::LaplaceClassification, penalty, chain, optimiser, X, y) loss = model.loss n_batches = length(y) training_loss = zero(Float32) @@ -164,7 +358,7 @@ function MLJFlux.train!(model::LaplaceApproximation, penalty, chain, optimiser, end function MLJFlux.fit!( - model::LaplaceApproximation, penalty, chain, optimiser, epochs, verbosity, X, y + model::LaplaceClassification, penalty, chain, optimiser, epochs, verbosity, X, y ) loss = model.loss @@ -210,66 +404,9 @@ function MLJFlux.fit!( return chain, history end -function MMI.clean!(model::LaplaceApproximation) - warning = "" - if model.lambda < 0 - warning *= "Need `lambda ≥ 0`. Resetting `lambda = 0`. " - model.lambda = 0 - end - if model.alpha < 0 || model.alpha > 1 - warning *= "Need alpha in the interval `[0, 1]`. " * "Resetting `alpha = 0`. " - model.alpha = 0 - end - if model.epochs < 0 - warning *= "Need `epochs ≥ 0`. Resetting `epochs = 10`. " - model.epochs = 10 - end - if model.batch_size <= 0 - warning *= "Need `batch_size > 0`. Resetting `batch_size = 1`. " - model.batch_size = 1 - end - if model.acceleration isa CUDALibs && gpu_isdead() - warning *= - "`acceleration isa CUDALibs` " * "but no CUDA device (GPU) currently live. " - end - if !(model.acceleration isa CUDALibs || model.acceleration isa CPU1) - warning *= "`Undefined acceleration, falling back to CPU`" - model.acceleration = CPU1() - end - if model.likelihood ∉ (:classification, :regression) - warning *= - "Need `likelihood ∈ (:classification, :regression)`. " * - "Resetting `likelihood = :classification`. " - model.likelihood = :classification - end - if model.subset_of_weights ∉ (:all, :last_layer, :subnetwork) - warning *= - "Need `subset_of_weights ∈ (:all, :last_layer, :subnetwork)`. " * - "Resetting `subset_of_weights = :all`. " - model.subset_of_weights = :all - end - if String(model.hessian_structure) ∉ ("full", "diagonal") && - !(typeof(model.hessian_structure) <: HessianStructure) - warning *= - "Need `hessian_structure ∈ (:full, :diagonal)` or `hessian_structure ∈ (:full, :diagonal)` or typeof(model.hessian_structure) <: HessianStructure." * - "Resetting `hessian_structure = :full`. " - model.hessian_structure = :full - end - if model.backend ∉ (:GGN, :EmpiricalFisher) - warning *= - "Need `backend ∈ (:GGN, :EmpiricalFisher)`. " * "Resetting `backend = :GGN`. " - model.backend = :GGN - end - if model.link_approx ∉ (:probit, :plugin) - warning *= - "Need `link_approx ∈ (:probit, :plugin)`. " * - "Resetting `link_approx = :probit`. " - model.link_approx = :probit - end - return warning -end -function MMI.predict(model::LaplaceApproximation, fitresult, Xnew) + +function MMI.predict(model::LaplaceClassification, fitresult, Xnew) chain, la, levels = fitresult # re-format Xnew into acceptable input for Laplace: X = MLJFlux.reformat(Xnew) @@ -280,74 +417,104 @@ function MMI.predict(model::LaplaceApproximation, fitresult, Xnew) i in 1:size(X, 2) ]..., ) - if la.likelihood == :classification - # return a UnivariateFinite: - return MMI.UnivariateFinite(levels, yhat) - end - if la.likelihood == :regression - # return a UnivariateNormal: - return MMI.UnivariateNormal(yhat[1], sqrt(yhat[2])) - end -end -function _isdefined(object, name) - pnames = propertynames(object) - fnames = fieldnames(typeof(object)) - name in pnames && !(name in fnames) && return true - return isdefined(object, name) -end + return MMI.UnivariateFinite(levels, yhat) + -function _equal_to_depth_one(x1, x2) - names = propertynames(x1) - names === propertynames(x2) || return false - for name in names - getproperty(x1, name) == getproperty(x2, name) || return false - end - return true end -function MMI.is_same_except( - m1::M1, m2::M2, exceptions::Symbol... -) where {M1<:LaplaceApproximation,M2<:LaplaceApproximation} - typeof(m1) === typeof(m2) || return false - names = propertynames(m1) - propertynames(m2) === names || return false - - for name in names - if !(name in exceptions) && name != :la - if !_isdefined(m1, name) - !_isdefined(m2, name) || return false - elseif _isdefined(m2, name) - if name in MLJFlux.deep_properties(M1) - _equal_to_depth_one(getproperty(m1, name), getproperty(m2, name)) || - return false - else - ( - MMI.is_same_except(getproperty(m1, name), getproperty(m2, name)) || - getproperty(m1, name) isa AbstractRNG || - getproperty(m2, name) isa AbstractRNG - ) || return false - end - else - return false - end + +######################################################## train ,fit and predict for regression + + + + +function MLJFlux.train!(model::LaplaceRegression, penalty, chain, optimiser, X, y) + loss = model.loss + n_batches = length(y) + training_loss = zero(Float32) + for i in 1:n_batches + parameters = Flux.params(chain) + gs = Flux.gradient(parameters) do + yhat = chain(X[i]) + batch_loss = loss(yhat, y[i]) + penalty(parameters) / n_batches + training_loss += batch_loss + return batch_loss end + Flux.update!(optimiser, parameters, gs) end - return true + return training_loss / n_batches end -MMI.metadata_model( - LaplaceApproximation; - input=Union{ - AbstractMatrix{MMI.Continuous}, - MMI.Table(MMI.Continuous), - MMI.Table{AbstractVector{MMI.Continuous}}, - }, - target=Union{ - AbstractArray{MMI.Finite}, - AbstractArray{MMI.Continuous}, - AbstractVector{MMI.Finite}, - AbstractVector{MMI.Continuous}, - }, - path="MLJFlux.LaplaceApproximation", +function MLJFlux.fit!( + model::LaplaceRegression, penalty, chain, optimiser, epochs, verbosity, X, y ) + loss = model.loss + + # intitialize and start progress meter: + meter = Progress( + epochs + 1; + dt=0, + desc="Optimising neural net:", + barglyphs=BarGlyphs("[=> ]"), + barlen=25, + color=:yellow, + ) + verbosity != 1 || next!(meter) + + # initiate history: + n_batches = length(y) + + parameters = Flux.params(chain) + + # initial loss: + losses = ( + loss(chain(X[i]), y[i]) + penalty(parameters) / n_batches for i in 1:n_batches + ) + history = [mean(losses)] + + for i in 1:epochs + current_loss = MLJFlux.train!( + model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y + ) + verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" + verbosity != 1 || next!(meter) + push!(history, current_loss) + end + + la = model.la + + # fit the Laplace model: + fit!(la, zip(X, y); model.fit_params...) + optimize_prior!(la; verbose=false, n_steps=100) + + model.la = la + + return chain, history +end + + +function MMI.predict(model::LaplaceRegression, fitresult, Xnew) + chain, la, levels = fitresult + # re-format Xnew into acceptable input for Laplace: + X = MLJFlux.reformat(Xnew) + # predict using Laplace: + yhat = vcat( + [ + glm_predictive_distribution(la, MLJFlux.tomat(X[:, i]); link_approx=model.link_approx)' for + i in 1:size(X, 2) + ]..., + ) + println(size(yhat)) + predictions = [] + for row in eachrow(yhat) + + mean_val = Float64(row[1][1]) + std_val = sqrt(Float64(row[2][1])) + # Append a Normal distribution: + push!(predictions, Normal(mean_val, std_val)) + end + + return predictions + +end From 9f4edbdbad24a9241e973e25ee1598b663919e35 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Mon, 3 Jun 2024 03:16:25 +0200 Subject: [PATCH 03/32] fixed fit! function for regression. classification not complete yet. --- src/mlj_flux.jl | 273 ++++++++++++++++++++++++++---------------------- 1 file changed, 150 insertions(+), 123 deletions(-) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index b290a8c..c851306 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -8,6 +8,7 @@ using ComputationalResources using Statistics using Distributions using LinearAlgebra +using LaplaceRedux mutable struct LaplaceClassification{B,F,O,L} <: MLJFlux.MLJFluxProbabilistic builder::B @@ -35,9 +36,9 @@ mutable struct LaplaceClassification{B,F,O,L} <: MLJFlux.MLJFluxProbabilistic end """ - LaplaceApproximation(; builder, finaliser, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration, likelihood, subset_of_weights, subnetwork_indices, hessian_structure, backend, σ, μ₀, P₀, link_approx, fit_params) + LaplaceClassification(; builder, finaliser, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration, likelihood, subset_of_weights, subnetwork_indices, hessian_structure, backend, σ, μ₀, P₀, link_approx, fit_params) -A probabilistic model that uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. The model is trained using the `fit!` method. The model is defined by the following default parameters for all `MLJFlux` models: +A probabilistic classification model that uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. The model is trained using the `fit!` method. The model is defined by the following default parameters for all `MLJFlux` models: - `builder`: a Flux model that constructs the neural network. - `finaliser`: a Flux model that processes the output of the neural network. @@ -53,7 +54,6 @@ A probabilistic model that uses Laplace approximation to estimate the posterior The model also has the following parameters, which are specific to the Laplace approximation: -- `likelihood`: the likelihood of the model, either `:classification` or `:regression`. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. - `subnetwork_indices`: the indices of the subnetworks. - `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. @@ -87,6 +87,7 @@ function LaplaceClassification(; fit_params::Dict{Symbol,Any}=Dict{Symbol,Any}(:override => true), ) where {B,F,O,L} likelihood = :classification + la= :classification model = LaplaceClassification( builder, finaliser, @@ -109,7 +110,7 @@ function LaplaceClassification(; P₀, link_approx, fit_params, - nothing, + la, ) message = MMI.clean!(model) @@ -132,7 +133,7 @@ mutable struct LaplaceRegression{B,F,O,L} <: MLJFlux.MLJFluxProbabilistic acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` likelihood::Symbol subset_of_weights::Symbol - subnetwork_indices::Vector{Vector{Int}} + subnetwork_indices::Union{Nothing,Vector{Vector{Int}}} hessian_structure::Union{HessianStructure,Symbol,String} backend::Symbol σ::Real @@ -175,7 +176,7 @@ function LaplaceRegression(; builder::B=MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish), finaliser::F=x->x, optimiser::O=Flux.Optimise.Adam(), - loss::L=Flux.mse, + loss::L=Flux.Losses.mse, epochs::Int=10, batch_size::Int=1, lambda::Float64=1.0, @@ -184,7 +185,7 @@ function LaplaceRegression(; optimiser_changes_trigger_retraining::Bool=false, acceleration::AbstractResource=CPU1(), subset_of_weights::Symbol=:all, - subnetwork_indices::Vector{Vector{Int}}=Vector{Vector{Int}}([]), + subnetwork_indices=nothing, hessian_structure::Union{HessianStructure,Symbol,String}=:full, backend::Symbol=:GGN, σ::Float64=1.0, @@ -193,6 +194,7 @@ function LaplaceRegression(; fit_params::Dict{Symbol,Any}=Dict{Symbol,Any}(:override => true), ) where {B,F,O,L} likelihood=:regression + la= nothing model = LaplaceRegression( builder, finaliser, @@ -214,7 +216,7 @@ function LaplaceRegression(; μ₀, P₀, fit_params, - nothing, + la, ) message = MMI.clean!(model) @@ -251,8 +253,8 @@ end function MLJFlux.build(model::Union{LaplaceClassification,LaplaceRegression}, rng, shape) - # Construct the chain - chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) + # Construct the initial chain + chain = MLJFlux.build(model.builder, rng, shape...) # Construct Laplace model and store it in the model object model.la = Laplace( chain; @@ -334,30 +336,33 @@ function MMI.clean!(model::Union{LaplaceClassification,LaplaceRegression}) end return warning end -######################################### train , fit and predict for classification - -function MLJFlux.train!(model::LaplaceClassification, penalty, chain, optimiser, X, y) - loss = model.loss - n_batches = length(y) - training_loss = zero(Float32) - for i in 1:n_batches - parameters = Flux.params(chain) - gs = Flux.gradient(parameters) do - yhat = chain(X[i]) - batch_loss = loss(yhat, y[i]) + penalty(parameters) / n_batches - training_loss += batch_loss - return batch_loss - end - Flux.update!(optimiser, parameters, gs) - end - return training_loss / n_batches -end -function MLJFlux.fit!( - model::LaplaceClassification, penalty, chain, optimiser, epochs, verbosity, X, y -) - loss = model.loss +######################################################## fit and predict for regression +function MLJFlux.fit!(model::LaplaceRegression, penalty, verbosity, X, y) + + epochs= model.epochs + n_samples= size(X, 1) + + # Determine the shape of the model + shape = MLJFlux.shape(model, X, y) + + # Build the chain + chain = MLJFlux.build(model, model.rng, shape) + la= model.la + + optimiser= model.optimiser + + # Initialize history: + n_samples = size(X, 1) + history = [] + # Define the loss function for Laplace Regression with a custom penalty + function custom_loss( X_batch, y_batch) + preds = chain(X_batch) + data_loss = model.loss(y_batch, preds) + penalty_term = penalty(params(chain)) + return data_loss + penalty_term + end # intitialize and start progress meter: meter = Progress( epochs + 1; @@ -367,150 +372,172 @@ function MLJFlux.fit!( barlen=25, color=:yellow, ) - verbosity != 1 || next!(meter) - - # initiate history: - n_batches = length(y) - + # Create a data loader + loader = Flux.Data.DataLoader((data=X', label=y), batchsize=model.batch_size, shuffle=true) parameters = Flux.params(chain) - - # initial loss: - losses = ( - loss(chain(X[i]), y[i]) + penalty(parameters) / n_batches for i in 1:n_batches - ) - history = [mean(losses)] - for i in 1:epochs - current_loss = MLJFlux.train!( - model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y - ) - verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" - verbosity != 1 || next!(meter) - push!(history, current_loss) + epoch_loss = 0.0 + # train the model + for (X_batch, y_batch) in loader + y_batch = reshape(y_batch,1,:) + + # Backward pass + gs = Flux.gradient(parameters) do + batch_loss = Flux.Losses.mse(chain(X_batch), y_batch) + epoch_loss += batch_loss + end + # Update parameters + Flux.update!(optimiser, parameters,gs) + end + epoch_loss /= n_samples + push!(history, epoch_loss) + #verbosity + if verbosity == 1 + next!(meter) + elseif verbosity ==2 + next!(meter) + println( "Loss is $(round(epoch_loss; sigdigits=4))") + end end - la = model.la + # fit the Laplace model: - fit!(la, zip(X, y); model.fit_params...) - optimize_prior!(la; verbose=false, n_steps=100) + LaplaceRedux.fit!(model.la,zip(eachrow(X),y)) + optimize_prior!(model.la; verbose=false, n_steps=100) - model.la = la return chain, history end - -function MMI.predict(model::LaplaceClassification, fitresult, Xnew) +function MMI.predict(model::LaplaceRegression, fitresult, Xnew) chain, la, levels = fitresult # re-format Xnew into acceptable input for Laplace: X = MLJFlux.reformat(Xnew) # predict using Laplace: yhat = vcat( [ - predict(la, MLJFlux.tomat(X[:, i]); link_approx=model.link_approx)' for + glm_predictive_distribution(la, MLJFlux.tomat(X[:, i]))' for i in 1:size(X, 2) ]..., ) - - return MMI.UnivariateFinite(levels, yhat) - + println(size(yhat)) + predictions = [] + for row in eachrow(yhat) + + mean_val = Float64(row[1][1]) + std_val = sqrt(Float64(row[2][1])) + # Append a Normal distribution: + push!(predictions, Normal(mean_val, std_val)) + end + + return predictions end -######################################################## train ,fit and predict for regression +######################################### fit and predict for classification -function MLJFlux.train!(model::LaplaceRegression, penalty, chain, optimiser, X, y) - loss = model.loss - n_batches = length(y) - training_loss = zero(Float32) - for i in 1:n_batches - parameters = Flux.params(chain) - gs = Flux.gradient(parameters) do - yhat = chain(X[i]) - batch_loss = loss(yhat, y[i]) + penalty(parameters) / n_batches - training_loss += batch_loss - return batch_loss - end - Flux.update!(optimiser, parameters, gs) - end - return training_loss / n_batches -end function MLJFlux.fit!( - model::LaplaceRegression, penalty, chain, optimiser, epochs, verbosity, X, y + model::LaplaceClassification, penalty, chain, optimiser, epochs, verbosity, X, y ) - loss = model.loss +epochs= model.epochs +n_samples= size(X, 1) +#y encode +y_encoded= unique(y) .== permutedims(y) - # intitialize and start progress meter: - meter = Progress( - epochs + 1; - dt=0, - desc="Optimising neural net:", - barglyphs=BarGlyphs("[=> ]"), - barlen=25, - color=:yellow, - ) - verbosity != 1 || next!(meter) +#todo - # initiate history: - n_batches = length(y) - parameters = Flux.params(chain) - # initial loss: - losses = ( - loss(chain(X[i]), y[i]) + penalty(parameters) / n_batches for i in 1:n_batches - ) - history = [mean(losses)] - for i in 1:epochs - current_loss = MLJFlux.train!( - model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y - ) - verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" - verbosity != 1 || next!(meter) - push!(history, current_loss) + +# Determine the shape of the model +shape = MLJFlux.shape(model, X, y_encoded) + +# Build the chain +chain = MLJFlux.build(model, model.rng, shape) +la= model.la + +optimiser= model.optimiser + +# Initialize history: +n_samples = size(X, 1) +history = [] +# Define the loss function for Laplace Regression with a custom penalty +function custom_loss( X_batch, y_batch) + preds = chain(X_batch) + data_loss = model.loss(y_batch, preds) + penalty_term = penalty(params(chain)) + return data_loss + penalty_term +end +# intitialize and start progress meter: +meter = Progress( + epochs + 1; + dt=0, + desc="Optimising neural net:", + barglyphs=BarGlyphs("[=> ]"), + barlen=25, + color=:yellow, +) +# Create a data loader +loader = Flux.Data.DataLoader((data=X', label=y), batchsize=model.batch_size, shuffle=true) +parameters = Flux.params(chain) +for i in 1:epochs + epoch_loss = 0.0 + # train the model + for (X_batch, y_batch) in loader + y_batch = reshape(y_batch,1,:) + + # Backward pass + gs = Flux.gradient(parameters) do + batch_loss = Flux.Losses.mse(chain(X_batch), y_batch) + epoch_loss += batch_loss + end + # Update parameters + Flux.update!(optimiser, parameters,gs) end + epoch_loss /= n_samples + push!(history, epoch_loss) + #verbosity + if verbosity == 1 + next!(meter) + elseif verbosity ==2 + next!(meter) + println( "Loss is $(round(epoch_loss; sigdigits=4))") + end +end - la = model.la - # fit the Laplace model: - fit!(la, zip(X, y); model.fit_params...) - optimize_prior!(la; verbose=false, n_steps=100) - model.la = la +# fit the Laplace model: +LaplaceRedux.fit!(model.la,zip(eachrow(X),y)) +optimize_prior!(model.la; verbose=false, n_steps=100) - return chain, history + +return chain, history end -function MMI.predict(model::LaplaceRegression, fitresult, Xnew) + +function MMI.predict(model::LaplaceClassification, fitresult, Xnew) chain, la, levels = fitresult # re-format Xnew into acceptable input for Laplace: X = MLJFlux.reformat(Xnew) # predict using Laplace: yhat = vcat( [ - glm_predictive_distribution(la, MLJFlux.tomat(X[:, i]))' for + predict(la, MLJFlux.tomat(X[:, i]); link_approx=model.link_approx)' for i in 1:size(X, 2) ]..., ) - println(size(yhat)) - predictions = [] - for row in eachrow(yhat) - - mean_val = Float64(row[1][1]) - std_val = sqrt(Float64(row[2][1])) - # Append a Normal distribution: - push!(predictions, Normal(mean_val, std_val)) - end - - return predictions + + return MMI.UnivariateFinite(levels, yhat) + end From 437405bf0aed90b489441b4903b85d3d3eb24fe3 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sat, 8 Jun 2024 04:29:24 +0200 Subject: [PATCH 04/32] updated with the new MLJ @mlj_model macro --- CHANGELOG.md | 12 + src/mlj_flux.jl | 657 ++++++++++++++++++------------------------------ 2 files changed, 254 insertions(+), 415 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d78be0c..b8fda09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,18 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), *Note*: We try to adhere to these practices as of version [v0.2.1]. +## Version [0.3.0] - 2024-06-8 + +### Changed + +- adapted the LaplaceClassification and the LaplaceRegression struct to use the new @mlj_model macro from MLJBase. +- Changed the fit! method arguments. Now it also accept a Flux chain model instead of retrieving it from the structs. this is due to the fact that MLJ wants only hyperparameters in the struct https://juliaai.github.io/MLJModelInterface.jl/dev/quick_start_guide/ +- Changed the predict functions for both LaplaceClassification and LaplaceRegression. + +### Removed +- Removed the shape , build and clean! functions. + + ## Version [0.2.3] - 2024-05-31 ### Changed diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index c851306..e249ecf 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -1,59 +1,23 @@ using Flux using MLJFlux -import MLJModelInterface as MMI using ProgressMeter: Progress, next!, BarGlyphs using Random using Tables -using ComputationalResources -using Statistics using Distributions using LinearAlgebra using LaplaceRedux - -mutable struct LaplaceClassification{B,F,O,L} <: MLJFlux.MLJFluxProbabilistic - builder::B - finaliser::F - optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl - loss::L # can be called as in `loss(yhat, y)` - epochs::Int # number of epochs - batch_size::Int # size of a batch - lambda::Float64 # regularization strength - alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) - rng::Union{AbstractRNG,Int64} - optimiser_changes_trigger_retraining::Bool - acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` - likelihood::Symbol - subset_of_weights::Symbol - subnetwork_indices::Vector{Vector{Int}} - hessian_structure::Union{HessianStructure,Symbol,String} - backend::Symbol - σ::Real - μ₀::Real - P₀::Union{AbstractMatrix,UniformScaling,Nothing} - link_approx::Symbol - fit_params::Dict{Symbol,Any} - la::Union{Nothing,AbstractLaplace} -end +import MLJBase +import MLJBase: @mlj_model, metadata_model, metadata_pkg """ - LaplaceClassification(; builder, finaliser, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration, likelihood, subset_of_weights, subnetwork_indices, hessian_structure, backend, σ, μ₀, P₀, link_approx, fit_params) + @mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic -A probabilistic classification model that uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. The model is trained using the `fit!` method. The model is defined by the following default parameters for all `MLJFlux` models: +A mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type. + It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. + The model is trained using the `fit!` method. The model is defined by the following default parameters: -- `builder`: a Flux model that constructs the neural network. -- `finaliser`: a Flux model that processes the output of the neural network. -- `optimiser`: a Flux optimiser. - `loss`: a loss function that takes the predicted output and the true output as arguments. -- `epochs`: the number of epochs to train the model. -- `batch_size`: the size of a batch. -- `lambda`: the regularization strength. -- `alpha`: the regularization mix (0 for all l2, 1 for all l1). - `rng`: a random number generator. -- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining. -- `acceleration`: the computational resource to use. - -The model also has the following parameters, which are specific to the Laplace approximation: - - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. - `subnetwork_indices`: the indices of the subnetworks. - `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. @@ -62,107 +26,36 @@ The model also has the following parameters, which are specific to the Laplace a - `μ₀`: the mean of the prior distribution. - `P₀`: the covariance matrix of the prior distribution. - `link_approx`: the link approximation to use, either `:probit` or `:plugin`. -- `fit_params`: additional parameters to pass to the `fit!` method. +- `predict_proba`:whether to compute the probabilities or not, either true or false. +- `fit_prior_nsteps`: the number of steps used to fit the priors. +- `la`: The fitted Laplace object will be saved here once the model is fitted. It has to be left to nothing, the fit! function will automatically save the Laplace model here. """ -function LaplaceClassification(; - builder::B=MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish), - finaliser::F=Flux.softmax, - optimiser::O=Flux.Optimise.Adam(), - loss::L=Flux.crossentropy, - epochs::Int=10, - batch_size::Int=1, - lambda::Float64=1.0, - alpha::Float64=0.0, - rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG, - optimiser_changes_trigger_retraining::Bool=false, - acceleration::AbstractResource=CPU1(), - subset_of_weights::Symbol=:all, - subnetwork_indices::Vector{Vector{Int}}=Vector{Vector{Int}}([]), - hessian_structure::Union{HessianStructure,Symbol,String}=:full, - backend::Symbol=:GGN, - σ::Float64=1.0, - μ₀::Float64=0.0, - P₀::Union{AbstractMatrix,UniformScaling,Nothing}=nothing, - link_approx::Symbol=:probit, - fit_params::Dict{Symbol,Any}=Dict{Symbol,Any}(:override => true), -) where {B,F,O,L} - likelihood = :classification - la= :classification - model = LaplaceClassification( - builder, - finaliser, - optimiser, - loss, - epochs, - batch_size, - lambda, - alpha, - rng, - optimiser_changes_trigger_retraining, - acceleration, - likelihood, - subset_of_weights, - subnetwork_indices, - hessian_structure, - backend, - σ, - μ₀, - P₀, - link_approx, - fit_params, - la, - ) - - message = MMI.clean!(model) - isempty(message) || @warn message - - return model -end - -mutable struct LaplaceRegression{B,F,O,L} <: MLJFlux.MLJFluxProbabilistic - builder::B - finaliser::F - optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl - loss::L # can be called as in `loss(yhat, y)` - epochs::Int # number of epochs - batch_size::Int # size of a batch - lambda::Float64 # regularization strength - alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) - rng::Union{AbstractRNG,Int64} - optimiser_changes_trigger_retraining::Bool - acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` - likelihood::Symbol - subset_of_weights::Symbol - subnetwork_indices::Union{Nothing,Vector{Vector{Int}}} - hessian_structure::Union{HessianStructure,Symbol,String} - backend::Symbol - σ::Real - μ₀::Real - P₀::Union{AbstractMatrix,UniformScaling,Nothing} - fit_params::Dict{Symbol,Any} - la::Union{Nothing,AbstractLaplace} +@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic + loss=Flux.crossentropy + rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG + subset_of_weights::Symbol=:all + subnetwork_indices::Vector{Vector{Int}}=Vector{Vector{Int}}([]) + hessian_structure::Union{HessianStructure,Symbol,String}=:full + backend::Symbol=:GGN + σ::Float64=1.0 + μ₀::Float64=0.0 + P₀::Union{AbstractMatrix,UniformScaling,Nothing}=nothing + link_approx::Symbol=:probit + predict_proba::Bool= true + fit_prior_nsteps::Int=100 + la::Union{Nothing,AbstractLaplace}= nothing + end """ - LaplaceRegression(; builder, finaliser, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration, likelihood, subset_of_weights, subnetwork_indices, hessian_structure, backend, σ, μ₀, P₀, link_approx, fit_params) + @mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic -A probabilistic regression model that uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. The model is trained using the `fit!` method. The model is defined by the following default parameters for all `MLJFlux` models: +A mutable struct representing a Laplace regression model that extends the `MLJFlux.MLJFluxProbabilistic` abstract type. +It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. +The model is trained using the `fit!` method. The model is defined by the following default parameters: -- `builder`: a Flux model that constructs the neural network. -- `finaliser`: a Flux model that processes the output of the neural network. -- `optimiser`: a Flux optimiser. - `loss`: a loss function that takes the predicted output and the true output as arguments. -- `epochs`: the number of epochs to train the model. -- `batch_size`: the size of a batch. -- `lambda`: the regularization strength. -- `alpha`: the regularization mix (0 for all l2, 1 for all l1). - `rng`: a random number generator. -- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining. -- `acceleration`: the computational resource to use. - -The model also has the following parameters, which are specific to the Laplace approximation: - -- `likelihood`: the likelihood of the model `:regression`. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. - `subnetwork_indices`: the indices of the subnetworks. - `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. @@ -170,220 +63,89 @@ The model also has the following parameters, which are specific to the Laplace a - `σ`: the standard deviation of the prior distribution. - `μ₀`: the mean of the prior distribution. - `P₀`: the covariance matrix of the prior distribution. -- `fit_params`: additional parameters to pass to the `fit!` method. +- `fit_prior_nsteps`: the number of steps used to fit the priors. +- `la`: The fitted Laplace object will be saved here once the model is fitted. It has to be left to nothing, the fit! function will automatically save the Laplace model here. """ -function LaplaceRegression(; - builder::B=MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish), - finaliser::F=x->x, - optimiser::O=Flux.Optimise.Adam(), - loss::L=Flux.Losses.mse, - epochs::Int=10, - batch_size::Int=1, - lambda::Float64=1.0, - alpha::Float64=0.0, - rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG, - optimiser_changes_trigger_retraining::Bool=false, - acceleration::AbstractResource=CPU1(), - subset_of_weights::Symbol=:all, - subnetwork_indices=nothing, - hessian_structure::Union{HessianStructure,Symbol,String}=:full, - backend::Symbol=:GGN, - σ::Float64=1.0, - μ₀::Float64=0.0, - P₀::Union{AbstractMatrix,UniformScaling,Nothing}=nothing, - fit_params::Dict{Symbol,Any}=Dict{Symbol,Any}(:override => true), -) where {B,F,O,L} - likelihood=:regression - la= nothing - model = LaplaceRegression( - builder, - finaliser, - optimiser, - loss, - epochs, - batch_size, - lambda, - alpha, - rng, - optimiser_changes_trigger_retraining, - acceleration, - likelihood, - subset_of_weights, - subnetwork_indices, - hessian_structure, - backend, - σ, - μ₀, - P₀, - fit_params, - la, - ) - - message = MMI.clean!(model) - isempty(message) || @warn message - - return model +@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic + loss=Flux.Losses.mse + rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG + subset_of_weights::Symbol=:all + subnetwork_indices=nothing + hessian_structure::Union{HessianStructure,Symbol,String}=:full + backend::Symbol=:GGN + σ::Float64=1.0 + μ₀::Float64=0.0 + P₀::Union{AbstractMatrix,UniformScaling,Nothing}=nothing + fit_prior_nsteps::Int=100 + la::Union{Nothing,AbstractLaplace}= nothing end +const MLJ_Laplace= Union{LaplaceClassification,LaplaceRegression} + +""" + MLJFlux.fit!(model::LaplaceRegression, penalty, chain, epochs, batch_size, optimiser, verbosity, X, y) + +Fit the LaplaceRegression model using Flux.jl. + +# Arguments +- `model::LaplaceRegression`: The LaplaceRegression object. +- `penalty`: The penalty term for regularization. +- `chain`: The chain of layers for the model. +- `epochs`: The number of training epochs. +- `batch_size`: The size of each training batch. +- `optimiser`: The optimization algorithm to use. +- `verbosity`: The level of verbosity during training. +- `X`: The input data. +- `y`: The target data. + +# Returns +- `model::LaplaceRegression`: The fitted LaplaceRegression model. +""" +function MLJFlux.fit!(model::LaplaceRegression, penalty,chain, epochs,batch_size, optimiser, verbosity, X, y) - - - - -function MLJFlux.shape(model::Union{LaplaceClassification,LaplaceRegression}, X, y) - X = X isa Matrix ? Tables.table(X) : X - n_input = length(Tables.columnnames(X)) - - if model isa LaplaceClassification - levels = MMI.classes(y[1]) - n_output = length(levels) - return (n_input, n_output) - elseif model isa LaplaceRegression - dims = size(y) - if length(dims) == 1 - n_output= 1 - else - n_output= dims[2] - end - return (n_input, n_output) - end - -end - - -function MLJFlux.build(model::Union{LaplaceClassification,LaplaceRegression}, rng, shape) - # Construct the initial chain - chain = MLJFlux.build(model.builder, rng, shape...) - # Construct Laplace model and store it in the model object + X = MLJBase.matrix(X, transpose=true) model.la = Laplace( chain; - likelihood=model.likelihood, + likelihood=:regression, subset_of_weights=model.subset_of_weights, subnetwork_indices=model.subnetwork_indices, hessian_structure=model.hessian_structure, backend=model.backend, σ=model.σ, μ₀=model.μ₀, - P₀=model.P₀, - ) - return chain -end - -function MLJFlux.fitresult(model::Union{LaplaceClassification,LaplaceRegression}, chain, y) - if model isa LaplaceClassification - return (chain, model.la, MMI.classes(y[1])) - else - return (chain, model.la, size(y) ) - end -end - - -function MMI.clean!(model::Union{LaplaceClassification,LaplaceRegression}) - warning = "" - if model.lambda < 0 - warning *= "Need `lambda ≥ 0`. Resetting `lambda = 0`. " - model.lambda = 0 - end - if model.alpha < 0 || model.alpha > 1 - warning *= "Need alpha in the interval `[0, 1]`. " * "Resetting `alpha = 0`. " - model.alpha = 0 - end - if model.epochs < 0 - warning *= "Need `epochs ≥ 0`. Resetting `epochs = 10`. " - model.epochs = 10 - end - if model.batch_size <= 0 - warning *= "Need `batch_size > 0`. Resetting `batch_size = 1`. " - model.batch_size = 1 - end - if model.acceleration isa CUDALibs && gpu_isdead() - warning *= - "`acceleration isa CUDALibs` " * "but no CUDA device (GPU) currently live. " - end - if !(model.acceleration isa CUDALibs || model.acceleration isa CPU1) - warning *= "`Undefined acceleration, falling back to CPU`" - model.acceleration = CPU1() - end - if model.likelihood ∉ (:regression, :classification) - warning *= "Need `likelihood ∈ (:regression, :classification)`. " * - "Resetting to default `likelihood = :regression`. " - model.likelihood = :regression - end - if model.subset_of_weights ∉ (:all, :last_layer, :subnetwork) - warning *= - "Need `subset_of_weights ∈ (:all, :last_layer, :subnetwork)`. " * - "Resetting `subset_of_weights = :all`. " - model.subset_of_weights = :all - end - if String(model.hessian_structure) ∉ ("full", "diagonal") && - !(typeof(model.hessian_structure) <: HessianStructure) - warning *= - "Need `hessian_structure ∈ (:full, :diagonal)` or `hessian_structure ∈ (:full, :diagonal)` or typeof(model.hessian_structure) <: HessianStructure." * - "Resetting `hessian_structure = :full`. " - model.hessian_structure = :full - end - if model.backend ∉ (:GGN, :EmpiricalFisher) - warning *= - "Need `backend ∈ (:GGN, :EmpiricalFisher)`. " * "Resetting `backend = :GGN`. " - model.backend = :GGN - end - if model.likelihood == :classification && model.link_approx ∉ (:probit, :plugin) - warning *= - "Need `link_approx ∈ (:probit, :plugin)`. " * - "Resetting `link_approx = :probit`. " - model.link_approx = :probit - end - return warning -end - -######################################################## fit and predict for regression - -function MLJFlux.fit!(model::LaplaceRegression, penalty, verbosity, X, y) - - epochs= model.epochs - n_samples= size(X, 1) - - # Determine the shape of the model - shape = MLJFlux.shape(model, X, y) - - # Build the chain - chain = MLJFlux.build(model, model.rng, shape) - la= model.la - - optimiser= model.optimiser - + P₀=model.P₀) + n_samples= size(X,1) # Initialize history: - n_samples = size(X, 1) history = [] + verbose_laplace=false # Define the loss function for Laplace Regression with a custom penalty - function custom_loss( X_batch, y_batch) - preds = chain(X_batch) - data_loss = model.loss(y_batch, preds) - penalty_term = penalty(params(chain)) + function custom_loss( y_pred, y_batch) + data_loss = model.loss( y_pred,y_batch) + penalty_term = penalty(Flux.params(chain)) return data_loss + penalty_term end # intitialize and start progress meter: meter = Progress( epochs + 1; - dt=0, + dt=1.0, desc="Optimising neural net:", barglyphs=BarGlyphs("[=> ]"), barlen=25, color=:yellow, ) # Create a data loader - loader = Flux.Data.DataLoader((data=X', label=y), batchsize=model.batch_size, shuffle=true) + loader = Flux.Data.DataLoader((data=X, label=y), batchsize=batch_size, shuffle=true) parameters = Flux.params(chain) for i in 1:epochs epoch_loss = 0.0 # train the model for (X_batch, y_batch) in loader y_batch = reshape(y_batch,1,:) - + # Backward pass gs = Flux.gradient(parameters) do - batch_loss = Flux.Losses.mse(chain(X_batch), y_batch) + batch_loss = custom_loss(chain(X_batch), y_batch) epoch_loss += batch_loss end # Update parameters @@ -396,6 +158,7 @@ function MLJFlux.fit!(model::LaplaceRegression, penalty, verbosity, X, y) next!(meter) elseif verbosity ==2 next!(meter) + verbose_laplace=true println( "Loss is $(round(epoch_loss; sigdigits=4))") end end @@ -403,31 +166,42 @@ function MLJFlux.fit!(model::LaplaceRegression, penalty, verbosity, X, y) # fit the Laplace model: - LaplaceRedux.fit!(model.la,zip(eachrow(X),y)) - optimize_prior!(model.la; verbose=false, n_steps=100) + LaplaceRedux.fit!(model.la,zip(eachcol(X),y)) + optimize_prior!(model.la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) + cache= nothing + report=[] - - return chain, history + return (model, cache, report) end -function MMI.predict(model::LaplaceRegression, fitresult, Xnew) - chain, la, levels = fitresult - # re-format Xnew into acceptable input for Laplace: - X = MLJFlux.reformat(Xnew) - # predict using Laplace: - yhat = vcat( - [ - glm_predictive_distribution(la, MLJFlux.tomat(X[:, i]))' for - i in 1:size(X, 2) - ]..., - ) - println(size(yhat)) +""" + predict(model::LaplaceRegression, Xnew) + +Predict the output for new input data using a Laplace regression model. + +# Arguments +- `model::LaplaceRegression`: The trained Laplace regression model. +- `Xnew`: The new input data. + +# Returns +- The predicted output for the new input data. + +""" +function MLJFlux.predict(model::LaplaceRegression, Xnew) + Xnew = MLJBase.matrix(Xnew) + #convert in a vector of vectors because MLJ ask to do so + X_vec= [Xnew[i,:] for i in 1:size(Xnew, 1)] + #inizialize output vector yhat + yhat=[] + # Predict using Laplace and collect the predictions + yhat = [glm_predictive_distribution(model.la, x_vec) for x_vec in X_vec] + predictions = [] for row in eachrow(yhat) - - mean_val = Float64(row[1][1]) - std_val = sqrt(Float64(row[2][1])) + + mean_val = Float64(row[1][1][1]) + std_val = sqrt(Float64(row[1][2][1])) # Append a Normal distribution: push!(predictions, Normal(mean_val, std_val)) end @@ -443,101 +217,154 @@ end ######################################### fit and predict for classification -function MLJFlux.fit!( - model::LaplaceClassification, penalty, chain, optimiser, epochs, verbosity, X, y -) -epochs= model.epochs -n_samples= size(X, 1) -#y encode -y_encoded= unique(y) .== permutedims(y) - -#todo - - - - - -# Determine the shape of the model -shape = MLJFlux.shape(model, X, y_encoded) - -# Build the chain -chain = MLJFlux.build(model, model.rng, shape) -la= model.la +""" + MLJFlux.fit!(model::LaplaceClassification, penalty, chain, epochs, batch_size, optimiser, verbosity, X, y) + +Fit the LaplaceClassification model using MLJFlux. + +# Arguments +- `model::LaplaceClassification`: The LaplaceClassification object to fit. +- `penalty`: The penalty to apply during training. +- `chain`: The chain to use during training. +- `epochs`: The number of training epochs. +- `batch_size`: The batch size for training. +- `optimiser`: The optimiser to use during training. +- `verbosity`: The verbosity level for training. +- `X`: The input data for training. +- `y`: The target labels for training. + +# Returns +- `model::LaplaceClassification`: The fitted LaplaceClassification model. +""" +function MLJFlux.fit!(model::LaplaceClassification, penalty,chain, epochs,batch_size, optimiser, verbosity, X, y) + X = MLJBase.matrix(X, transpose=true) -optimiser= model.optimiser + # Integer encode the target variable y + #y_onehot = unique(y) .== permutedims(y) -# Initialize history: -n_samples = size(X, 1) -history = [] -# Define the loss function for Laplace Regression with a custom penalty -function custom_loss( X_batch, y_batch) - preds = chain(X_batch) - data_loss = model.loss(y_batch, preds) - penalty_term = penalty(params(chain)) + model.la = Laplace( + chain; + likelihood=:classification, + subset_of_weights=model.subset_of_weights, + subnetwork_indices=model.subnetwork_indices, + hessian_structure=model.hessian_structure, + backend=model.backend, + σ=model.σ, + μ₀=model.μ₀, + P₀=model.P₀) + n_samples= size(X,1) + verbose_laplace=false + + # Initialize history: + history = [] + # Define the loss function for Laplace Regression with a custom penalty + function custom_loss( y_pred, y_batch) + data_loss = model.loss( y_pred,y_batch) + penalty_term = penalty(Flux.params(chain)) return data_loss + penalty_term -end -# intitialize and start progress meter: -meter = Progress( - epochs + 1; - dt=0, - desc="Optimising neural net:", - barglyphs=BarGlyphs("[=> ]"), - barlen=25, - color=:yellow, -) -# Create a data loader -loader = Flux.Data.DataLoader((data=X', label=y), batchsize=model.batch_size, shuffle=true) -parameters = Flux.params(chain) -for i in 1:epochs - epoch_loss = 0.0 - # train the model - for (X_batch, y_batch) in loader - y_batch = reshape(y_batch,1,:) - - # Backward pass - gs = Flux.gradient(parameters) do - batch_loss = Flux.Losses.mse(chain(X_batch), y_batch) + end + # intitialize and start progress meter: + meter = Progress( + epochs + 1; + dt=0, + desc="Optimising neural net:", + barglyphs=BarGlyphs("[=> ]"), + barlen=25, + color=:yellow, + ) + # Create a data loader + loader = Flux.Data.DataLoader((data=X, label=y), batchsize=batch_size, shuffle=true) + parameters = Flux.params(chain) + for i in 1:epochs + epoch_loss = 0.0 + # train the model + for (X_batch, y_batch) in loader + + # Backward pass + gs = Flux.gradient(parameters) do + batch_loss = custom_loss(chain(X_batch), y_batch) epoch_loss += batch_loss + end + # Update parameters + Flux.update!(optimiser, parameters,gs) end - # Update parameters - Flux.update!(optimiser, parameters,gs) - end - epoch_loss /= n_samples - push!(history, epoch_loss) - #verbosity - if verbosity == 1 - next!(meter) - elseif verbosity ==2 - next!(meter) - println( "Loss is $(round(epoch_loss; sigdigits=4))") - end -end - + epoch_loss /= n_samples + push!(history, epoch_loss) + #verbosity + if verbosity == 1 + next!(meter) + elseif verbosity ==2 + next!(meter) + verbose_laplace=true + println( "Loss is $(round(epoch_loss; sigdigits=4))") + end + end + # fit the Laplace model: + LaplaceRedux.fit!(model.la,zip(eachcol(X),y)) + optimize_prior!(model.la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) + cache= nothing + report=[] -# fit the Laplace model: -LaplaceRedux.fit!(model.la,zip(eachrow(X),y)) -optimize_prior!(model.la; verbose=false, n_steps=100) + return (model, cache, report) +end -return chain, history -end +""" + predict(model::LaplaceClassification, Xnew) +Predicts the class labels for new data using the LaplaceClassification model. +# Arguments +- `model::LaplaceClassification`: The trained LaplaceClassification model. +- `Xnew`: The new data to make predictions on. -function MMI.predict(model::LaplaceClassification, fitresult, Xnew) - chain, la, levels = fitresult - # re-format Xnew into acceptable input for Laplace: - X = MLJFlux.reformat(Xnew) - # predict using Laplace: - yhat = vcat( - [ - predict(la, MLJFlux.tomat(X[:, i]); link_approx=model.link_approx)' for - i in 1:size(X, 2) - ]..., - ) +# Returns +An array of predicted class labels. - return MMI.UnivariateFinite(levels, yhat) +""" +function MLJFlux.predict(model::LaplaceClassification, Xnew) + Xnew = MLJBase.matrix(Xnew) + #convert in a vector of vectors because Laplace ask to do so + X_vec= X_vec= [Xnew[i,:] for i in 1:size(Xnew, 1)] + # Predict using Laplace and collect the predictions + predictions = [LaplaceRedux.predict(model.la, x;link_approx= model.link_approx,predict_proba=model.predict_proba) for x in X_vec] + + return predictions end + +# Then for each model, +MLJBase.metadata_model( + LaplaceClassification; + input=Union{ + AbstractMatrix{MLJBase.Continuous}, + MLJBase.Table(MLJBase.Continuous), + MLJBase.Table{AbstractVector{MLJBase.Continuous}}, + }, + target=Union{ + AbstractArray{MLJBase.Finite}, + AbstractArray{MLJBase.Continuous}, + AbstractVector{MLJBase.Finite}, + AbstractVector{MLJBase.Continuous}, + }, + path="MLJFlux.LaplaceClassification", +) +# Then for each model, +MLJBase.metadata_model( + LaplaceRegression; + input=Union{ + AbstractMatrix{MLJBase.Continuous}, + MLJBase.Table(MLJBase.Continuous), + MLJBase.Table{AbstractVector{MLJBase.Continuous}}, + }, + target=Union{ + AbstractArray{MLJBase.Finite}, + AbstractArray{MLJBase.Continuous}, + AbstractVector{MLJBase.Finite}, + AbstractVector{MLJBase.Continuous}, + }, + path="MLJFlux.LaplaceRegression", +) + \ No newline at end of file From 93140f8d1d0d77ff23d2668d9c632889ec31d1bd Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sat, 8 Jun 2024 04:47:13 +0200 Subject: [PATCH 05/32] add constraints on parameters --- src/mlj_flux.jl | 81 ++++++++++++++++++++++--------------------------- 1 file changed, 36 insertions(+), 45 deletions(-) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index e249ecf..01dca1e 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -15,8 +15,10 @@ import MLJBase: @mlj_model, metadata_model, metadata_pkg A mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. The model is trained using the `fit!` method. The model is defined by the following default parameters: - - `loss`: a loss function that takes the predicted output and the true output as arguments. +- `optimiser`: a Flux optimiser,default is Flux.Optimise.Adam(). +- `epochs`: the number of epochs to train the model,default is 10. +- `batch_size`: the size of a batch,default is 1. - `rng`: a random number generator. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. - `subnetwork_indices`: the indices of the subnetworks. @@ -32,17 +34,20 @@ A mutable struct representing a Laplace Classification model that extends the ML """ @mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic loss=Flux.crossentropy + optimiser=Flux.Optimise.Adam() + epochs::Int=10::(_ > 0) + batch_size::Int=1::(_ > 0) rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG - subset_of_weights::Symbol=:all - subnetwork_indices::Vector{Vector{Int}}=Vector{Vector{Int}}([]) - hessian_structure::Union{HessianStructure,Symbol,String}=:full - backend::Symbol=:GGN + subset_of_weights::Symbol=:all::(_ in (:all, :last_layer, :subnetwork)) + subnetwork_indices::Vector{Vector{Int}}=Vector{Vector{Int}}([]) + hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in ("full", "diagonal")) + backend::Symbol=:GGN::(_ in (:GGN, :EmpiricalFisher)) σ::Float64=1.0 μ₀::Float64=0.0 P₀::Union{AbstractMatrix,UniformScaling,Nothing}=nothing - link_approx::Symbol=:probit - predict_proba::Bool= true - fit_prior_nsteps::Int=100 + link_approx::Symbol=:probit::(_ in (:probit,:plugin)) + predict_proba::Bool= true::(_ in (true,false)) + fit_prior_nsteps::Int=100::(_ > 0) la::Union{Nothing,AbstractLaplace}= nothing end @@ -55,6 +60,9 @@ It uses Laplace approximation to estimate the posterior distribution of the weig The model is trained using the `fit!` method. The model is defined by the following default parameters: - `loss`: a loss function that takes the predicted output and the true output as arguments. +- `optimiser`: a Flux optimiser. +- `epochs`: the number of epochs to train the model. +- `batch_size`: the size of a batch. - `rng`: a random number generator. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. - `subnetwork_indices`: the indices of the subnetworks. @@ -68,15 +76,18 @@ The model is trained using the `fit!` method. The model is defined by the follow """ @mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic loss=Flux.Losses.mse + optimiser=Flux.Optimise.Adam() + epochs::Int=10::(_ > 0) + batch_size::Int=1::(_ > 0) rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG - subset_of_weights::Symbol=:all + subset_of_weights::Symbol=:all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices=nothing - hessian_structure::Union{HessianStructure,Symbol,String}=:full - backend::Symbol=:GGN + hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in ("full", "diagonal")) + backend::Symbol=:GGN::(_ in (:GGN, :EmpiricalFisher)) σ::Float64=1.0 μ₀::Float64=0.0 P₀::Union{AbstractMatrix,UniformScaling,Nothing}=nothing - fit_prior_nsteps::Int=100 + fit_prior_nsteps::Int=100::(_ > 0) la::Union{Nothing,AbstractLaplace}= nothing end const MLJ_Laplace= Union{LaplaceClassification,LaplaceRegression} @@ -90,11 +101,7 @@ Fit the LaplaceRegression model using Flux.jl. # Arguments - `model::LaplaceRegression`: The LaplaceRegression object. -- `penalty`: The penalty term for regularization. - `chain`: The chain of layers for the model. -- `epochs`: The number of training epochs. -- `batch_size`: The size of each training batch. -- `optimiser`: The optimization algorithm to use. - `verbosity`: The level of verbosity during training. - `X`: The input data. - `y`: The target data. @@ -102,7 +109,7 @@ Fit the LaplaceRegression model using Flux.jl. # Returns - `model::LaplaceRegression`: The fitted LaplaceRegression model. """ -function MLJFlux.fit!(model::LaplaceRegression, penalty,chain, epochs,batch_size, optimiser, verbosity, X, y) +function MLJFlux.fit!(model::LaplaceRegression, chain, verbosity, X, y) X = MLJBase.matrix(X, transpose=true) model.la = Laplace( @@ -119,15 +126,9 @@ function MLJFlux.fit!(model::LaplaceRegression, penalty,chain, epochs,batch_size # Initialize history: history = [] verbose_laplace=false - # Define the loss function for Laplace Regression with a custom penalty - function custom_loss( y_pred, y_batch) - data_loss = model.loss( y_pred,y_batch) - penalty_term = penalty(Flux.params(chain)) - return data_loss + penalty_term - end # intitialize and start progress meter: meter = Progress( - epochs + 1; + model.epochs + 1; dt=1.0, desc="Optimising neural net:", barglyphs=BarGlyphs("[=> ]"), @@ -135,9 +136,9 @@ function MLJFlux.fit!(model::LaplaceRegression, penalty,chain, epochs,batch_size color=:yellow, ) # Create a data loader - loader = Flux.Data.DataLoader((data=X, label=y), batchsize=batch_size, shuffle=true) + loader = Flux.Data.DataLoader((data=X, label=y), batchsize=model.batch_size, shuffle=true) parameters = Flux.params(chain) - for i in 1:epochs + for i in 1:model.epochs epoch_loss = 0.0 # train the model for (X_batch, y_batch) in loader @@ -145,11 +146,11 @@ function MLJFlux.fit!(model::LaplaceRegression, penalty,chain, epochs,batch_size # Backward pass gs = Flux.gradient(parameters) do - batch_loss = custom_loss(chain(X_batch), y_batch) + batch_loss = model.loss(chain(X_batch), y_batch) epoch_loss += batch_loss end # Update parameters - Flux.update!(optimiser, parameters,gs) + Flux.update!(model.optimiser, parameters,gs) end epoch_loss /= n_samples push!(history, epoch_loss) @@ -218,17 +219,13 @@ end """ - MLJFlux.fit!(model::LaplaceClassification, penalty, chain, epochs, batch_size, optimiser, verbosity, X, y) + MLJFlux.fit!(model::LaplaceClassification, chain, epochs, batch_size, optimiser, verbosity, X, y) Fit the LaplaceClassification model using MLJFlux. # Arguments - `model::LaplaceClassification`: The LaplaceClassification object to fit. -- `penalty`: The penalty to apply during training. - `chain`: The chain to use during training. -- `epochs`: The number of training epochs. -- `batch_size`: The batch size for training. -- `optimiser`: The optimiser to use during training. - `verbosity`: The verbosity level for training. - `X`: The input data for training. - `y`: The target labels for training. @@ -236,7 +233,7 @@ Fit the LaplaceClassification model using MLJFlux. # Returns - `model::LaplaceClassification`: The fitted LaplaceClassification model. """ -function MLJFlux.fit!(model::LaplaceClassification, penalty,chain, epochs,batch_size, optimiser, verbosity, X, y) +function MLJFlux.fit!(model::LaplaceClassification, chain, verbosity, X, y) X = MLJBase.matrix(X, transpose=true) # Integer encode the target variable y @@ -257,15 +254,9 @@ function MLJFlux.fit!(model::LaplaceClassification, penalty,chain, epochs,batch_ # Initialize history: history = [] - # Define the loss function for Laplace Regression with a custom penalty - function custom_loss( y_pred, y_batch) - data_loss = model.loss( y_pred,y_batch) - penalty_term = penalty(Flux.params(chain)) - return data_loss + penalty_term - end # intitialize and start progress meter: meter = Progress( - epochs + 1; + model.epochs + 1; dt=0, desc="Optimising neural net:", barglyphs=BarGlyphs("[=> ]"), @@ -273,20 +264,20 @@ function MLJFlux.fit!(model::LaplaceClassification, penalty,chain, epochs,batch_ color=:yellow, ) # Create a data loader - loader = Flux.Data.DataLoader((data=X, label=y), batchsize=batch_size, shuffle=true) + loader = Flux.Data.DataLoader((data=X, label=y), batchsize=model.batch_size, shuffle=true) parameters = Flux.params(chain) - for i in 1:epochs + for i in 1:model.epochs epoch_loss = 0.0 # train the model for (X_batch, y_batch) in loader # Backward pass gs = Flux.gradient(parameters) do - batch_loss = custom_loss(chain(X_batch), y_batch) + batch_loss = model.loss(chain(X_batch), y_batch) epoch_loss += batch_loss end # Update parameters - Flux.update!(optimiser, parameters,gs) + Flux.update!(model.optimiser, parameters,gs) end epoch_loss /= n_samples push!(history, epoch_loss) From 352c96111adb0a1a9b7f16a94fe8c8326450391c Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sun, 9 Jun 2024 00:19:19 +0200 Subject: [PATCH 06/32] added the project.toml of laplaceredux --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 8353012..64f52f0 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" From 45f9fb727a6298bd67a3853c27e7d279ac6c2bcf Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 12 Jun 2024 03:31:12 +0200 Subject: [PATCH 07/32] updated struct and test. --- src/mlj_flux.jl | 392 ++++++++++++++++++++--------------- test/mlj_flux_interfacing.jl | 100 ++++++++- 2 files changed, 324 insertions(+), 168 deletions(-) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 01dca1e..9555782 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -6,20 +6,26 @@ using Tables using Distributions using LinearAlgebra using LaplaceRedux -import MLJBase +using MLJBase import MLJBase: @mlj_model, metadata_model, metadata_pkg + """ - @mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic + @mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic -A mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type. - It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. - The model is trained using the `fit!` method. The model is defined by the following default parameters: +A mutable struct representing a Laplace regression model that extends the `MLJFlux.MLJFluxProbabilistic` abstract type. +It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. +The model is trained using the `fit!` method. The model is defined by the following default parameters: + +- `builder`: a Flux model that constructs the neural network. +- `optimiser`: a Flux optimiser. - `loss`: a loss function that takes the predicted output and the true output as arguments. -- `optimiser`: a Flux optimiser,default is Flux.Optimise.Adam(). -- `epochs`: the number of epochs to train the model,default is 10. -- `batch_size`: the size of a batch,default is 1. +- `batch_size`: the size of a batch. +- `lambda`: the regularization strength. +- `alpha`: the regularization mix (0 for all l2, 1 for all l1). - `rng`: a random number generator. +- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining. +- `acceleration`: the computational resource to use. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. - `subnetwork_indices`: the indices of the subnetworks. - `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. @@ -27,43 +33,47 @@ A mutable struct representing a Laplace Classification model that extends the ML - `σ`: the standard deviation of the prior distribution. - `μ₀`: the mean of the prior distribution. - `P₀`: the covariance matrix of the prior distribution. -- `link_approx`: the link approximation to use, either `:probit` or `:plugin`. -- `predict_proba`:whether to compute the probabilities or not, either true or false. - `fit_prior_nsteps`: the number of steps used to fit the priors. -- `la`: The fitted Laplace object will be saved here once the model is fitted. It has to be left to nothing, the fit! function will automatically save the Laplace model here. """ -@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic - loss=Flux.crossentropy - optimiser=Flux.Optimise.Adam() - epochs::Int=10::(_ > 0) +@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic + builder=MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) + optimiser= Flux.Optimise.Adam() + loss=Flux.Losses.mse batch_size::Int=1::(_ > 0) - rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG + lambda::Float64=1.0 + alpha::Float64=0.0 + rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG + optimiser_changes_trigger_retraining::Bool=false::(_ in (true, false)) + acceleration::AbstractResource=CPU1() subset_of_weights::Symbol=:all::(_ in (:all, :last_layer, :subnetwork)) - subnetwork_indices::Vector{Vector{Int}}=Vector{Vector{Int}}([]) - hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in ("full", "diagonal")) + subnetwork_indices=nothing + hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in (":full", ":diagonal")) backend::Symbol=:GGN::(_ in (:GGN, :EmpiricalFisher)) - σ::Float64=1.0 + σ::Float64=1.0 μ₀::Float64=0.0 P₀::Union{AbstractMatrix,UniformScaling,Nothing}=nothing - link_approx::Symbol=:probit::(_ in (:probit,:plugin)) - predict_proba::Bool= true::(_ in (true,false)) fit_prior_nsteps::Int=100::(_ > 0) - la::Union{Nothing,AbstractLaplace}= nothing - end + + """ - @mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic + @mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic -A mutable struct representing a Laplace regression model that extends the `MLJFlux.MLJFluxProbabilistic` abstract type. -It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. -The model is trained using the `fit!` method. The model is defined by the following default parameters: +A mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type. + It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. + The model is trained using the `fit!` method. The model is defined by the following default parameters: -- `loss`: a loss function that takes the predicted output and the true output as arguments. +- `builder`: a Flux model that constructs the neural network. +- `finaliser`: a Flux model that processes the output of the neural network. - `optimiser`: a Flux optimiser. -- `epochs`: the number of epochs to train the model. +- `loss`: a loss function that takes the predicted output and the true output as arguments. - `batch_size`: the size of a batch. +- `lambda`: the regularization strength. +- `alpha`: the regularization mix (0 for all l2, 1 for all l1). - `rng`: a random number generator. +- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining. +- `acceleration`: the computational resource to use. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. - `subnetwork_indices`: the indices of the subnetworks. - `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. @@ -72,145 +82,72 @@ The model is trained using the `fit!` method. The model is defined by the follow - `μ₀`: the mean of the prior distribution. - `P₀`: the covariance matrix of the prior distribution. - `fit_prior_nsteps`: the number of steps used to fit the priors. -- `la`: The fitted Laplace object will be saved here once the model is fitted. It has to be left to nothing, the fit! function will automatically save the Laplace model here. """ -@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic - loss=Flux.Losses.mse - optimiser=Flux.Optimise.Adam() - epochs::Int=10::(_ > 0) +@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic + builder=MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) + finaliser=Flux.softmax + optimiser= Flux.Optimise.Adam() + loss=Flux.crossentropy + epochs::Int=10 batch_size::Int=1::(_ > 0) - rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG + lambda::Float64=1.0 + alpha::Float64=0.0 + rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG + optimiser_changes_trigger_retraining::Bool=false + acceleration::AbstractResource=CPU1() subset_of_weights::Symbol=:all::(_ in (:all, :last_layer, :subnetwork)) - subnetwork_indices=nothing - hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in ("full", "diagonal")) + subnetwork_indices::Vector{Vector{Int}}=Vector{Vector{Int}}([]) + hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in (":full", ":diagonal")) backend::Symbol=:GGN::(_ in (:GGN, :EmpiricalFisher)) - σ::Float64=1.0 + σ::Float64=1.0 μ₀::Float64=0.0 P₀::Union{AbstractMatrix,UniformScaling,Nothing}=nothing + link_approx::Symbol=:probit::(_ in (:probit,:plugin)) + predict_proba::Bool= true::(_ in (true,false)) fit_prior_nsteps::Int=100::(_ > 0) - la::Union{Nothing,AbstractLaplace}= nothing + end +############### const MLJ_Laplace= Union{LaplaceClassification,LaplaceRegression} - - -""" - MLJFlux.fit!(model::LaplaceRegression, penalty, chain, epochs, batch_size, optimiser, verbosity, X, y) - -Fit the LaplaceRegression model using Flux.jl. - -# Arguments -- `model::LaplaceRegression`: The LaplaceRegression object. -- `chain`: The chain of layers for the model. -- `verbosity`: The level of verbosity during training. -- `X`: The input data. -- `y`: The target data. +################################ functions shape and build -# Returns -- `model::LaplaceRegression`: The fitted LaplaceRegression model. -""" -function MLJFlux.fit!(model::LaplaceRegression, chain, verbosity, X, y) +function MLJFlux.shape(model::LaplaceClassification, X, y) + n_input = size(X,2) + levels = unique(y) + n_output = length(levels) + return (n_input, n_output) + +end - X = MLJBase.matrix(X, transpose=true) - model.la = Laplace( - chain; - likelihood=:regression, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀) - n_samples= size(X,1) - # Initialize history: - history = [] - verbose_laplace=false - # intitialize and start progress meter: - meter = Progress( - model.epochs + 1; - dt=1.0, - desc="Optimising neural net:", - barglyphs=BarGlyphs("[=> ]"), - barlen=25, - color=:yellow, - ) - # Create a data loader - loader = Flux.Data.DataLoader((data=X, label=y), batchsize=model.batch_size, shuffle=true) - parameters = Flux.params(chain) - for i in 1:model.epochs - epoch_loss = 0.0 - # train the model - for (X_batch, y_batch) in loader - y_batch = reshape(y_batch,1,:) - - # Backward pass - gs = Flux.gradient(parameters) do - batch_loss = model.loss(chain(X_batch), y_batch) - epoch_loss += batch_loss - end - # Update parameters - Flux.update!(model.optimiser, parameters,gs) +function MLJFlux.shape(model::LaplaceRegression, X, y) + n_input = size(X,2) + dims = size(y) + if length(dims) == 1 + n_output= 1 + else + n_output= dims[1] end - epoch_loss /= n_samples - push!(history, epoch_loss) - #verbosity - if verbosity == 1 - next!(meter) - elseif verbosity ==2 - next!(meter) - verbose_laplace=true - println( "Loss is $(round(epoch_loss; sigdigits=4))") - end - end - - - - # fit the Laplace model: - LaplaceRedux.fit!(model.la,zip(eachcol(X),y)) - optimize_prior!(model.la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - cache= nothing - report=[] - - return (model, cache, report) + return (n_input, n_output) end -""" - predict(model::LaplaceRegression, Xnew) - -Predict the output for new input data using a Laplace regression model. - -# Arguments -- `model::LaplaceRegression`: The trained Laplace regression model. -- `Xnew`: The new input data. - -# Returns -- The predicted output for the new input data. - -""" -function MLJFlux.predict(model::LaplaceRegression, Xnew) - Xnew = MLJBase.matrix(Xnew) - #convert in a vector of vectors because MLJ ask to do so - X_vec= [Xnew[i,:] for i in 1:size(Xnew, 1)] - #inizialize output vector yhat - yhat=[] - # Predict using Laplace and collect the predictions - yhat = [glm_predictive_distribution(model.la, x_vec) for x_vec in X_vec] - - predictions = [] - for row in eachrow(yhat) - - mean_val = Float64(row[1][1][1]) - std_val = sqrt(Float64(row[1][2][1])) - # Append a Normal distribution: - push!(predictions, Normal(mean_val, std_val)) - end +function MLJFlux.build(model::MLJ_Laplace, shape) + #chain + chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) - return predictions + return chain +end +function MLJFlux.fitresult(model::LaplaceClassification, chain, y) + return (chain, length(unique(y_cl))) end +function MLJFlux.fitresult(model::LaplaceRegression, chain, y) + return (chain, size(y) ) +end + + @@ -219,13 +156,16 @@ end """ - MLJFlux.fit!(model::LaplaceClassification, chain, epochs, batch_size, optimiser, verbosity, X, y) + MLJFlux.fit!(model::LaplaceClassification, chain,penalty,optimiser,epochs, verbosity, X, y) Fit the LaplaceClassification model using MLJFlux. # Arguments - `model::LaplaceClassification`: The LaplaceClassification object to fit. - `chain`: The chain to use during training. +- `penalty`: a penalty function to add to the loss function during training. +- `optimiser`: the optimiser to use during training. +- `epochs`: the number of epochs use for training. - `verbosity`: The verbosity level for training. - `X`: The input data for training. - `y`: The target labels for training. @@ -233,13 +173,13 @@ Fit the LaplaceClassification model using MLJFlux. # Returns - `model::LaplaceClassification`: The fitted LaplaceClassification model. """ -function MLJFlux.fit!(model::LaplaceClassification, chain, verbosity, X, y) +function MLJFlux.fit!(model::LaplaceClassification, chain,penalty,optimiser,epochs, verbosity, X, y) X = MLJBase.matrix(X, transpose=true) # Integer encode the target variable y #y_onehot = unique(y) .== permutedims(y) - model.la = Laplace( + la = LaplaceRedux.Laplace( chain; likelihood=:classification, subset_of_weights=model.subset_of_weights, @@ -256,7 +196,7 @@ function MLJFlux.fit!(model::LaplaceClassification, chain, verbosity, X, y) history = [] # intitialize and start progress meter: meter = Progress( - model.epochs + 1; + epochs + 1; dt=0, desc="Optimising neural net:", barglyphs=BarGlyphs("[=> ]"), @@ -266,18 +206,18 @@ function MLJFlux.fit!(model::LaplaceClassification, chain, verbosity, X, y) # Create a data loader loader = Flux.Data.DataLoader((data=X, label=y), batchsize=model.batch_size, shuffle=true) parameters = Flux.params(chain) - for i in 1:model.epochs + for i in 1:epochs epoch_loss = 0.0 # train the model for (X_batch, y_batch) in loader # Backward pass gs = Flux.gradient(parameters) do - batch_loss = model.loss(chain(X_batch), y_batch) + batch_loss = (model.loss(chain(X_batch), y_batch) + penalty(Flux.params(chain)) ) epoch_loss += batch_loss end # Update parameters - Flux.update!(model.optimiser, parameters,gs) + Flux.update!(optimiser, parameters,gs) end epoch_loss /= n_samples push!(history, epoch_loss) @@ -292,12 +232,11 @@ function MLJFlux.fit!(model::LaplaceClassification, chain, verbosity, X, y) end # fit the Laplace model: - LaplaceRedux.fit!(model.la,zip(eachcol(X),y)) - optimize_prior!(model.la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - cache= nothing + LaplaceRedux.fit!(la,zip(eachcol(X),y)) + optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) report=[] - return (model, cache, report) + return (la, history, report) end @@ -326,19 +265,143 @@ function MLJFlux.predict(model::LaplaceClassification, Xnew) end + + +""" + MLJFlux.fit!(model::LaplaceRegression, penalty, chain, epochs, batch_size, optimiser, verbosity, X, y) + +Fit the LaplaceRegression model using Flux.jl. + +# Arguments +- `model::LaplaceRegression`: The LaplaceRegression model. +- `chain`: The chain to use during training. +- `penalty`: a penalty function to add to the loss function during training. +- `optimiser`: the optimiser to use during training. +- `epochs`: the number of epochs use for training. +- `verbosity`: The verbosity level for training. +- `X`: The input data for training. +- `y`: The target labels for training. + +# Returns +- `model::LaplaceRegression`: The fitted LaplaceRegression model. +""" +function MLJFlux.fit!(model::LaplaceRegression, chain,penalty,optimiser,epochs, verbosity, X, y) + + X = MLJBase.matrix(X, transpose=true) + la = LaplaceRedux.Laplace( + chain; + likelihood=:regression, + subset_of_weights=model.subset_of_weights, + subnetwork_indices=model.subnetwork_indices, + hessian_structure=model.hessian_structure, + backend=model.backend, + σ=model.σ, + μ₀=model.μ₀, + P₀=model.P₀) + n_samples= size(X,1) + # Initialize history: + history = [] + verbose_laplace=false + # intitialize and start progress meter: + meter = Progress( + epochs + 1; + dt=1.0, + desc="Optimising neural net:", + barglyphs=BarGlyphs("[=> ]"), + barlen=25, + color=:yellow, + ) + # Create a data loader + loader = Flux.Data.DataLoader((data=X, label=y), batchsize=model.batch_size, shuffle=true) + parameters = Flux.params(chain) + for i in 1:epochs + epoch_loss = 0.0 + # train the model + for (X_batch, y_batch) in loader + y_batch = reshape(y_batch,1,:) + + # Backward pass + gs = Flux.gradient(parameters) do + batch_loss = (model.loss(chain(X_batch), y_batch) + penalty(Flux.params(chain)) ) + epoch_loss += batch_loss + end + # Update parameters + Flux.update!(optimiser, parameters,gs) + end + epoch_loss /= n_samples + push!(history, epoch_loss) + #verbosity + if verbosity == 1 + next!(meter) + elseif verbosity ==2 + next!(meter) + verbose_laplace=true + println( "Loss is $(round(epoch_loss; sigdigits=4))") + end + end + + + + # fit the Laplace model: + LaplaceRedux.fit!(la,zip(eachcol(X),y)) + optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) + report=[] + + return (la, history, report) +end + + +""" + predict(model::LaplaceRegression, Xnew) + +Predict the output for new input data using a Laplace regression model. + +# Arguments +- `model::LaplaceRegression`: The trained Laplace regression model. +- `Xnew`: The new input data. + +# Returns +- The predicted output for the new input data. + +""" +function MLJFlux.predict(model::LaplaceRegression, Xnew) + Xnew = MLJBase.matrix(Xnew) + #convert in a vector of vectors because MLJ ask to do so + X_vec= [Xnew[i,:] for i in 1:size(Xnew, 1)] + #inizialize output vector yhat + yhat=[] + # Predict using Laplace and collect the predictions + yhat = [glm_predictive_distribution(model.la, x_vec) for x_vec in X_vec] + + predictions = [] + for row in eachrow(yhat) + + mean_val = Float64(row[1][1][1]) + std_val = sqrt(Float64(row[1][2][1])) + # Append a Normal distribution: + push!(predictions, Normal(mean_val, std_val)) + end + + return predictions + +end + + + # Then for each model, MLJBase.metadata_model( LaplaceClassification; - input=Union{ + input=Union{ + AbstractMatrix{MLJBase.Finite}, + MLJBase.Table(MLJBase.Finite), AbstractMatrix{MLJBase.Continuous}, MLJBase.Table(MLJBase.Continuous), MLJBase.Table{AbstractVector{MLJBase.Continuous}}, + MLJBase.Table{AbstractVector{MLJBase.Finite}} }, target=Union{ AbstractArray{MLJBase.Finite}, AbstractArray{MLJBase.Continuous}, - AbstractVector{MLJBase.Finite}, - AbstractVector{MLJBase.Continuous}, }, path="MLJFlux.LaplaceClassification", ) @@ -348,13 +411,14 @@ MLJBase.metadata_model( input=Union{ AbstractMatrix{MLJBase.Continuous}, MLJBase.Table(MLJBase.Continuous), + AbstractMatrix{MLJBase.Finite}, + MLJBase.Table(MLJBase.Finite), MLJBase.Table{AbstractVector{MLJBase.Continuous}}, + MLJBase.Table{AbstractVector{MLJBase.Finite}} }, target=Union{ AbstractArray{MLJBase.Finite}, AbstractArray{MLJBase.Continuous}, - AbstractVector{MLJBase.Finite}, - AbstractVector{MLJBase.Continuous}, }, path="MLJFlux.LaplaceRegression", ) diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index b30d903..078bca0 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -6,12 +6,12 @@ using MLJFlux using Flux using StableRNGs -function basictest(X, y, builder, optimiser, threshold) +function basictest_regression(X, y, builder, optimiser, threshold) optimiser = deepcopy(optimiser) stable_rng = StableRNGs.StableRNG(123) - model = LaplaceApproximation(; + model = LaplaceRegression(; builder=builder, optimiser=optimiser, acceleration=CPUThreads(), @@ -53,7 +53,99 @@ function basictest(X, y, builder, optimiser, threshold) @test length(history) == model.epochs + 1 # start fresh with small epochs: - model = LaplaceApproximation(; + model = LaplaceRegression(; + builder=builder, optimiser=optimiser, epochs=2, acceleration=CPU1(), rng=stable_rng + ) + + fitresult, cache, _report = MLJBase.fit(model, 0, X, y) + + # change batch_size and check it performs cold restart: + model.batch_size = 2 + fitresult, cache, _report = @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + # change learning rate and check it does *not* restart: + model.optimiser.eta /= 2 + fitresult, cache, _report = @test_logs(MLJBase.update(model, 2, fitresult, cache, X, y)) + + # set `optimiser_changes_trigger_retraining = true` and change + # learning rate and check it does restart: + model.optimiser_changes_trigger_retraining = true + model.optimiser.eta /= 2 + @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + return true +end + +seed!(1234) +N = 300 +X = MLJBase.table(rand(Float32, N, 4)); +ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N) + +builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) +optimizer = Flux.Optimise.Adam(0.03) + +@test basictest_regression(X, y, builder, optimizer, 0.9) + + + + + +function basictest_classification(X, y, builder, optimiser, threshold) + optimiser = deepcopy(optimiser) + + stable_rng = StableRNGs.StableRNG(123) + + model = LaplaceRegression(; + builder=builder, + optimiser=optimiser, + acceleration=CPUThreads(), + rng=stable_rng, + lambda=-1.0, + alpha=-1.0, + epochs=-1, + batch_size=-1, + likelihood=:incorrect, + subset_of_weights=:incorrect, + hessian_structure=:incorrect, + backend=:incorrect, + link_approx=:incorrect, + ) + + fitresult, cache, _report = MLJBase.fit(model, 0, X, y) + + history = _report.training_losses + @test length(history) == model.epochs + 1 + + # test improvement in training loss: + @test history[end] < threshold * history[1] + + # increase iterations and check update is incremental: + model.epochs = model.epochs + 3 + + fitresult, cache, _report = @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + @test :chain in keys(MLJBase.fitted_params(model, fitresult)) + + yhat = MLJBase.predict(model, fitresult, X) + + history = _report.training_losses + @test length(history) == model.epochs + 1 + + # start fresh with small epochs: + model = LaplaceRegression(; builder=builder, optimiser=optimiser, epochs=2, acceleration=CPU1(), rng=stable_rng ) @@ -105,4 +197,4 @@ y = categorical( builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) optimizer = Flux.Optimise.Adam(0.03) -@test basictest(X, y, builder, optimizer, 0.9) +@test basictest_classification(X, y, builder, optimizer, 0.9) From 19e9b7fbb74f08d7a7d02f28c1a67870417b23cc Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 12 Jun 2024 03:48:25 +0200 Subject: [PATCH 08/32] ops forgot changelog --- CHANGELOG.md | 4 ++-- docs/Project.toml | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8fda09..c14fce1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,9 +9,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), ## Version [0.3.0] - 2024-06-8 ### Changed - +- fixed test functions - adapted the LaplaceClassification and the LaplaceRegression struct to use the new @mlj_model macro from MLJBase. -- Changed the fit! method arguments. Now it also accept a Flux chain model instead of retrieving it from the structs. this is due to the fact that MLJ wants only hyperparameters in the struct https://juliaai.github.io/MLJModelInterface.jl/dev/quick_start_guide/ +- Changed the fit! method arguments. - Changed the predict functions for both LaplaceClassification and LaplaceRegression. ### Removed diff --git a/docs/Project.toml b/docs/Project.toml index 49942c8..9459681 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,7 @@ [deps] +CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" @@ -8,6 +11,9 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TaijaPlotting = "bd7198b4-c7d6-400c-9bab-9a24614b0240" +Trapz = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" From 927527714df228bda74bc4388eaa867e4c7d2a8e Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Thu, 13 Jun 2024 13:21:21 +0200 Subject: [PATCH 09/32] partial fixes. fit! function gives trouble --- src/baselaplace/predicting.jl | 10 +++++--- src/mlj_flux.jl | 47 ++++++++++++++++++++++------------- test/mlj_flux_interfacing.jl | 19 +++++++------- 3 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index 49773a2..fcde981 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -1,3 +1,4 @@ +using Distributions """ functional_variance(la::AbstractLaplace, 𝐉::AbstractArray) @@ -39,7 +40,9 @@ function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray) fμ = reshape(fμ, Flux.outputsize(la.model, size(X))) fvar = functional_variance(la, 𝐉) fvar = reshape(fvar, size(fμ)...) - return fμ, fvar + fstd = sqrt.(fvar) + normal_distr= [Distributions.Normal(fμ[i, j], fstd[i, j]) for i in 1:size(fμ, 1), j in 1:size(fμ, 2)] + return normal_distr end """ @@ -75,11 +78,12 @@ predict(la, hcat(x...)) function predict( la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true ) - fμ, fvar = glm_predictive_distribution(la, X) + normal_distr = glm_predictive_distribution(la, X) + fμ, fvar = mean.(normal_distr), var.(normal_distr) # Regression: if la.likelihood == :regression - return fμ, fvar + return normal_distr end # Classification: diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 9555782..eba6c1c 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -6,6 +6,7 @@ using Tables using Distributions using LinearAlgebra using LaplaceRedux +using ComputationalResources using MLJBase import MLJBase: @mlj_model, metadata_model, metadata_pkg @@ -38,13 +39,14 @@ The model is trained using the `fit!` method. The model is defined by the follow @mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic builder=MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) optimiser= Flux.Optimise.Adam() - loss=Flux.Losses.mse + loss=Flux.Losses.mse + epochs::Int=10::(_ > 0) batch_size::Int=1::(_ > 0) lambda::Float64=1.0 alpha::Float64=0.0 rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG optimiser_changes_trigger_retraining::Bool=false::(_ in (true, false)) - acceleration::AbstractResource=CPU1() + acceleration=CPU1()::(_ in (CPU1(), CUDALibs())) subset_of_weights::Symbol=:all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices=nothing hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in (":full", ":diagonal")) @@ -88,13 +90,13 @@ A mutable struct representing a Laplace Classification model that extends the ML finaliser=Flux.softmax optimiser= Flux.Optimise.Adam() loss=Flux.crossentropy - epochs::Int=10 + epochs::Int=10::(_ > 0) batch_size::Int=1::(_ > 0) lambda::Float64=1.0 alpha::Float64=0.0 rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG optimiser_changes_trigger_retraining::Bool=false - acceleration::AbstractResource=CPU1() + acceleration=CPU1()::(_ in (CPU1(), CUDALibs())) subset_of_weights::Symbol=:all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices::Vector{Vector{Int}}=Vector{Vector{Int}}([]) hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in (":full", ":diagonal")) @@ -113,6 +115,7 @@ const MLJ_Laplace= Union{LaplaceClassification,LaplaceRegression} ################################ functions shape and build function MLJFlux.shape(model::LaplaceClassification, X, y) + X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X n_input = size(X,2) levels = unique(y) n_output = length(levels) @@ -121,6 +124,7 @@ function MLJFlux.shape(model::LaplaceClassification, X, y) end function MLJFlux.shape(model::LaplaceRegression, X, y) + X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X n_input = size(X,2) dims = size(y) if length(dims) == 1 @@ -132,15 +136,23 @@ function MLJFlux.shape(model::LaplaceRegression, X, y) end -function MLJFlux.build(model::MLJ_Laplace, shape) +function MLJFlux.build(model::LaplaceClassification,rng, shape) #chain - chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) + chain = Flux.Chain(MLJFlux.build(model.builder, rng,shape...), model.finaliser) + + return chain +end + + +function MLJFlux.build(model::LaplaceRegression,rng,shape) + #chain + chain = MLJFlux.build(model.builder,rng , shape...) return chain end function MLJFlux.fitresult(model::LaplaceClassification, chain, y) - return (chain, length(unique(y_cl))) + return (chain, length(unique(y))) end function MLJFlux.fitresult(model::LaplaceRegression, chain, y) @@ -174,7 +186,8 @@ Fit the LaplaceClassification model using MLJFlux. - `model::LaplaceClassification`: The fitted LaplaceClassification model. """ function MLJFlux.fit!(model::LaplaceClassification, chain,penalty,optimiser,epochs, verbosity, X, y) - X = MLJBase.matrix(X, transpose=true) + X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X + X = X' # Integer encode the target variable y #y_onehot = unique(y) .== permutedims(y) @@ -286,8 +299,8 @@ Fit the LaplaceRegression model using Flux.jl. - `model::LaplaceRegression`: The fitted LaplaceRegression model. """ function MLJFlux.fit!(model::LaplaceRegression, chain,penalty,optimiser,epochs, verbosity, X, y) - - X = MLJBase.matrix(X, transpose=true) + X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X + X = X' la = LaplaceRedux.Laplace( chain; likelihood=:regression, @@ -373,16 +386,16 @@ function MLJFlux.predict(model::LaplaceRegression, Xnew) # Predict using Laplace and collect the predictions yhat = [glm_predictive_distribution(model.la, x_vec) for x_vec in X_vec] - predictions = [] - for row in eachrow(yhat) + #predictions = [] + #for row in eachrow(yhat) - mean_val = Float64(row[1][1][1]) - std_val = sqrt(Float64(row[1][2][1])) + #mean_val = Float64(row[1][1][1]) + #std_val = sqrt(Float64(row[1][2][1])) # Append a Normal distribution: - push!(predictions, Normal(mean_val, std_val)) - end + #push!(predictions, Normal(mean_val, std_val)) + #end - return predictions + return yhat end diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index 078bca0..7d0122e 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -15,16 +15,15 @@ function basictest_regression(X, y, builder, optimiser, threshold) builder=builder, optimiser=optimiser, acceleration=CPUThreads(), + loss= Flux.Losses.mse, rng=stable_rng, lambda=-1.0, alpha=-1.0, epochs=-1, batch_size=-1, - likelihood=:incorrect, subset_of_weights=:incorrect, hessian_structure=:incorrect, - backend=:incorrect, - link_approx=:incorrect, + backend=:incorrect ) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) @@ -92,7 +91,7 @@ ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N) builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) optimizer = Flux.Optimise.Adam(0.03) -@test basictest_regression(X, y, builder, optimizer, 0.9) +@test basictest_regression(X, ycont, builder, optimizer, 0.9) @@ -103,16 +102,16 @@ function basictest_classification(X, y, builder, optimiser, threshold) stable_rng = StableRNGs.StableRNG(123) - model = LaplaceRegression(; + model = LaplaceClassification(; builder=builder, optimiser=optimiser, - acceleration=CPUThreads(), - rng=stable_rng, - lambda=-1.0, - alpha=-1.0, + loss= Flux.crossentropy, epochs=-1, batch_size=-1, - likelihood=:incorrect, + lambda=-1.0, + alpha=-1.0, + rng=stable_rng, + acceleration=CPUThreads(), subset_of_weights=:incorrect, hessian_structure=:incorrect, backend=:incorrect, From 651a9bb0591eba90c51218f57e419b97c2a73177 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Fri, 14 Jun 2024 05:01:17 +0200 Subject: [PATCH 10/32] fixed the version of Distribution used in project.toml (compat). this solved one of the issue. the remaining two are still there --- Project.toml | 2 ++ src/mlj_flux.jl | 27 +++++++++++++++++---------- test/mlj_flux_interfacing.jl | 6 ++---- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 64f52f0..729a294 100644 --- a/Project.toml +++ b/Project.toml @@ -26,8 +26,10 @@ Aqua = "0.8" ChainRulesCore = "1.23.0" Compat = "4.7.0" ComputationalResources = "0.3.2" +Distributions = "0.25.109" Flux = "0.12, 0.13, 0.14" LinearAlgebra = "1.6, 1.7, 1.8, 1.9, 1.10" +MLJBase = "1.4.0" MLJFlux = "0.2.10, 0.3, 0.4" MLJModelInterface = "1.8.0" MLUtils = "0.4.3" diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index eba6c1c..517977c 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -136,17 +136,17 @@ function MLJFlux.shape(model::LaplaceRegression, X, y) end -function MLJFlux.build(model::LaplaceClassification,rng, shape) +function MLJFlux.build(model::LaplaceClassification, shape) #chain - chain = Flux.Chain(MLJFlux.build(model.builder, rng,shape...), model.finaliser) + chain = Flux.Chain(MLJFlux.build(model.builder, model.rng,shape...), model.finaliser) return chain end -function MLJFlux.build(model::LaplaceRegression,rng,shape) +function MLJFlux.build(model::LaplaceRegression,shape) #chain - chain = MLJFlux.build(model.builder,rng , shape...) + chain = MLJFlux.build(model.builder,model.rng , shape...) return chain end @@ -185,7 +185,7 @@ Fit the LaplaceClassification model using MLJFlux. # Returns - `model::LaplaceClassification`: The fitted LaplaceClassification model. """ -function MLJFlux.fit!(model::LaplaceClassification, chain,penalty,optimiser,epochs, verbosity, X, y) +function MLJFlux.fit!(model::LaplaceClassification, penalty,chain,optimiser,epochs, verbosity, X, y) X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X X = X' @@ -226,7 +226,7 @@ function MLJFlux.fit!(model::LaplaceClassification, chain,penalty,optimiser,epoc # Backward pass gs = Flux.gradient(parameters) do - batch_loss = (model.loss(chain(X_batch), y_batch) + penalty(Flux.params(chain)) ) + batch_loss = (model.loss(chain(X_batch), y_batch) + penalty(parameters) ) epoch_loss += batch_loss end # Update parameters @@ -298,9 +298,9 @@ Fit the LaplaceRegression model using Flux.jl. # Returns - `model::LaplaceRegression`: The fitted LaplaceRegression model. """ -function MLJFlux.fit!(model::LaplaceRegression, chain,penalty,optimiser,epochs, verbosity, X, y) - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - X = X' +function MLJFlux.fit!(model::LaplaceRegression, penalty,chain,optimiser,epochs, verbosity, X, y) + X = X isa Tables.MatrixTable ? MLJBase.matrix(X,transpose=true) : X + #X = MLJ.matrix(X,transpose=true) la = LaplaceRedux.Laplace( chain; likelihood=:regression, @@ -324,6 +324,8 @@ function MLJFlux.fit!(model::LaplaceRegression, chain,penalty,optimiser,epochs, barlen=25, color=:yellow, ) + println("Shape of X prima di loader: ", size(X)) + println("Shape of y prima di loader : ", size(y)) # Create a data loader loader = Flux.Data.DataLoader((data=X, label=y), batchsize=model.batch_size, shuffle=true) parameters = Flux.params(chain) @@ -331,11 +333,16 @@ function MLJFlux.fit!(model::LaplaceRegression, chain,penalty,optimiser,epochs, epoch_loss = 0.0 # train the model for (X_batch, y_batch) in loader + y_batch = reshape(y_batch,1,:) + println("Shape of X_batch: ", size(X_batch)) + println("Shape of y_batch: ", size(y_batch)) + X_batch = Flux.flatten(X_batch) + y_batch = Flux.flatten(y_batch) # Backward pass gs = Flux.gradient(parameters) do - batch_loss = (model.loss(chain(X_batch), y_batch) + penalty(Flux.params(chain)) ) + batch_loss = (model.loss(chain(X_batch), y_batch) +penalty(parameters) ) epoch_loss += batch_loss end # Update parameters diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index 7d0122e..b8808ce 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -53,8 +53,7 @@ function basictest_regression(X, y, builder, optimiser, threshold) # start fresh with small epochs: model = LaplaceRegression(; - builder=builder, optimiser=optimiser, epochs=2, acceleration=CPU1(), rng=stable_rng - ) + builder=builder, optimiser=optimiser, epochs=2, acceleration=CPU1(), rng=stable_rng) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) @@ -145,8 +144,7 @@ function basictest_classification(X, y, builder, optimiser, threshold) # start fresh with small epochs: model = LaplaceRegression(; - builder=builder, optimiser=optimiser, epochs=2, acceleration=CPU1(), rng=stable_rng - ) + builder=builder, optimiser=optimiser, epochs=2, acceleration=CPU1(), rng=stable_rng) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) From 56bd1de8e97c4f0ed14ae9af7cd1267a4a076d84 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sat, 15 Jun 2024 07:33:19 +0200 Subject: [PATCH 11/32] last attempt. --- src/mlj_flux.jl | 98 +++++++++++++++++++----------------- test/mlj_flux_interfacing.jl | 7 +-- 2 files changed, 56 insertions(+), 49 deletions(-) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 517977c..3f45eb8 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -49,7 +49,7 @@ The model is trained using the `fit!` method. The model is defined by the follow acceleration=CPU1()::(_ in (CPU1(), CUDALibs())) subset_of_weights::Symbol=:all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices=nothing - hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in (":full", ":diagonal")) + hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in (:full, :diagonal)) backend::Symbol=:GGN::(_ in (:GGN, :EmpiricalFisher)) σ::Float64=1.0 μ₀::Float64=0.0 @@ -99,7 +99,7 @@ A mutable struct representing a Laplace Classification model that extends the ML acceleration=CPU1()::(_ in (CPU1(), CUDALibs())) subset_of_weights::Symbol=:all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices::Vector{Vector{Int}}=Vector{Vector{Int}}([]) - hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in (":full", ":diagonal")) + hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in (:full, :diagonal)) backend::Symbol=:GGN::(_ in (:GGN, :EmpiricalFisher)) σ::Float64=1.0 μ₀::Float64=0.0 @@ -136,17 +136,17 @@ function MLJFlux.shape(model::LaplaceRegression, X, y) end -function MLJFlux.build(model::LaplaceClassification, shape) +function MLJFlux.build(model::LaplaceClassification,rng, shape) #chain - chain = Flux.Chain(MLJFlux.build(model.builder, model.rng,shape...), model.finaliser) + chain = Flux.Chain(MLJFlux.build(model.builder, rng,shape...), model.finaliser) return chain end -function MLJFlux.build(model::LaplaceRegression,shape) +function MLJFlux.build(model::LaplaceRegression,rng,shape) #chain - chain = MLJFlux.build(model.builder,model.rng , shape...) + chain = MLJFlux.build(model.builder,rng , shape...) return chain end @@ -174,23 +174,22 @@ Fit the LaplaceClassification model using MLJFlux. # Arguments - `model::LaplaceClassification`: The LaplaceClassification object to fit. -- `chain`: The chain to use during training. -- `penalty`: a penalty function to add to the loss function during training. -- `optimiser`: the optimiser to use during training. -- `epochs`: the number of epochs use for training. - `verbosity`: The verbosity level for training. - `X`: The input data for training. -- `y`: The target labels for training. +- `y`: The target labels for training one-hot encoded. # Returns -- `model::LaplaceClassification`: The fitted LaplaceClassification model. +- (la,chain), history, report) +where la is the fitted Laplace model. """ -function MLJFlux.fit!(model::LaplaceClassification, penalty,chain,optimiser,epochs, verbosity, X, y) - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - X = X' +function MLJFlux.fit!(model::LaplaceClassification,verbosity, X, y) + X = MLJBase.matrix(X) + + shape= MLJFlux.shape(model, X,y) + + chain= MLJFlux.build(model,model.rng, shape ) + - # Integer encode the target variable y - #y_onehot = unique(y) .== permutedims(y) la = LaplaceRedux.Laplace( chain; @@ -209,7 +208,7 @@ function MLJFlux.fit!(model::LaplaceClassification, penalty,chain,optimiser,epoc history = [] # intitialize and start progress meter: meter = Progress( - epochs + 1; + model.epochs + 1; dt=0, desc="Optimising neural net:", barglyphs=BarGlyphs("[=> ]"), @@ -217,20 +216,20 @@ function MLJFlux.fit!(model::LaplaceClassification, penalty,chain,optimiser,epoc color=:yellow, ) # Create a data loader - loader = Flux.Data.DataLoader((data=X, label=y), batchsize=model.batch_size, shuffle=true) + loader = Flux.Data.DataLoader((data=X', label=y), batchsize=model.batch_size, shuffle=true) parameters = Flux.params(chain) - for i in 1:epochs + for i in 1: model.epochs epoch_loss = 0.0 # train the model for (X_batch, y_batch) in loader # Backward pass gs = Flux.gradient(parameters) do - batch_loss = (model.loss(chain(X_batch), y_batch) + penalty(parameters) ) + batch_loss = (model.loss(chain(X_batch), y_batch) ) epoch_loss += batch_loss end # Update parameters - Flux.update!(optimiser, parameters,gs) + Flux.update!(model.optimiser, parameters,gs) end epoch_loss /= n_samples push!(history, epoch_loss) @@ -245,11 +244,11 @@ function MLJFlux.fit!(model::LaplaceClassification, penalty,chain,optimiser,epoc end # fit the Laplace model: - LaplaceRedux.fit!(la,zip(eachcol(X),y)) + LaplaceRedux.fit!(la,zip(eachrow(X),y)) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) report=[] - return (la, history, report) + return ((la,chain), history, report) end @@ -260,19 +259,21 @@ Predicts the class labels for new data using the LaplaceClassification model. # Arguments - `model::LaplaceClassification`: The trained LaplaceClassification model. +- fitresult: the fitresult output produced by MLJFlux.fit! - `Xnew`: The new data to make predictions on. # Returns An array of predicted class labels. """ -function MLJFlux.predict(model::LaplaceClassification, Xnew) +function MLJFlux.predict(model::LaplaceClassification,fitresult, Xnew) + la= fitresult[1] Xnew = MLJBase.matrix(Xnew) #convert in a vector of vectors because Laplace ask to do so X_vec= X_vec= [Xnew[i,:] for i in 1:size(Xnew, 1)] # Predict using Laplace and collect the predictions - predictions = [LaplaceRedux.predict(model.la, x;link_approx= model.link_approx,predict_proba=model.predict_proba) for x in X_vec] + predictions = [LaplaceRedux.predict(la, x;link_approx= model.link_approx,predict_proba=model.predict_proba) for x in X_vec] return predictions @@ -287,10 +288,6 @@ Fit the LaplaceRegression model using Flux.jl. # Arguments - `model::LaplaceRegression`: The LaplaceRegression model. -- `chain`: The chain to use during training. -- `penalty`: a penalty function to add to the loss function during training. -- `optimiser`: the optimiser to use during training. -- `epochs`: the number of epochs use for training. - `verbosity`: The verbosity level for training. - `X`: The input data for training. - `y`: The target labels for training. @@ -298,9 +295,16 @@ Fit the LaplaceRegression model using Flux.jl. # Returns - `model::LaplaceRegression`: The fitted LaplaceRegression model. """ -function MLJFlux.fit!(model::LaplaceRegression, penalty,chain,optimiser,epochs, verbosity, X, y) - X = X isa Tables.MatrixTable ? MLJBase.matrix(X,transpose=true) : X - #X = MLJ.matrix(X,transpose=true) +function MLJFlux.fit!(model::LaplaceRegression, verbosity, X, y) + #X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X + + X = MLJBase.matrix(X) + + shape= MLJFlux.shape(model, X,y) + + chain= MLJFlux.build(model, model.rng,shape ) + + la = LaplaceRedux.Laplace( chain; likelihood=:regression, @@ -317,7 +321,7 @@ function MLJFlux.fit!(model::LaplaceRegression, penalty,chain,optimiser,epochs, verbose_laplace=false # intitialize and start progress meter: meter = Progress( - epochs + 1; + model.epochs + 1; dt=1.0, desc="Optimising neural net:", barglyphs=BarGlyphs("[=> ]"), @@ -327,26 +331,24 @@ function MLJFlux.fit!(model::LaplaceRegression, penalty,chain,optimiser,epochs, println("Shape of X prima di loader: ", size(X)) println("Shape of y prima di loader : ", size(y)) # Create a data loader - loader = Flux.Data.DataLoader((data=X, label=y), batchsize=model.batch_size, shuffle=true) + loader = Flux.Data.DataLoader((data=X', label=y), batchsize=model.batch_size, shuffle=true) parameters = Flux.params(chain) - for i in 1:epochs + for i in 1:model.epochs epoch_loss = 0.0 # train the model for (X_batch, y_batch) in loader y_batch = reshape(y_batch,1,:) - println("Shape of X_batch: ", size(X_batch)) - println("Shape of y_batch: ", size(y_batch)) - X_batch = Flux.flatten(X_batch) - y_batch = Flux.flatten(y_batch) + println("Shape of X_batch dopo di loader: ", size(X_batch)) + println("Shape of y_batch dopo di loader : ", size(y_batch)) # Backward pass gs = Flux.gradient(parameters) do - batch_loss = (model.loss(chain(X_batch), y_batch) +penalty(parameters) ) + batch_loss = (model.loss(chain(X_batch), y_batch) ) epoch_loss += batch_loss end # Update parameters - Flux.update!(optimiser, parameters,gs) + Flux.update!(model.optimiser, parameters,gs) end epoch_loss /= n_samples push!(history, epoch_loss) @@ -363,11 +365,11 @@ function MLJFlux.fit!(model::LaplaceRegression, penalty,chain,optimiser,epochs, # fit the Laplace model: - LaplaceRedux.fit!(la,zip(eachcol(X),y)) + LaplaceRedux.fit!(la,zip(eachrow(X),y)) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) report=[] - return (la, history, report) + return ((la,chain), history, report) end @@ -378,20 +380,24 @@ Predict the output for new input data using a Laplace regression model. # Arguments - `model::LaplaceRegression`: The trained Laplace regression model. +- the fitresult output produced by MLJFlux.fit! - `Xnew`: The new input data. # Returns - The predicted output for the new input data. """ -function MLJFlux.predict(model::LaplaceRegression, Xnew) +function MLJFlux.predict(model::LaplaceRegression,fitresult, Xnew) + Xnew = MLJBase.matrix(Xnew) + + la= fitresult[1] #convert in a vector of vectors because MLJ ask to do so X_vec= [Xnew[i,:] for i in 1:size(Xnew, 1)] #inizialize output vector yhat yhat=[] # Predict using Laplace and collect the predictions - yhat = [glm_predictive_distribution(model.la, x_vec) for x_vec in X_vec] + yhat = [glm_predictive_distribution(la, x_vec) for x_vec in X_vec] #predictions = [] #for row in eachrow(yhat) diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index b8808ce..7dbf0df 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -46,7 +46,7 @@ function basictest_regression(X, y, builder, optimiser, threshold) @test :chain in keys(MLJBase.fitted_params(model, fitresult)) - yhat = MLJBase.predict(model, fitresult, X) + #yhat = MLJBase.predict(model, fitresult, X) history = _report.training_losses @test length(history) == model.epochs + 1 @@ -137,7 +137,7 @@ function basictest_classification(X, y, builder, optimiser, threshold) @test :chain in keys(MLJBase.fitted_params(model, fitresult)) - yhat = MLJBase.predict(model, fitresult, X) + #yhat = MLJBase.predict(model, fitresult, X) history = _report.training_losses @test length(history) == model.epochs + 1 @@ -193,5 +193,6 @@ y = categorical( builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) optimizer = Flux.Optimise.Adam(0.03) +y_onehot= transpose(unique(y) .== permutedims(y)) -@test basictest_classification(X, y, builder, optimizer, 0.9) +@test basictest_classification(X, y_onehot, builder, optimizer, 0.9) From a5a327dd331c56f7741cd3dcba5732e468d6f838 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Sat, 15 Jun 2024 09:03:27 +0200 Subject: [PATCH 12/32] formatting --- .github/workflows/FormatCheck.yml | 8 +- CHANGELOG.md | 5 +- src/baselaplace/predicting.jl | 7 +- src/mlj_flux.jl | 325 ++++++++++++++---------------- test/mlj_flux_interfacing.jl | 18 +- 5 files changed, 169 insertions(+), 194 deletions(-) diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml index e576a1b..e924c1c 100644 --- a/.github/workflows/FormatCheck.yml +++ b/.github/workflows/FormatCheck.yml @@ -25,10 +25,4 @@ jobs: run: | using JuliaFormatter format("."; verbose=true) - shell: julia --color=yes {0} - - name: Suggest formatting changes - uses: reviewdog/action-suggester@v1 - if: github.event_name == 'pull_request' - with: - tool_name: JuliaFormatter - fail_on_error: true \ No newline at end of file + shell: julia --color=yes {0} \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index c14fce1..1473dd4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,13 +9,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), ## Version [0.3.0] - 2024-06-8 ### Changed + - fixed test functions - adapted the LaplaceClassification and the LaplaceRegression struct to use the new @mlj_model macro from MLJBase. - Changed the fit! method arguments. - Changed the predict functions for both LaplaceClassification and LaplaceRegression. ### Removed -- Removed the shape , build and clean! functions. + +- Removed the shape, build and clean! functions. +- Removed Review dog for code format suggestions. [#39] ## Version [0.2.3] - 2024-05-31 diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index fcde981..674c60f 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -41,7 +41,10 @@ function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray) fvar = functional_variance(la, 𝐉) fvar = reshape(fvar, size(fμ)...) fstd = sqrt.(fvar) - normal_distr= [Distributions.Normal(fμ[i, j], fstd[i, j]) for i in 1:size(fμ, 1), j in 1:size(fμ, 2)] + normal_distr = [ + Distributions.Normal(fμ[i, j], fstd[i, j]) for i in 1:size(fμ, 1), + j in 1:size(fμ, 2) + ] return normal_distr end @@ -79,7 +82,7 @@ function predict( la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true ) normal_distr = glm_predictive_distribution(la, X) - fμ, fvar = mean.(normal_distr), var.(normal_distr) + fμ, fvar = mean.(normal_distr), var.(normal_distr) # Regression: if la.likelihood == :regression diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 3f45eb8..f502e49 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -10,7 +10,6 @@ using ComputationalResources using MLJBase import MLJBase: @mlj_model, metadata_model, metadata_pkg - """ @mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic @@ -37,28 +36,27 @@ The model is trained using the `fit!` method. The model is defined by the follow - `fit_prior_nsteps`: the number of steps used to fit the priors. """ @mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic - builder=MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) - optimiser= Flux.Optimise.Adam() - loss=Flux.Losses.mse - epochs::Int=10::(_ > 0) - batch_size::Int=1::(_ > 0) - lambda::Float64=1.0 - alpha::Float64=0.0 - rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG - optimiser_changes_trigger_retraining::Bool=false::(_ in (true, false)) - acceleration=CPU1()::(_ in (CPU1(), CUDALibs())) - subset_of_weights::Symbol=:all::(_ in (:all, :last_layer, :subnetwork)) - subnetwork_indices=nothing - hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in (:full, :diagonal)) - backend::Symbol=:GGN::(_ in (:GGN, :EmpiricalFisher)) - σ::Float64=1.0 - μ₀::Float64=0.0 - P₀::Union{AbstractMatrix,UniformScaling,Nothing}=nothing - fit_prior_nsteps::Int=100::(_ > 0) + builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) + optimiser = Flux.Optimise.Adam() + loss = Flux.Losses.mse + epochs::Int = 10::(_ > 0) + batch_size::Int = 1::(_ > 0) + lambda::Float64 = 1.0 + alpha::Float64 = 0.0 + rng::Union{AbstractRNG,Int64} = Random.GLOBAL_RNG + optimiser_changes_trigger_retraining::Bool = false::(_ in (true, false)) + acceleration = CPU1()::(_ in (CPU1(), CUDALibs())) + subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) + subnetwork_indices = nothing + hessian_structure::Union{HessianStructure,Symbol,String} = + :full::(_ in (:full, :diagonal)) + backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) + σ::Float64 = 1.0 + μ₀::Float64 = 0.0 + P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing + fit_prior_nsteps::Int = 100::(_ > 0) end - - """ @mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic @@ -86,87 +84,78 @@ A mutable struct representing a Laplace Classification model that extends the ML - `fit_prior_nsteps`: the number of steps used to fit the priors. """ @mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic - builder=MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) - finaliser=Flux.softmax - optimiser= Flux.Optimise.Adam() - loss=Flux.crossentropy - epochs::Int=10::(_ > 0) - batch_size::Int=1::(_ > 0) - lambda::Float64=1.0 - alpha::Float64=0.0 - rng::Union{AbstractRNG,Int64}=Random.GLOBAL_RNG - optimiser_changes_trigger_retraining::Bool=false - acceleration=CPU1()::(_ in (CPU1(), CUDALibs())) - subset_of_weights::Symbol=:all::(_ in (:all, :last_layer, :subnetwork)) - subnetwork_indices::Vector{Vector{Int}}=Vector{Vector{Int}}([]) - hessian_structure::Union{HessianStructure,Symbol,String}=:full::(_ in (:full, :diagonal)) - backend::Symbol=:GGN::(_ in (:GGN, :EmpiricalFisher)) - σ::Float64=1.0 - μ₀::Float64=0.0 - P₀::Union{AbstractMatrix,UniformScaling,Nothing}=nothing - link_approx::Symbol=:probit::(_ in (:probit,:plugin)) - predict_proba::Bool= true::(_ in (true,false)) - fit_prior_nsteps::Int=100::(_ > 0) - + builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) + finaliser = Flux.softmax + optimiser = Flux.Optimise.Adam() + loss = Flux.crossentropy + epochs::Int = 10::(_ > 0) + batch_size::Int = 1::(_ > 0) + lambda::Float64 = 1.0 + alpha::Float64 = 0.0 + rng::Union{AbstractRNG,Int64} = Random.GLOBAL_RNG + optimiser_changes_trigger_retraining::Bool = false + acceleration = CPU1()::(_ in (CPU1(), CUDALibs())) + subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) + subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([]) + hessian_structure::Union{HessianStructure,Symbol,String} = + :full::(_ in (:full, :diagonal)) + backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) + σ::Float64 = 1.0 + μ₀::Float64 = 0.0 + P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing + link_approx::Symbol = :probit::(_ in (:probit, :plugin)) + predict_proba::Bool = true::(_ in (true, false)) + fit_prior_nsteps::Int = 100::(_ > 0) end ############### -const MLJ_Laplace= Union{LaplaceClassification,LaplaceRegression} +const MLJ_Laplace = Union{LaplaceClassification,LaplaceRegression} ################################ functions shape and build function MLJFlux.shape(model::LaplaceClassification, X, y) X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - n_input = size(X,2) + n_input = size(X, 2) levels = unique(y) n_output = length(levels) return (n_input, n_output) - end function MLJFlux.shape(model::LaplaceRegression, X, y) X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - n_input = size(X,2) + n_input = size(X, 2) dims = size(y) - if length(dims) == 1 - n_output= 1 - else - n_output= dims[1] - end - return (n_input, n_output) + if length(dims) == 1 + n_output = 1 + else + n_output = dims[1] + end + return (n_input, n_output) end - -function MLJFlux.build(model::LaplaceClassification,rng, shape) +function MLJFlux.build(model::LaplaceClassification, rng, shape) #chain - chain = Flux.Chain(MLJFlux.build(model.builder, rng,shape...), model.finaliser) - + chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) + return chain end - -function MLJFlux.build(model::LaplaceRegression,rng,shape) +function MLJFlux.build(model::LaplaceRegression, rng, shape) #chain - chain = MLJFlux.build(model.builder,rng , shape...) - + chain = MLJFlux.build(model.builder, rng, shape...) + return chain end function MLJFlux.fitresult(model::LaplaceClassification, chain, y) - return (chain, length(unique(y))) + return (chain, length(unique(y))) end function MLJFlux.fitresult(model::LaplaceRegression, chain, y) - return (chain, size(y) ) + return (chain, size(y)) end - - - - - ######################################### fit and predict for classification - """ MLJFlux.fit!(model::LaplaceClassification, chain,penalty,optimiser,epochs, verbosity, X, y) @@ -182,14 +171,12 @@ Fit the LaplaceClassification model using MLJFlux. - (la,chain), history, report) where la is the fitted Laplace model. """ -function MLJFlux.fit!(model::LaplaceClassification,verbosity, X, y) +function MLJFlux.fit!(model::LaplaceClassification, verbosity, X, y) X = MLJBase.matrix(X) - shape= MLJFlux.shape(model, X,y) - - chain= MLJFlux.build(model,model.rng, shape ) - + shape = MLJFlux.shape(model, X, y) + chain = MLJFlux.build(model, model.rng, shape) la = LaplaceRedux.Laplace( chain; @@ -200,58 +187,60 @@ function MLJFlux.fit!(model::LaplaceClassification,verbosity, X, y) backend=model.backend, σ=model.σ, μ₀=model.μ₀, - P₀=model.P₀) - n_samples= size(X,1) - verbose_laplace=false - - # Initialize history: - history = [] - # intitialize and start progress meter: - meter = Progress( - model.epochs + 1; - dt=0, - desc="Optimising neural net:", - barglyphs=BarGlyphs("[=> ]"), - barlen=25, - color=:yellow, - ) - # Create a data loader - loader = Flux.Data.DataLoader((data=X', label=y), batchsize=model.batch_size, shuffle=true) - parameters = Flux.params(chain) - for i in 1: model.epochs - epoch_loss = 0.0 - # train the model - for (X_batch, y_batch) in loader - - # Backward pass - gs = Flux.gradient(parameters) do - batch_loss = (model.loss(chain(X_batch), y_batch) ) - epoch_loss += batch_loss + P₀=model.P₀, + ) + n_samples = size(X, 1) + verbose_laplace = false + + # Initialize history: + history = [] + # intitialize and start progress meter: + meter = Progress( + model.epochs + 1; + dt=0, + desc="Optimising neural net:", + barglyphs=BarGlyphs("[=> ]"), + barlen=25, + color=:yellow, + ) + # Create a data loader + loader = Flux.Data.DataLoader( + (data=X', label=y); batchsize=model.batch_size, shuffle=true + ) + parameters = Flux.params(chain) + for i in 1:(model.epochs) + epoch_loss = 0.0 + # train the model + for (X_batch, y_batch) in loader + + # Backward pass + gs = Flux.gradient(parameters) do + batch_loss = (model.loss(chain(X_batch), y_batch)) + epoch_loss += batch_loss end - # Update parameters - Flux.update!(model.optimiser, parameters,gs) + # Update parameters + Flux.update!(model.optimiser, parameters, gs) end - epoch_loss /= n_samples - push!(history, epoch_loss) - #verbosity - if verbosity == 1 + epoch_loss /= n_samples + push!(history, epoch_loss) + #verbosity + if verbosity == 1 next!(meter) - elseif verbosity ==2 + elseif verbosity == 2 next!(meter) - verbose_laplace=true - println( "Loss is $(round(epoch_loss; sigdigits=4))") + verbose_laplace = true + println("Loss is $(round(epoch_loss; sigdigits=4))") end - end + end - # fit the Laplace model: - LaplaceRedux.fit!(la,zip(eachrow(X),y)) - optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - report=[] + # fit the Laplace model: + LaplaceRedux.fit!(la, zip(eachrow(X), y)) + optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) + report = [] - return ((la,chain), history, report) + return ((la, chain), history, report) end - """ predict(model::LaplaceClassification, Xnew) @@ -266,21 +255,22 @@ Predicts the class labels for new data using the LaplaceClassification model. An array of predicted class labels. """ -function MLJFlux.predict(model::LaplaceClassification,fitresult, Xnew) - la= fitresult[1] - Xnew = MLJBase.matrix(Xnew) +function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) + la = fitresult[1] + Xnew = MLJBase.matrix(Xnew) #convert in a vector of vectors because Laplace ask to do so - X_vec= X_vec= [Xnew[i,:] for i in 1:size(Xnew, 1)] + X_vec = X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] # Predict using Laplace and collect the predictions - predictions = [LaplaceRedux.predict(la, x;link_approx= model.link_approx,predict_proba=model.predict_proba) for x in X_vec] - - return predictions + predictions = [ + LaplaceRedux.predict( + la, x; link_approx=model.link_approx, predict_proba=model.predict_proba + ) for x in X_vec + ] + return predictions end - - """ MLJFlux.fit!(model::LaplaceRegression, penalty, chain, epochs, batch_size, optimiser, verbosity, X, y) @@ -300,10 +290,9 @@ function MLJFlux.fit!(model::LaplaceRegression, verbosity, X, y) X = MLJBase.matrix(X) - shape= MLJFlux.shape(model, X,y) - - chain= MLJFlux.build(model, model.rng,shape ) + shape = MLJFlux.shape(model, X, y) + chain = MLJFlux.build(model, model.rng, shape) la = LaplaceRedux.Laplace( chain; @@ -314,11 +303,12 @@ function MLJFlux.fit!(model::LaplaceRegression, verbosity, X, y) backend=model.backend, σ=model.σ, μ₀=model.μ₀, - P₀=model.P₀) - n_samples= size(X,1) + P₀=model.P₀, + ) + n_samples = size(X, 1) # Initialize history: history = [] - verbose_laplace=false + verbose_laplace = false # intitialize and start progress meter: meter = Progress( model.epochs + 1; @@ -331,48 +321,46 @@ function MLJFlux.fit!(model::LaplaceRegression, verbosity, X, y) println("Shape of X prima di loader: ", size(X)) println("Shape of y prima di loader : ", size(y)) # Create a data loader - loader = Flux.Data.DataLoader((data=X', label=y), batchsize=model.batch_size, shuffle=true) + loader = Flux.Data.DataLoader( + (data=X', label=y); batchsize=model.batch_size, shuffle=true + ) parameters = Flux.params(chain) - for i in 1:model.epochs + for i in 1:(model.epochs) epoch_loss = 0.0 # train the model for (X_batch, y_batch) in loader - - y_batch = reshape(y_batch,1,:) + y_batch = reshape(y_batch, 1, :) println("Shape of X_batch dopo di loader: ", size(X_batch)) println("Shape of y_batch dopo di loader : ", size(y_batch)) - + # Backward pass gs = Flux.gradient(parameters) do - batch_loss = (model.loss(chain(X_batch), y_batch) ) + batch_loss = (model.loss(chain(X_batch), y_batch)) epoch_loss += batch_loss end # Update parameters - Flux.update!(model.optimiser, parameters,gs) + Flux.update!(model.optimiser, parameters, gs) end epoch_loss /= n_samples push!(history, epoch_loss) #verbosity if verbosity == 1 next!(meter) - elseif verbosity ==2 + elseif verbosity == 2 next!(meter) - verbose_laplace=true - println( "Loss is $(round(epoch_loss; sigdigits=4))") + verbose_laplace = true + println("Loss is $(round(epoch_loss; sigdigits=4))") end end - - # fit the Laplace model: - LaplaceRedux.fit!(la,zip(eachrow(X),y)) + LaplaceRedux.fit!(la, zip(eachrow(X), y)) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - report=[] + report = [] - return ((la,chain), history, report) + return ((la, chain), history, report) end - """ predict(model::LaplaceRegression, Xnew) @@ -387,48 +375,41 @@ Predict the output for new input data using a Laplace regression model. - The predicted output for the new input data. """ -function MLJFlux.predict(model::LaplaceRegression,fitresult, Xnew) +function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) + Xnew = MLJBase.matrix(Xnew) - Xnew = MLJBase.matrix(Xnew) - - la= fitresult[1] + la = fitresult[1] #convert in a vector of vectors because MLJ ask to do so - X_vec= [Xnew[i,:] for i in 1:size(Xnew, 1)] + X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] #inizialize output vector yhat - yhat=[] + yhat = [] # Predict using Laplace and collect the predictions yhat = [glm_predictive_distribution(la, x_vec) for x_vec in X_vec] #predictions = [] #for row in eachrow(yhat) - #mean_val = Float64(row[1][1][1]) - #std_val = sqrt(Float64(row[1][2][1])) - # Append a Normal distribution: - #push!(predictions, Normal(mean_val, std_val)) + #mean_val = Float64(row[1][1][1]) + #std_val = sqrt(Float64(row[1][2][1])) + # Append a Normal distribution: + #push!(predictions, Normal(mean_val, std_val)) #end - - return yhat + return yhat end - - # Then for each model, MLJBase.metadata_model( LaplaceClassification; - input=Union{ + input=Union{ AbstractMatrix{MLJBase.Finite}, MLJBase.Table(MLJBase.Finite), AbstractMatrix{MLJBase.Continuous}, MLJBase.Table(MLJBase.Continuous), MLJBase.Table{AbstractVector{MLJBase.Continuous}}, - MLJBase.Table{AbstractVector{MLJBase.Finite}} - }, - target=Union{ - AbstractArray{MLJBase.Finite}, - AbstractArray{MLJBase.Continuous}, + MLJBase.Table{AbstractVector{MLJBase.Finite}}, }, + target=Union{AbstractArray{MLJBase.Finite},AbstractArray{MLJBase.Continuous}}, path="MLJFlux.LaplaceClassification", ) # Then for each model, @@ -440,12 +421,8 @@ MLJBase.metadata_model( AbstractMatrix{MLJBase.Finite}, MLJBase.Table(MLJBase.Finite), MLJBase.Table{AbstractVector{MLJBase.Continuous}}, - MLJBase.Table{AbstractVector{MLJBase.Finite}} - }, - target=Union{ - AbstractArray{MLJBase.Finite}, - AbstractArray{MLJBase.Continuous}, + MLJBase.Table{AbstractVector{MLJBase.Finite}}, }, + target=Union{AbstractArray{MLJBase.Finite},AbstractArray{MLJBase.Continuous}}, path="MLJFlux.LaplaceRegression", ) - \ No newline at end of file diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index 7dbf0df..b39a9cf 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -15,7 +15,7 @@ function basictest_regression(X, y, builder, optimiser, threshold) builder=builder, optimiser=optimiser, acceleration=CPUThreads(), - loss= Flux.Losses.mse, + loss=Flux.Losses.mse, rng=stable_rng, lambda=-1.0, alpha=-1.0, @@ -23,7 +23,7 @@ function basictest_regression(X, y, builder, optimiser, threshold) batch_size=-1, subset_of_weights=:incorrect, hessian_structure=:incorrect, - backend=:incorrect + backend=:incorrect, ) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) @@ -53,7 +53,8 @@ function basictest_regression(X, y, builder, optimiser, threshold) # start fresh with small epochs: model = LaplaceRegression(; - builder=builder, optimiser=optimiser, epochs=2, acceleration=CPU1(), rng=stable_rng) + builder=builder, optimiser=optimiser, epochs=2, acceleration=CPU1(), rng=stable_rng + ) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) @@ -92,10 +93,6 @@ optimizer = Flux.Optimise.Adam(0.03) @test basictest_regression(X, ycont, builder, optimizer, 0.9) - - - - function basictest_classification(X, y, builder, optimiser, threshold) optimiser = deepcopy(optimiser) @@ -104,7 +101,7 @@ function basictest_classification(X, y, builder, optimiser, threshold) model = LaplaceClassification(; builder=builder, optimiser=optimiser, - loss= Flux.crossentropy, + loss=Flux.crossentropy, epochs=-1, batch_size=-1, lambda=-1.0, @@ -144,7 +141,8 @@ function basictest_classification(X, y, builder, optimiser, threshold) # start fresh with small epochs: model = LaplaceRegression(; - builder=builder, optimiser=optimiser, epochs=2, acceleration=CPU1(), rng=stable_rng) + builder=builder, optimiser=optimiser, epochs=2, acceleration=CPU1(), rng=stable_rng + ) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) @@ -193,6 +191,6 @@ y = categorical( builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) optimizer = Flux.Optimise.Adam(0.03) -y_onehot= transpose(unique(y) .== permutedims(y)) +y_onehot = transpose(unique(y) .== permutedims(y)) @test basictest_classification(X, y_onehot, builder, optimizer, 0.9) From 66b09a99984815aadb86c254a8487e15458fb2a1 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Sat, 15 Jun 2024 09:46:50 +0200 Subject: [PATCH 13/32] formatter; explicit imports; fixed error (was passing on one-hot encoded y instead of categorical) --- Project.toml | 2 +- dev/notebooks/mlj-interfacing/mlj.ipynb | 8 +- docs/Manifest.toml | 436 ++++++++---------- docs/Project.toml | 1 - src/baselaplace/predicting.jl | 2 +- src/mlj_flux.jl | 12 +- test/Manifest.toml | 560 +++++++++--------------- test/Project.toml | 1 - test/mlj_flux_interfacing.jl | 7 +- 9 files changed, 393 insertions(+), 636 deletions(-) diff --git a/Project.toml b/Project.toml index 729a294..068dbb4 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,7 @@ ComputationalResources = "0.3.2" Distributions = "0.25.109" Flux = "0.12, 0.13, 0.14" LinearAlgebra = "1.6, 1.7, 1.8, 1.9, 1.10" -MLJBase = "1.4.0" +MLJBase = "0, 1.4.0" MLJFlux = "0.2.10, 0.3, 0.4" MLJModelInterface = "1.8.0" MLUtils = "0.4.3" diff --git a/dev/notebooks/mlj-interfacing/mlj.ipynb b/dev/notebooks/mlj-interfacing/mlj.ipynb index 614d870..7cb0bfc 100644 --- a/dev/notebooks/mlj-interfacing/mlj.ipynb +++ b/dev/notebooks/mlj-interfacing/mlj.ipynb @@ -74,7 +74,7 @@ "│ \n", "│ In general, data in `machine(model, data...)` is expected to satisfy\n", "│ \n", - "│ scitype(data) <: MLJ.fit_data_scitype(model)\n", + "│ scitype(data) <: MLJBase.fit_data_scitype(model)\n", "│ \n", "│ In the present case:\n", "│ \n", @@ -194,7 +194,7 @@ } ], "source": [ - "MLJ.fit!(mach)" + "MLJBase.fit!(mach)" ] }, { @@ -214,7 +214,7 @@ } ], "source": [ - "MLJ.predict(mach, MLJBase.table(rand(Float32, 1, 4)))" + "MLJBase.predict(mach, MLJBase.table(rand(Float32, 1, 4)))" ] }, { @@ -323,7 +323,7 @@ } ], "source": [ - "MLJ.predict(mach, X)" + "MLJBase.predict(mach, X)" ] }, { diff --git a/docs/Manifest.toml b/docs/Manifest.toml index bdd743a..1479860 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,20 +1,14 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.2" +julia_version = "1.10.3" manifest_format = "2.0" -project_hash = "89145376c7506aacdd7535aa1062567925d9d8fb" +project_hash = "3a38b320d6e28531e4489912be325157cb3ab4b9" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9" version = "0.0.1" -[[deps.ARFFFiles]] -deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] -git-tree-sha1 = "e8c8e0a2be6eb4f56b1672e46004463033daa409" -uuid = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8" -version = "1.4.1" - [[deps.AbstractFFTs]] deps = ["LinearAlgebra"] git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" @@ -42,10 +36,10 @@ weakdeps = ["StaticArrays"] AdaptStaticArraysExt = "StaticArrays" [[deps.AliasTables]] -deps = ["Random"] -git-tree-sha1 = "07591db28451b3e45f4c0088a2d5e986ae5aa92d" +deps = ["PtrArrays", "Random"] +git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff" uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" -version = "1.1.1" +version = "1.1.3" [[deps.ArgCheck]] git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" @@ -140,32 +134,37 @@ version = "0.10.14" [[deps.CUDA]] deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"] -git-tree-sha1 = "4e33522a036b39fc6f5cb7447ae3b28eb8fbe99b" +git-tree-sha1 = "6e945e876652f2003e6ca74e19a3c45017d3e9f6" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "5.3.3" -weakdeps = ["ChainRulesCore", "SpecialFunctions"] +version = "5.4.2" [deps.CUDA.extensions] ChainRulesCoreExt = "ChainRulesCore" + EnzymeCoreExt = "EnzymeCore" SpecialFunctionsExt = "SpecialFunctions" + [deps.CUDA.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + [[deps.CUDA_Driver_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "dc172b558adbf17952001e15cf0d6364e6d78c2f" +git-tree-sha1 = "c48f9da18efd43b6b7adb7ee1f93fe5f2926c339" uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" -version = "0.8.1+0" +version = "0.9.0+0" [[deps.CUDA_Runtime_Discovery]] deps = ["Libdl"] -git-tree-sha1 = "38f830504358e9972d2a0c3e5d51cb865e0733df" +git-tree-sha1 = "5db9da5fdeaa708c22ba86b82c49528f402497f2" uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" -version = "0.2.4" +version = "0.3.3" [[deps.CUDA_Runtime_jll]] deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "4ca7d6d92075906c2ce871ea8bba971fff20d00c" +git-tree-sha1 = "bcba305388e16aa5c879e896726db9e71b4942c6" uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -version = "0.12.1+0" +version = "0.14.0+1" [[deps.CUDNN_jll]] deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] @@ -175,9 +174,9 @@ version = "9.0.0+1" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a4c43f59baa34011e303e76f5c8c91bf58415aaf" +git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.18.0+1" +version = "1.18.0+2" [[deps.Calculus]] deps = ["LinearAlgebra"] @@ -217,15 +216,15 @@ version = "0.1.15" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "e7d1016142a71c980309114ee30a3e4f870902f4" +git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.65.0" +version = "1.69.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" +git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.23.0" +version = "1.24.0" weakdeps = ["SparseArrays"] [deps.ChainRulesCore.extensions] @@ -239,9 +238,9 @@ version = "0.7.4" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "67c1f244b991cad9b0aa4b7540fb758c2488b129" +git-tree-sha1 = "4b270d6465eb21ae89b732182c20dc165f8bf9f2" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.24.0" +version = "3.25.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -261,14 +260,9 @@ weakdeps = ["SpecialFunctions"] [[deps.Colors]] deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" +git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.10" - -[[deps.Combinatorics]] -git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" -uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" -version = "1.0.2" +version = "0.12.11" [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] @@ -278,9 +272,9 @@ version = "0.3.0" [[deps.Compat]] deps = ["TOML", "UUIDs"] -git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" +git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.14.0" +version = "4.15.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -289,7 +283,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.0+0" +version = "1.1.1+0" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" @@ -314,10 +308,10 @@ uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" version = "2.4.1" [[deps.ConformalPrediction]] -deps = ["CategoricalArrays", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "StatsBase"] -git-tree-sha1 = "ee084331dcb2772dbd25a7c6afcaa664f36a0f04" +deps = ["CategoricalArrays", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "Statistics"] +git-tree-sha1 = "520277c269dd9169fb5ae4b6eff45eeae46d4900" uuid = "98bfc277-1877-43dc-819b-a3e38c30242f" -version = "0.1.6" +version = "0.1.5" [[deps.ConstructionBase]] deps = ["LinearAlgebra"] @@ -345,18 +339,18 @@ uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" version = "0.6.3" [[deps.CounterfactualExplanations]] -deps = ["CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "Flux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLJBase", "MLJDecisionTreeInterface", "MLUtils", "MultivariateStats", "PackageExtensionCompat", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "TaijaBase", "UUIDs"] -git-tree-sha1 = "0151b29e1c86f205a431c9e03accf9bc7af39dfa" +deps = ["CategoricalArrays", "ChainRulesCore", "DataFrames", "Distributions", "Flux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLJBase", "MLJDecisionTreeInterface", "MLUtils", "MultivariateStats", "PackageExtensionCompat", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "TaijaBase", "UUIDs"] +git-tree-sha1 = "8a68385b6852e9357889aea661536059bc8b6158" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "1.1.5" +version = "1.1.6" [deps.CounterfactualExplanations.extensions] - EvoTreesExt = "EvoTrees" + DecisionTreeExt = "DecisionTree" LaplaceReduxExt = "LaplaceRedux" NeuroTreeExt = "NeuroTreeModels" [deps.CounterfactualExplanations.weakdeps] - EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" + DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" NeuroTreeModels = "1db4e0a5-a364-4b0c-897c-2bd5a4a3a1f2" @@ -437,9 +431,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "22c595ca4146c07b16bcf9c8bea86f731f7109d2" +git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.108" +version = "0.25.109" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" @@ -459,9 +453,9 @@ version = "0.9.3" [[deps.Documenter]] deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "CodecZlib", "Dates", "DocStringExtensions", "Downloads", "Git", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "TOML", "Test", "Unicode"] -git-tree-sha1 = "f15a91e6e3919055efa4f206f942a73fedf5dfe6" +git-tree-sha1 = "5461b2a67beb9089980e2f8f25145186b6d34f91" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "1.4.0" +version = "1.4.1" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] @@ -474,12 +468,6 @@ git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" version = "0.6.8" -[[deps.EarlyStopping]] -deps = ["Dates", "Statistics"] -git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6" -uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" -version = "0.3.0" - [[deps.EpollShim_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643" @@ -494,9 +482,9 @@ version = "0.1.10" [[deps.Expat_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "4558ab818dcceaab612d1bb8c19cee87eda2b83c" +git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7" uuid = "2e619515-83b5-522b-bb60-26c02a35a201" -version = "2.5.0+0" +version = "2.6.2+0" [[deps.ExprTools]] git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" @@ -544,9 +532,9 @@ uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] deps = ["LinearAlgebra"] -git-tree-sha1 = "57f08d5665e76397e96b168f9acc12ab17c84a68" +git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.10.2" +version = "1.11.0" weakdeps = ["PDMats", "SparseArrays", "Statistics"] [deps.FillArrays.extensions] @@ -556,9 +544,9 @@ weakdeps = ["PDMats", "SparseArrays", "Statistics"] [[deps.FixedPointNumbers]] deps = ["Statistics"] -git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" +git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.4" +version = "0.8.5" [[deps.Flux]] deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] @@ -579,10 +567,10 @@ version = "0.14.15" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.Fontconfig_jll]] -deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Pkg", "Zlib_jll"] -git-tree-sha1 = "21efd19106a55620a188615da6d3d06cd7f6ee03" +deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Zlib_jll"] +git-tree-sha1 = "db16beca600632c95fc8aca29890d83788dd8b23" uuid = "a3f928ae-7b40-5064-980b-68af3947d34b" -version = "2.13.93+0" +version = "2.13.96+0" [[deps.Format]] git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc" @@ -601,21 +589,21 @@ weakdeps = ["StaticArrays"] [[deps.FreeType2_jll]] deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "d8db6a5a2fe1381c1ea4ef2cab7c69c2de7f9ea0" +git-tree-sha1 = "5c1d8ae0efc6c2e7b1fc502cbe25def8f661b7bc" uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" -version = "2.13.1+0" +version = "2.13.2+0" [[deps.FriBidi_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "aa31987c2ba8704e23c6c8ba8a4f769d5d7e4f91" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1ed150b39aebcc805c26b93a8d0122c940f64ce2" uuid = "559328eb-81f9-559d-9380-de523a88c83c" -version = "1.0.10+0" +version = "1.0.14+0" [[deps.Functors]] deps = ["LinearAlgebra"] -git-tree-sha1 = "d3e63d9fa13f8eaa2f06f64949e2afc593ff52c2" +git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.10" +version = "0.4.11" [[deps.Future]] deps = ["Random"] @@ -629,9 +617,9 @@ version = "3.3.9+0" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "68e8ff56a4a355a85d2784b94614491f8c900cde" +git-tree-sha1 = "c154546e322a9c73364e8a60430b0f79b812d320" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.1.0" +version = "10.2.0" [[deps.GPUArraysCore]] deps = ["Adapt"] @@ -641,21 +629,21 @@ version = "0.1.6" [[deps.GPUCompiler]] deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "1600477fba37c9fc067b9be21f5e8101f24a8865" +git-tree-sha1 = "518ebd058c9895de468a8c255797b0c53fdb44dd" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.26.4" +version = "0.26.5" [[deps.GR]] -deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"] -git-tree-sha1 = "3437ade7073682993e092ca570ad68a2aba26983" +deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"] +git-tree-sha1 = "ddda044ca260ee324c5fc07edb6d7cf3f0b9c350" uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.73.3" +version = "0.73.5" [[deps.GR_jll]] deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a96d5c713e6aa28c242b0d25c1347e258d6541ab" +git-tree-sha1 = "278e5e0f820178e8a26df3184fcb2280717c79b1" uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.73.3+0" +version = "0.73.5+0" [[deps.Gettext_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] @@ -677,9 +665,9 @@ version = "2.44.0+2" [[deps.Glib_jll]] deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "359a1ba2e320790ddbe4ee8b4d54a305c0ea2aff" +git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.80.0+0" +version = "2.80.2+0" [[deps.Graphite2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -694,9 +682,9 @@ version = "1.0.2" [[deps.HTTP]] deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "2c3ec1f90bb4a8f7beafb0cffea8a4c3f4e636ab" +git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.10.6" +version = "1.10.8" [[deps.HarfBuzz_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] @@ -712,15 +700,15 @@ version = "0.3.23" [[deps.IOCapture]] deps = ["Logging", "Random"] -git-tree-sha1 = "8b72179abc660bfab5e28472e019392b97d0985c" +git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770" uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.4" +version = "0.2.5" [[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "d05027a62b4c9a2223820a9fdeae1110ad3946a5" +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.13" +version = "0.4.14" [[deps.InitialValues]] git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" @@ -747,22 +735,16 @@ git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" version = "0.2.2" -[[deps.IterationControl]] -deps = ["EarlyStopping", "InteractiveUtils"] -git-tree-sha1 = "e663925ebc3d93c1150a7570d114f9ea2f664726" -uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" -version = "0.5.4" - [[deps.IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" [[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Printf", "Reexport", "Requires", "TranscodingStreams", "UUIDs"] -git-tree-sha1 = "5ea6acdd53a51d897672edb694e3cc2912f3f8a7" +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] +git-tree-sha1 = "bdbe8222d2f5703ad6a7019277d149ec6d78c301" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.46" +version = "0.4.48" [[deps.JLFzf]] deps = ["Pipe", "REPL", "Random", "fzf_jll"] @@ -784,9 +766,9 @@ version = "0.21.4" [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "3336abae9a713d2210bb57ab484b1e065edd7d23" +git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "3.0.2+0" +version = "3.0.3+0" [[deps.JuliaNVTXCallbacks_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -802,9 +784,9 @@ version = "0.2.4" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "ed7167240f40e62d97c1f5f7735dea6de3cc5c49" +git-tree-sha1 = "8e5a339882cc401688d79b811d923a38ba77d50a" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.18" +version = "0.9.20" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" @@ -813,10 +795,10 @@ version = "0.9.18" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [[deps.LAME_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "170b660facf5df5de098d866564877e119141cbd" uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" -version = "3.100.1+0" +version = "3.100.2+0" [[deps.LERC_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -826,9 +808,9 @@ version = "3.0.0+1" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a" +git-tree-sha1 = "389aea28d882a40b5e1747069af71bdbd47a1cae" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.6.3" +version = "7.2.1" weakdeps = ["BFloat16s"] [deps.LLVM.extensions] @@ -852,10 +834,10 @@ uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" version = "15.0.7+0" [[deps.LZO_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "e5b909bcf985c5e2605737d2ce278ed791b89be6" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d" uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" -version = "2.10.1+0" +version = "2.10.2+0" [[deps.LaTeXStrings]] git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" @@ -863,10 +845,10 @@ uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" version = "1.3.1" [[deps.LaplaceRedux]] -deps = ["ChainRulesCore", "Compat", "ComputationalResources", "Flux", "LinearAlgebra", "MLJFlux", "MLJModelInterface", "MLUtils", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] +deps = ["ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] path = ".." uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" -version = "0.2.0" +version = "0.2.1" [[deps.Latexify]] deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] @@ -882,12 +864,6 @@ version = "0.16.3" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8" -[[deps.LatinHypercubeSampling]] -deps = ["Random", "StableRNGs", "StatsBase", "Test"] -git-tree-sha1 = "825289d43c753c7f1bf9bed334c253e9913997f8" -uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" -version = "1.9.0" - [[deps.LazilyInitializedFields]] git-tree-sha1 = "8f7f3cabab0fd1800699663533b6d5cb3fc0e612" uuid = "0e77f7df-68c5-4e49-93ce-4cd80f5598bf" @@ -943,10 +919,10 @@ uuid = "7e76a0d4-f3c7-5321-8279-8d96eeed0f29" version = "1.6.0+0" [[deps.Libgpg_error_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "c333716e46366857753e273ce6a69ee0945a6db9" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "fbb1f2bef882392312feb1ede3615ddc1e9b99ed" uuid = "7add5ba3-2f88-524e-9cd5-f83b8a55f7b8" -version = "1.42.0+0" +version = "1.49.0+0" [[deps.Libiconv_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -956,9 +932,9 @@ version = "1.17.0+0" [[deps.Libmount_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "4b683b19157282f50bfd5dcaa2efe5295814ea22" +git-tree-sha1 = "0c4f9c4f1a50d8f35048fa0532dabbadf702f81e" uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9" -version = "2.40.0+0" +version = "2.40.1+0" [[deps.Libtiff_jll]] deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"] @@ -968,9 +944,9 @@ version = "4.5.1+1" [[deps.Libuuid_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "27fd5cc10be85658cacfe11bb81bee216af13eda" +git-tree-sha1 = "5ee6203157c120d79034c748a2acba45b82b8807" uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" -version = "2.40.0+0" +version = "2.40.1+0" [[deps.LinearAlgebra]] deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] @@ -978,9 +954,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" +git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.27" +version = "0.3.28" [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -1011,18 +987,6 @@ weakdeps = ["CategoricalArrays"] [deps.LossFunctions.extensions] LossFunctionsCategoricalArraysExt = "CategoricalArrays" -[[deps.MLFlowClient]] -deps = ["Dates", "FilePathsBase", "HTTP", "JSON", "ShowCases", "URIs", "UUIDs"] -git-tree-sha1 = "5cc2a5453856e79f4772269fbe6b19fcdcba391a" -uuid = "64a0f543-368b-4a9a-827a-e71edb2a0b83" -version = "0.4.7" - -[[deps.MLJ]] -deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlow", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "193f1f1ac77d91eabe1ac81ff48646b378270eef" -uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" -version = "0.19.5" - [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] git-tree-sha1 = "0b7307d1a7214ec3c0ba305571e713f9492ea984" @@ -1035,47 +999,17 @@ git-tree-sha1 = "90ef4d3b6cacec631c57cc034e1e61b4aa0ce511" uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661" version = "0.4.2" -[[deps.MLJEnsembles]] -deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"] -git-tree-sha1 = "95b306ef8108067d26dfde9ff3457d59911cc0d6" -uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" -version = "0.3.3" - -[[deps.MLJFlow]] -deps = ["MLFlowClient", "MLJBase", "MLJModelInterface"] -git-tree-sha1 = "bceeeb648c9aa2fc6f65f957c688b164d30f2905" -uuid = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f" -version = "0.1.1" - [[deps.MLJFlux]] deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"] git-tree-sha1 = "72935b7de07a7f6b72fd49ecc7898dac79248d46" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" version = "0.4.0" -[[deps.MLJIteration]] -deps = ["IterationControl", "MLJBase", "Random", "Serialization"] -git-tree-sha1 = "be6d5c71ab499a59e82d65e00a89ceba8732fcd5" -uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" -version = "0.5.1" - [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] -git-tree-sha1 = "d2a45e1b5998ba3fdfb6cfe0c81096d4c7fb40e7" +git-tree-sha1 = "88ef480f46e0506143681b3fb14d86742f3cecb1" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -version = "1.9.6" - -[[deps.MLJModels]] -deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "410da88e0e6ece5467293d2c76b51b7c6df7d072" -uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" -version = "0.16.17" - -[[deps.MLJTuning]] -deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase"] -git-tree-sha1 = "02688098bd77827b64ed8ad747c14f715f98cfc4" -uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" -version = "0.7.4" +version = "1.10.0" [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" @@ -1156,16 +1090,16 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.1.10" [[deps.MultivariateStats]] -deps = ["Arpack", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] -git-tree-sha1 = "68bf5103e002c44adfd71fea6bd770b3f0586843" +deps = ["Arpack", "Distributions", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] +git-tree-sha1 = "816620e3aac93e5b5359e4fdaf23ca4525b00ddf" uuid = "6f286f6a-111f-5878-ab1e-185364afe411" -version = "0.10.2" +version = "0.10.3" [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "5055845dd316575ae2fc1f6dcb3545ff15fe547a" +git-tree-sha1 = "3d4617f943afe6410206a5294a95948c8d1b35bd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.14" +version = "0.9.17" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" @@ -1216,9 +1150,9 @@ version = "0.2.3" [[deps.NearestNeighbors]] deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "ded64ff6d4fdd1cb68dfcbb818c69e144a5b2e4c" +git-tree-sha1 = "e4a9d37f0ee694da969def1f0dd4654642dfb51c" uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.16" +version = "0.4.17" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" @@ -1246,12 +1180,6 @@ deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" version = "0.8.1+2" -[[deps.OpenML]] -deps = ["ARFFFiles", "HTTP", "JSON", "Markdown", "Pkg", "Scratch"] -git-tree-sha1 = "6efb039ae888699d5a74fb593f6f3e10c7193e33" -uuid = "8b6db2d4-7670-4922-a472-f9537c81ab66" -version = "0.3.1" - [[deps.OpenSSL]] deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" @@ -1260,9 +1188,9 @@ version = "1.4.3" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "3da7367955dcc5c54c1ba4d402ccdc09a1a3e046" +git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.13+1" +version = "3.0.14+0" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -1329,9 +1257,9 @@ version = "1.3.0" [[deps.Pixman_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "Libdl"] -git-tree-sha1 = "64779bc4c9784fee475689a1752ef4d5747c5e87" +git-tree-sha1 = "35621f10a7531bc8fa58f74610b1bfb70a3cfc6b" uuid = "30392449-352a-5448-841d-b1acce4e97dc" -version = "0.42.2+0" +version = "0.43.4+0" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] @@ -1340,9 +1268,9 @@ version = "1.10.0" [[deps.PlotThemes]] deps = ["PlotUtils", "Statistics"] -git-tree-sha1 = "1f03a2d339f42dca4a4da149c7e15e9b896ad899" +git-tree-sha1 = "6e55c6841ce3411ccb3457ee52fc48cb698d6fb0" uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" -version = "3.1.0" +version = "3.2.0" [[deps.PlotUtils]] deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] @@ -1393,16 +1321,11 @@ git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" version = "0.2.0" -[[deps.PrettyPrinting]] -git-tree-sha1 = "142ee93724a9c5d04d78df7006670a93ed1b244e" -uuid = "54e16d92-306c-5ea0-a30b-337be88ac337" -version = "0.4.2" - [[deps.PrettyTables]] deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "88b895d13d53b5577fd53379d913b9ab9ac82660" +git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.1" +version = "2.3.2" [[deps.Printf]] deps = ["Unicode"] @@ -1420,6 +1343,11 @@ git-tree-sha1 = "763a8ceb07833dd51bb9e3bbca372de32c0605ad" uuid = "92933f4c-e287-5a05-a399-4b506db050ca" version = "1.10.0" +[[deps.PtrArrays]] +git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" +uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" +version = "1.2.0" + [[deps.Qt6Base_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"] git-tree-sha1 = "37b7bb7aabf9a085e0044307e1717436117f2b3b" @@ -1512,10 +1440,10 @@ uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" version = "0.7.1" [[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.0+0" +version = "0.4.2+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -1546,9 +1474,9 @@ version = "1.2.1" [[deps.SentinelArrays]] deps = ["Dates", "Random"] -git-tree-sha1 = "0e7508ff27ba32f26cd459474ca2ede1bc10991f" +git-tree-sha1 = "90b4f68892337554d31cdcdbe19e48989f26c7e6" uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.1" +version = "1.4.3" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -1603,9 +1531,9 @@ version = "0.1.2" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.1" +version = "2.4.0" weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] @@ -1617,17 +1545,11 @@ git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" -[[deps.StableRNGs]] -deps = ["Random"] -git-tree-sha1 = "83e6cce8324d49dfaf9ef059227f91ed4441a8e5" -uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" -version = "1.0.2" - [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" +git-tree-sha1 = "6e00379a24597be4ae1ee6b2d882e15392040132" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.3" +version = "1.9.5" weakdeps = ["ChainRulesCore", "Statistics"] [deps.StaticArrays.extensions] @@ -1635,15 +1557,15 @@ weakdeps = ["ChainRulesCore", "Statistics"] StaticArraysStatisticsExt = "Statistics" [[deps.StaticArraysCore]] -git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.2" +version = "1.4.3" [[deps.StatisticalTraits]] deps = ["ScientificTypesBase"] -git-tree-sha1 = "30b9236691858e13f167ce829490a68e1a597782" +git-tree-sha1 = "983c41a0ddd6c19f5607ca87271d7c7620ab5d50" uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9" -version = "3.2.0" +version = "3.3.0" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -1658,9 +1580,9 @@ version = "1.7.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.21" +version = "0.34.3" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] @@ -1728,15 +1650,16 @@ uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" version = "1.11.1" [[deps.TaijaBase]] -git-tree-sha1 = "f7b57ab4a5c746d83e8883a38454839d0fab8b6f" +deps = ["CategoricalArrays", "Distributions", "Flux", "MLUtils", "Optimisers", "StatsBase", "Tables"] +git-tree-sha1 = "1c80c4472c6ab6e8c9fa544a22d907295b388dd0" uuid = "10284c91-9f28-4c9a-abbf-ee43576dfff6" -version = "1.0.2" +version = "1.2.2" [[deps.TaijaPlotting]] deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "MLUtils", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "Plots"] -git-tree-sha1 = "e408d0162419206852142c6a08ba72a924379065" +git-tree-sha1 = "2a4fcdf2abd5533d6d24a97ce5e89327391b2dc1" uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" -version = "1.1.1" +version = "1.1.2" [[deps.Tar]] deps = ["ArgTools", "SHA"] @@ -1755,9 +1678,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.TimeZones]] deps = ["Dates", "Downloads", "InlineStrings", "Mocking", "Printf", "Scratch", "TZJData", "Unicode", "p7zip_jll"] -git-tree-sha1 = "96793c9316d6c9f9be4641f2e5b1319a205e6f27" +git-tree-sha1 = "a6ae8d7a27940c33624f8c7bde5528de21ba730d" uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53" -version = "1.15.0" +version = "1.17.0" weakdeps = ["RecipesBase"] [deps.TimeZones.extensions] @@ -1765,14 +1688,14 @@ weakdeps = ["RecipesBase"] [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] -git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" +git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.23" +version = "0.5.24" [[deps.TranscodingStreams]] -git-tree-sha1 = "71509f04d045ec714c4748c785a59045c3736349" +git-tree-sha1 = "a947ea21087caba0a798c5e494d0bb78e3a1a3a0" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.10.7" +version = "0.10.9" weakdeps = ["Random", "Test"] [deps.TranscodingStreams.extensions] @@ -1798,6 +1721,11 @@ version = "0.4.80" OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" +[[deps.Trapz]] +git-tree-sha1 = "79eb0ed763084a3e7de81fe1838379ac6a23b6a0" +uuid = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1" +version = "2.0.3" + [[deps.Tullio]] deps = ["DiffRules", "LinearAlgebra", "Requires"] git-tree-sha1 = "6d476962ba4e435d7f4101a403b1d3d72afe72f3" @@ -1841,9 +1769,9 @@ version = "0.4.1" [[deps.Unitful]] deps = ["Dates", "LinearAlgebra", "Random"] -git-tree-sha1 = "3c793be6df9dd77a0cf49d80984ef9ff996948fa" +git-tree-sha1 = "dd260903fdabea27d9b6021689b3cd5401a57748" uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.19.0" +version = "1.20.0" [deps.Unitful.extensions] ConstructionBaseUnitfulExt = "ConstructionBase" @@ -1866,9 +1794,9 @@ version = "0.2.1" [[deps.UnsafeAtomicsLLVM]] deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" +git-tree-sha1 = "d9f5962fecd5ccece07db1ff006fb0b5271bdfdd" uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.3" +version = "0.1.4" [[deps.Unzip]] git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78" @@ -1906,9 +1834,9 @@ version = "1.6.1" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "532e22cf7be8462035d092ff21fada7527e2c488" +git-tree-sha1 = "52ff2af32e591541550bd753c0da8b9bc92bb9d9" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.12.6+0" +version = "2.12.7+0" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "Pkg", "XML2_jll", "Zlib_jll"] @@ -1923,16 +1851,16 @@ uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800" version = "5.4.6+0" [[deps.Xorg_libICE_jll]] -deps = ["Libdl", "Pkg"] -git-tree-sha1 = "e5becd4411063bdcac16be8b66fc2f9f6f1e8fe5" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "326b4fea307b0b39892b3e85fa451692eda8d46c" uuid = "f67eecfb-183a-506d-b269-f58e52b52d7c" -version = "1.0.10+1" +version = "1.1.1+0" [[deps.Xorg_libSM_jll]] -deps = ["Libdl", "Pkg", "Xorg_libICE_jll"] -git-tree-sha1 = "4a9d9e4c180e1e8119b5ffc224a7b59d3a7f7e18" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libICE_jll"] +git-tree-sha1 = "3796722887072218eabafb494a13c963209754ce" uuid = "c834827a-8449-5923-a945-d239c165b7dd" -version = "1.2.3+0" +version = "1.2.4+0" [[deps.Xorg_libX11_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] @@ -1959,10 +1887,10 @@ uuid = "a3789734-cfe1-5b06-b2d0-1dd0d9d62d05" version = "1.1.4+0" [[deps.Xorg_libXext_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] -git-tree-sha1 = "b7c0aa8c376b31e4852b360222848637f481f8c3" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] +git-tree-sha1 = "d2d1a5c49fae4ba39983f63de6afcbea47194e85" uuid = "1082639a-0dae-5f34-9b06-72781eeb8cb3" -version = "1.3.4+4" +version = "1.3.6+0" [[deps.Xorg_libXfixes_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] @@ -1989,10 +1917,10 @@ uuid = "ec84b674-ba8e-5d96-8ba1-2a689ba10484" version = "1.5.2+4" [[deps.Xorg_libXrender_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] -git-tree-sha1 = "19560f30fd49f4d4efbe7002a1037f8c43d43b96" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] +git-tree-sha1 = "47e45cd78224c53109495b3e324df0c37bb61fbe" uuid = "ea2f1a96-1ddc-540d-b46f-429655e07cfa" -version = "0.9.10+4" +version = "0.9.11+0" [[deps.Xorg_libpthread_stubs_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -2079,9 +2007,9 @@ version = "1.5.6+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "4ddb4470e47b0094c93055a3bcae799165cc68f1" +git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.69" +version = "0.6.70" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" @@ -2101,9 +2029,9 @@ version = "0.2.5" [[deps.cuDNN]] deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"] -git-tree-sha1 = "1f6a185a8da9bbbc20134b7b935981f70c9b26ad" +git-tree-sha1 = "4909e87d6d62c29a897d54d9001c63932e41cb0e" uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" -version = "1.3.1" +version = "1.3.2" [[deps.eudev_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "gperf_jll"] @@ -2124,10 +2052,10 @@ uuid = "1a1c6b14-54f6-533d-8383-74cd7377aa70" version = "3.1.1+0" [[deps.libaom_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "3a2ea60308f0996d26f1e5354e10c24e9ef905d4" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1827acba325fdcdf1d2647fc8d5301dd9ba43a9d" uuid = "a4ae2306-e953-59d6-aa16-d00cac43593b" -version = "3.4.0+0" +version = "3.9.0+0" [[deps.libass_jll]] deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] diff --git a/docs/Project.toml b/docs/Project.toml index 9459681..b3335de 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,7 +5,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" -MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index 674c60f..e431fd2 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -1,4 +1,4 @@ -using Distributions +using Distributions: Distributions """ functional_variance(la::AbstractLaplace, 𝐉::AbstractArray) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index f502e49..789e8ca 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -3,15 +3,13 @@ using MLJFlux using ProgressMeter: Progress, next!, BarGlyphs using Random using Tables -using Distributions using LinearAlgebra using LaplaceRedux using ComputationalResources -using MLJBase -import MLJBase: @mlj_model, metadata_model, metadata_pkg +using MLJBase: MLJBase """ - @mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic + MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic A mutable struct representing a Laplace regression model that extends the `MLJFlux.MLJFluxProbabilistic` abstract type. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. @@ -35,7 +33,7 @@ The model is trained using the `fit!` method. The model is defined by the follow - `P₀`: the covariance matrix of the prior distribution. - `fit_prior_nsteps`: the number of steps used to fit the priors. """ -@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic +MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) optimiser = Flux.Optimise.Adam() loss = Flux.Losses.mse @@ -58,7 +56,7 @@ The model is trained using the `fit!` method. The model is defined by the follow end """ - @mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic + MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic A mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. @@ -83,7 +81,7 @@ A mutable struct representing a Laplace Classification model that extends the ML - `P₀`: the covariance matrix of the prior distribution. - `fit_prior_nsteps`: the number of steps used to fit the priors. """ -@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic +MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) finaliser = Flux.softmax optimiser = Flux.Optimise.Adam() diff --git a/test/Manifest.toml b/test/Manifest.toml index e86d802..32e8eb0 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -1,14 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.2" +julia_version = "1.10.3" manifest_format = "2.0" -project_hash = "f9d09bb4d70c764973f6ddf935071923791db309" - -[[deps.ARFFFiles]] -deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] -git-tree-sha1 = "e8c8e0a2be6eb4f56b1672e46004463033daa409" -uuid = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8" -version = "1.4.1" +project_hash = "fa6672850323ab23f77b8212aabdf7f033fa4213" [[deps.AbstractFFTs]] deps = ["LinearAlgebra"] @@ -36,6 +30,12 @@ weakdeps = ["StaticArrays"] [deps.Adapt.extensions] AdaptStaticArraysExt = "StaticArrays" +[[deps.AliasTables]] +deps = ["PtrArrays", "Random"] +git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff" +uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" +version = "1.1.3" + [[deps.Aqua]] deps = ["Compat", "Pkg", "Test"] git-tree-sha1 = "12e575f31a6f233ba2485ed86b9325b85df37c61" @@ -144,40 +144,11 @@ git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" version = "0.10.14" -[[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"] -git-tree-sha1 = "3dcab8a2c18ca319ea15a41d90e9528b8e93894a" -uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "5.3.0" -weakdeps = ["ChainRulesCore", "SpecialFunctions"] - - [deps.CUDA.extensions] - ChainRulesCoreExt = "ChainRulesCore" - SpecialFunctionsExt = "SpecialFunctions" - -[[deps.CUDA_Driver_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "dc172b558adbf17952001e15cf0d6364e6d78c2f" -uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" -version = "0.8.1+0" - -[[deps.CUDA_Runtime_Discovery]] -deps = ["Libdl"] -git-tree-sha1 = "38f830504358e9972d2a0c3e5d51cb865e0733df" -uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" -version = "0.2.4" - -[[deps.CUDA_Runtime_jll]] -deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "4ca7d6d92075906c2ce871ea8bba971fff20d00c" -uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -version = "0.12.1+0" - [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a4c43f59baa34011e303e76f5c8c91bf58415aaf" +git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.18.0+1" +version = "1.18.0+2" [[deps.Calculus]] deps = ["LinearAlgebra"] @@ -212,15 +183,15 @@ version = "0.1.15" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "4e42872be98fa3343c4f8458cbda8c5c6a6fa97c" +git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.63.0" +version = "1.69.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" +git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.23.0" +version = "1.24.0" weakdeps = ["SparseArrays"] [deps.ChainRulesCore.extensions] @@ -246,9 +217,9 @@ version = "0.7.4" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "67c1f244b991cad9b0aa4b7540fb758c2488b129" +git-tree-sha1 = "4b270d6465eb21ae89b732182c20dc165f8bf9f2" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.24.0" +version = "3.25.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -268,9 +239,9 @@ weakdeps = ["SpecialFunctions"] [[deps.Colors]] deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" +git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.10" +version = "0.12.11" [[deps.Combinatorics]] git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" @@ -285,9 +256,9 @@ version = "0.3.0" [[deps.Compat]] deps = ["TOML", "UUIDs"] -git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" +git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.14.0" +version = "4.15.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -296,7 +267,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.0+0" +version = "1.1.1+0" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" @@ -346,18 +317,18 @@ uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" version = "0.6.3" [[deps.CounterfactualExplanations]] -deps = ["CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "Flux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLJBase", "MLJDecisionTreeInterface", "MLUtils", "MultivariateStats", "PackageExtensionCompat", "Parameters", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "TaijaBase", "UUIDs"] -git-tree-sha1 = "651bb986e1676a84770ea69af8e6beb08791bb31" +deps = ["CategoricalArrays", "ChainRulesCore", "DataFrames", "Distributions", "Flux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLJBase", "MLJDecisionTreeInterface", "MLUtils", "MultivariateStats", "PackageExtensionCompat", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "TaijaBase", "UUIDs"] +git-tree-sha1 = "8a68385b6852e9357889aea661536059bc8b6158" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "1.1.1" +version = "1.1.6" [deps.CounterfactualExplanations.extensions] - EvoTreesExt = "EvoTrees" + DecisionTreeExt = "DecisionTree" LaplaceReduxExt = "LaplaceRedux" NeuroTreeExt = "NeuroTreeModels" [deps.CounterfactualExplanations.weakdeps] - EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" + DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" NeuroTreeModels = "1db4e0a5-a364-4b0c-897c-2bd5a4a3a1f2" @@ -385,9 +356,9 @@ version = "1.6.1" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "97d79461925cdb635ee32116978fc735b9463a39" +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.19" +version = "0.18.20" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -443,10 +414,10 @@ deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] -deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "7c302d7a5fec5214eb8a5a4c466dcf7a51fcf169" +deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] +git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.107" +version = "0.25.109" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" @@ -475,12 +446,6 @@ git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" version = "0.6.8" -[[deps.EarlyStopping]] -deps = ["Dates", "Statistics"] -git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6" -uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" -version = "0.3.0" - [[deps.EpollShim_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643" @@ -495,14 +460,9 @@ version = "0.1.10" [[deps.Expat_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "4558ab818dcceaab612d1bb8c19cee87eda2b83c" +git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7" uuid = "2e619515-83b5-522b-bb60-26c02a35a201" -version = "2.5.0+0" - -[[deps.ExprTools]] -git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" -uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.10" +version = "2.6.2+0" [[deps.FFMPEG]] deps = ["FFMPEG_jll"] @@ -545,9 +505,9 @@ uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] deps = ["LinearAlgebra"] -git-tree-sha1 = "bfe82a708416cf00b73a3198db0859c82f741558" +git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.10.0" +version = "1.11.0" weakdeps = ["PDMats", "SparseArrays", "Statistics"] [deps.FillArrays.extensions] @@ -557,9 +517,9 @@ weakdeps = ["PDMats", "SparseArrays", "Statistics"] [[deps.FixedPointNumbers]] deps = ["Statistics"] -git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" +git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.4" +version = "0.8.5" [[deps.Flux]] deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] @@ -580,10 +540,10 @@ version = "0.14.15" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.Fontconfig_jll]] -deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Pkg", "Zlib_jll"] -git-tree-sha1 = "21efd19106a55620a188615da6d3d06cd7f6ee03" +deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Zlib_jll"] +git-tree-sha1 = "db16beca600632c95fc8aca29890d83788dd8b23" uuid = "a3f928ae-7b40-5064-980b-68af3947d34b" -version = "2.13.93+0" +version = "2.13.96+0" [[deps.Format]] git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc" @@ -602,21 +562,21 @@ weakdeps = ["StaticArrays"] [[deps.FreeType2_jll]] deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "d8db6a5a2fe1381c1ea4ef2cab7c69c2de7f9ea0" +git-tree-sha1 = "5c1d8ae0efc6c2e7b1fc502cbe25def8f661b7bc" uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" -version = "2.13.1+0" +version = "2.13.2+0" [[deps.FriBidi_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "aa31987c2ba8704e23c6c8ba8a4f769d5d7e4f91" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1ed150b39aebcc805c26b93a8d0122c940f64ce2" uuid = "559328eb-81f9-559d-9380-de523a88c83c" -version = "1.0.10+0" +version = "1.0.14+0" [[deps.Functors]] deps = ["LinearAlgebra"] -git-tree-sha1 = "d3e63d9fa13f8eaa2f06f64949e2afc593ff52c2" +git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.10" +version = "0.4.11" [[deps.Future]] deps = ["Random"] @@ -630,9 +590,9 @@ version = "3.3.9+0" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "68e8ff56a4a355a85d2784b94614491f8c900cde" +git-tree-sha1 = "c154546e322a9c73364e8a60430b0f79b812d320" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.1.0" +version = "10.2.0" [[deps.GPUArraysCore]] deps = ["Adapt"] @@ -640,23 +600,17 @@ git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" version = "0.1.6" -[[deps.GPUCompiler]] -deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "1600477fba37c9fc067b9be21f5e8101f24a8865" -uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.26.4" - [[deps.GR]] -deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"] -git-tree-sha1 = "3437ade7073682993e092ca570ad68a2aba26983" +deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"] +git-tree-sha1 = "ddda044ca260ee324c5fc07edb6d7cf3f0b9c350" uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.73.3" +version = "0.73.5" [[deps.GR_jll]] deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a96d5c713e6aa28c242b0d25c1347e258d6541ab" +git-tree-sha1 = "278e5e0f820178e8a26df3184fcb2280717c79b1" uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.73.3+0" +version = "0.73.5+0" [[deps.GZip]] deps = ["Libdl", "Zlib_jll"] @@ -672,9 +626,9 @@ version = "0.21.0+0" [[deps.Glib_jll]] deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "359a1ba2e320790ddbe4ee8b4d54a305c0ea2aff" +git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.80.0+0" +version = "2.80.2+0" [[deps.Glob]] git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" @@ -705,16 +659,16 @@ version = "0.17.2" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" [[deps.HDF5_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] -git-tree-sha1 = "38c8874692d48d5440d5752d6c74b0c6b0b60739" +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] +git-tree-sha1 = "82a471768b513dc39e471540fdadc84ff80ff997" uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" -version = "1.14.2+1" +version = "1.14.3+3" [[deps.HTTP]] deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "8e59b47b9dc525b70550ca082ce85bcd7f5477cd" +git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.10.5" +version = "1.10.8" [[deps.HarfBuzz_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] @@ -735,10 +689,10 @@ uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" version = "0.3.23" [[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "5d8c5713f38f7bc029e26627b687710ba406d0dd" +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.12" +version = "0.4.14" [[deps.ImageBase]] deps = ["ImageCore", "Reexport"] @@ -789,22 +743,16 @@ git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" version = "0.2.2" -[[deps.IterationControl]] -deps = ["EarlyStopping", "InteractiveUtils"] -git-tree-sha1 = "e663925ebc3d93c1150a7570d114f9ea2f664726" -uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" -version = "0.5.4" - [[deps.IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" [[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Printf", "Reexport", "Requires", "TranscodingStreams", "UUIDs"] -git-tree-sha1 = "5ea6acdd53a51d897672edb694e3cc2912f3f8a7" +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] +git-tree-sha1 = "bdbe8222d2f5703ad6a7019277d149ec6d78c301" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.46" +version = "0.4.48" [[deps.JLFzf]] deps = ["Pipe", "REPL", "Random", "fzf_jll"] @@ -838,15 +786,9 @@ version = "1.14.0" [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "3336abae9a713d2210bb57ab484b1e065edd7d23" +git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "3.0.2+0" - -[[deps.JuliaNVTXCallbacks_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "af433a10f3942e882d3c671aacb203e006a5808f" -uuid = "9c1d0b0a-7046-5b2e-a33f-ea22f176ac7e" -version = "0.2.1+0" +version = "3.0.3+0" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -856,9 +798,9 @@ version = "0.2.4" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "ed7167240f40e62d97c1f5f7735dea6de3cc5c49" +git-tree-sha1 = "8e5a339882cc401688d79b811d923a38ba77d50a" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.18" +version = "0.9.20" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" @@ -867,10 +809,10 @@ version = "0.9.18" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [[deps.LAME_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "170b660facf5df5de098d866564877e119141cbd" uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" -version = "3.100.1+0" +version = "3.100.2+0" [[deps.LERC_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -880,9 +822,9 @@ version = "3.0.0+1" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a" +git-tree-sha1 = "389aea28d882a40b5e1747069af71bdbd47a1cae" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.6.3" +version = "7.2.1" weakdeps = ["BFloat16s"] [deps.LLVM.extensions] @@ -894,11 +836,6 @@ git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" version = "0.0.29+0" -[[deps.LLVMLoopInfo]] -git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea" -uuid = "8b046642-f1f6-4319-8d3c-209ddc03c586" -version = "1.0.0" - [[deps.LLVMOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713" @@ -906,10 +843,10 @@ uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" version = "15.0.7+0" [[deps.LZO_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "e5b909bcf985c5e2605737d2ce278ed791b89be6" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d" uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" -version = "2.10.1+0" +version = "2.10.2+0" [[deps.LaTeXStrings]] git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" @@ -918,9 +855,9 @@ version = "1.3.1" [[deps.Latexify]] deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] -git-tree-sha1 = "cad560042a7cc108f5a4c24ea1431a9221f22c1b" +git-tree-sha1 = "e0b5cd21dc1b44ec6e64f351976f961e6f31d6c4" uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" -version = "0.16.2" +version = "0.16.3" [deps.Latexify.extensions] DataFramesExt = "DataFrames" @@ -930,12 +867,6 @@ version = "0.16.2" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8" -[[deps.LatinHypercubeSampling]] -deps = ["Random", "StableRNGs", "StatsBase", "Test"] -git-tree-sha1 = "825289d43c753c7f1bf9bed334c253e9913997f8" -uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" -version = "1.9.0" - [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" @@ -985,10 +916,10 @@ uuid = "e9f186c6-92d2-5b65-8a66-fee21dc1b490" version = "3.2.2+1" [[deps.Libgcrypt_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll", "Pkg"] -git-tree-sha1 = "64613c82a59c120435c067c2b809fc61cf5166ae" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll"] +git-tree-sha1 = "9fd170c4bbfd8b935fdc5f8b7aa33532c991a673" uuid = "d4300ac3-e22c-5743-9152-c294e39db1e4" -version = "1.8.7+0" +version = "1.8.11+0" [[deps.Libglvnd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll", "Xorg_libXext_jll"] @@ -997,10 +928,10 @@ uuid = "7e76a0d4-f3c7-5321-8279-8d96eeed0f29" version = "1.6.0+0" [[deps.Libgpg_error_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "c333716e46366857753e273ce6a69ee0945a6db9" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "fbb1f2bef882392312feb1ede3615ddc1e9b99ed" uuid = "7add5ba3-2f88-524e-9cd5-f83b8a55f7b8" -version = "1.42.0+0" +version = "1.49.0+0" [[deps.Libiconv_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -1010,9 +941,9 @@ version = "1.17.0+0" [[deps.Libmount_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "dae976433497a2f841baadea93d27e68f1a12a97" +git-tree-sha1 = "0c4f9c4f1a50d8f35048fa0532dabbadf702f81e" uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9" -version = "2.39.3+0" +version = "2.40.1+0" [[deps.Libtiff_jll]] deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"] @@ -1022,9 +953,9 @@ version = "4.5.1+1" [[deps.Libuuid_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "0a04a1318df1bf510beb2562cf90fb0c386f58c4" +git-tree-sha1 = "5ee6203157c120d79034c748a2acba45b82b8807" uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" -version = "2.39.3+1" +version = "2.40.1+0" [[deps.LinearAlgebra]] deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] @@ -1032,9 +963,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" +git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.27" +version = "0.3.28" [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -1057,9 +988,9 @@ version = "1.0.3" [[deps.MAT]] deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] -git-tree-sha1 = "ed1cf0a322d78cee07718bed5fd945e2218c35a1" +git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" uuid = "23992714-dd62-5051-b70f-ba57cb901cac" -version = "0.10.6" +version = "0.10.7" [[deps.MLDatasets]] deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] @@ -1067,69 +998,35 @@ git-tree-sha1 = "aab72207b3c687086a400be710650a57494992bd" uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" version = "0.7.14" -[[deps.MLFlowClient]] -deps = ["Dates", "FilePathsBase", "HTTP", "JSON", "ShowCases", "URIs", "UUIDs"] -git-tree-sha1 = "5cc2a5453856e79f4772269fbe6b19fcdcba391a" -uuid = "64a0f543-368b-4a9a-827a-e71edb2a0b83" -version = "0.4.7" - -[[deps.MLJ]] -deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBalancing", "MLJBase", "MLJEnsembles", "MLJFlow", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "StatisticalMeasures", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "a49aa31103f78b4c13e8d6beb13c5091cce82303" -uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" -version = "0.20.3" - -[[deps.MLJBalancing]] -deps = ["MLJBase", "MLJModelInterface", "MLUtils", "OrderedCollections", "Random", "StatsBase"] -git-tree-sha1 = "f02e28f9f3c54a138db12a97a5d823e5e572c2d6" -uuid = "45f359ea-796d-4f51-95a5-deb1a414c586" -version = "0.1.4" - [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "17d160e8f796ab5ceb4c017bc4019d21fd686a35" +git-tree-sha1 = "24e5d28b2ea86b3feb6af5a5735f012d62e27b65" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -version = "1.2.1" -weakdeps = ["StatisticalMeasures"] +version = "1.4.0" [deps.MLJBase.extensions] DefaultMeasuresExt = "StatisticalMeasures" + [deps.MLJBase.weakdeps] + StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" + [[deps.MLJDecisionTreeInterface]] deps = ["CategoricalArrays", "DecisionTree", "MLJModelInterface", "Random", "Tables"] git-tree-sha1 = "90ef4d3b6cacec631c57cc034e1e61b4aa0ce511" uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661" version = "0.4.2" -[[deps.MLJEnsembles]] -deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatisticalMeasuresBase", "StatsBase"] -git-tree-sha1 = "94403b2c8f692011df6731913376e0e37f6c0fe9" -uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" -version = "0.4.0" - -[[deps.MLJFlow]] -deps = ["MLFlowClient", "MLJBase", "MLJModelInterface"] -git-tree-sha1 = "79989f284c1f6c39eef70f6c8a39736e4f8d3d02" -uuid = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f" -version = "0.4.1" - [[deps.MLJFlux]] -deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"] -git-tree-sha1 = "72935b7de07a7f6b72fd49ecc7898dac79248d46" +deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables"] +git-tree-sha1 = "50c7f24b84005a2a80875c10d4f4059df17a0f68" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" -version = "0.4.0" - -[[deps.MLJIteration]] -deps = ["IterationControl", "MLJBase", "Random", "Serialization"] -git-tree-sha1 = "1e909ee09417ebd18559c4d9c15febff887192df" -uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" -version = "0.6.1" +version = "0.5.1" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] -git-tree-sha1 = "d2a45e1b5998ba3fdfb6cfe0c81096d4c7fb40e7" +git-tree-sha1 = "88ef480f46e0506143681b3fb14d86742f3cecb1" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -version = "1.9.6" +version = "1.10.0" [[deps.MLJModels]] deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] @@ -1137,12 +1034,6 @@ git-tree-sha1 = "410da88e0e6ece5467293d2c76b51b7c6df7d072" uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" version = "0.16.17" -[[deps.MLJTuning]] -deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase", "StatisticalMeasuresBase"] -git-tree-sha1 = "4a2c14b9529753db3ece53fd635c609220200507" -uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" -version = "0.8.4" - [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" uuid = "d8e11817-5142-5d16-987a-aa16d5891078" @@ -1156,21 +1047,21 @@ version = "0.4.4" [[deps.MPICH_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "656036b9ed6f942d35e536e249600bc31d0f9df8" +git-tree-sha1 = "4099bb6809ac109bfc17d521dad33763bcf026b7" uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" -version = "4.2.0+0" +version = "4.2.1+1" [[deps.MPIPreferences]] deps = ["Libdl", "Preferences"] -git-tree-sha1 = "8f6af051b9e8ec597fa09d8885ed79fd582f33c9" +git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" -version = "0.1.10" +version = "0.1.11" [[deps.MPItrampoline_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "77c3bd69fdb024d75af38713e883d0f249ce19c2" +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" -version = "5.3.2+0" +version = "5.4.0+0" [[deps.MacroTools]] deps = ["Markdown", "Random"] @@ -1208,11 +1099,13 @@ deps = ["Artifacts", "BSON", "ChainRulesCore", "Flux", "Functors", "JLD2", "Lazy git-tree-sha1 = "5aac9a2b511afda7bf89df5044a2e0b429f83152" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" version = "0.9.3" -weakdeps = ["CUDA"] [deps.Metalhead.extensions] MetalheadCUDAExt = "CUDA" + [deps.Metalhead.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e" @@ -1245,16 +1138,16 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.1.10" [[deps.MultivariateStats]] -deps = ["Arpack", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] -git-tree-sha1 = "68bf5103e002c44adfd71fea6bd770b3f0586843" +deps = ["Arpack", "Distributions", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] +git-tree-sha1 = "816620e3aac93e5b5359e4fdaf23ca4525b00ddf" uuid = "6f286f6a-111f-5878-ab1e-185364afe411" -version = "0.10.2" +version = "0.10.3" [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "1fa1a14766c60e66ab22e242d45c1857c83a3805" +git-tree-sha1 = "3d4617f943afe6410206a5294a95948c8d1b35bd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.13" +version = "0.9.17" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" @@ -1274,18 +1167,6 @@ git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" version = "0.4.3" -[[deps.NVTX]] -deps = ["Colors", "JuliaNVTXCallbacks_jll", "Libdl", "NVTX_jll"] -git-tree-sha1 = "53046f0483375e3ed78e49190f1154fa0a4083a1" -uuid = "5da4648a-3479-48b8-97b9-01cb529c0a1f" -version = "0.3.4" - -[[deps.NVTX_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "ce3269ed42816bf18d500c9f63418d4b0d9f5a3b" -uuid = "e98f9f5b-d649-5603-91fd-7774390e6439" -version = "3.1.0+2" - [[deps.NaNMath]] deps = ["OpenLibm_jll"] git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" @@ -1333,29 +1214,23 @@ deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" version = "0.8.1+2" -[[deps.OpenML]] -deps = ["ARFFFiles", "HTTP", "JSON", "Markdown", "Pkg", "Scratch"] -git-tree-sha1 = "6efb039ae888699d5a74fb593f6f3e10c7193e33" -uuid = "8b6db2d4-7670-4922-a472-f9537c81ab66" -version = "0.3.1" - [[deps.OpenMPI_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "PMIx_jll", "TOML", "Zlib_jll", "libevent_jll", "prrte_jll"] -git-tree-sha1 = "f46caf663e069027a06942d00dced37f1eb3d8ad" +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" -version = "5.0.2+0" +version = "4.1.6+0" [[deps.OpenSSL]] deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] -git-tree-sha1 = "af81a32750ebc831ee28bdaaba6e1067decef51e" +git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" -version = "1.4.2" +version = "1.4.3" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "3da7367955dcc5c54c1ba4d402ccdc09a1a3e046" +git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.13+1" +version = "3.0.14+0" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -1391,12 +1266,6 @@ git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" version = "0.11.31" -[[deps.PMIx_jll]] -deps = ["Artifacts", "Hwloc_jll", "JLLWrappers", "Libdl", "Zlib_jll", "libevent_jll"] -git-tree-sha1 = "360f48126b5f2c2f0c833be960097f7c62705976" -uuid = "32165bc3-0280-59bc-8c0b-c33b6203efab" -version = "4.2.9+0" - [[deps.PackageExtensionCompat]] git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" @@ -1434,10 +1303,10 @@ uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" version = "1.2.1" [[deps.Pickle]] -deps = ["DataStructures", "InternedStrings", "Serialization", "SparseArrays", "Strided", "StringEncodings", "ZipFile"] -git-tree-sha1 = "e6a34eb1dc0c498f0774bbfbbbeff2de101f4235" +deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] +git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" -version = "0.3.2" +version = "0.3.5" [[deps.Pipe]] git-tree-sha1 = "6842804e7867b115ca9de748a0cf6b364523c16d" @@ -1446,9 +1315,9 @@ version = "1.3.0" [[deps.Pixman_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "Libdl"] -git-tree-sha1 = "64779bc4c9784fee475689a1752ef4d5747c5e87" +git-tree-sha1 = "35621f10a7531bc8fa58f74610b1bfb70a3cfc6b" uuid = "30392449-352a-5448-841d-b1acce4e97dc" -version = "0.42.2+0" +version = "0.43.4+0" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] @@ -1457,9 +1326,9 @@ version = "1.10.0" [[deps.PlotThemes]] deps = ["PlotUtils", "Statistics"] -git-tree-sha1 = "1f03a2d339f42dca4a4da149c7e15e9b896ad899" +git-tree-sha1 = "6e55c6841ce3411ccb3457ee52fc48cb698d6fb0" uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" -version = "3.1.0" +version = "3.2.0" [[deps.PlotUtils]] deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] @@ -1517,9 +1386,9 @@ version = "0.4.2" [[deps.PrettyTables]] deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "88b895d13d53b5577fd53379d913b9ab9ac82660" +git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.1" +version = "2.3.2" [[deps.Printf]] deps = ["Unicode"] @@ -1537,6 +1406,11 @@ git-tree-sha1 = "763a8ceb07833dd51bb9e3bbca372de32c0605ad" uuid = "92933f4c-e287-5a05-a399-4b506db050ca" version = "1.10.0" +[[deps.PtrArrays]] +git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" +uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" +version = "1.2.0" + [[deps.Qt6Base_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"] git-tree-sha1 = "37b7bb7aabf9a085e0044307e1717436117f2b3b" @@ -1557,18 +1431,6 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -[[deps.Random123]] -deps = ["Random", "RandomNumbers"] -git-tree-sha1 = "4743b43e5a9c4a2ede372de7061eed81795b12e7" -uuid = "74087812-796a-5b5d-8853-05524746bad3" -version = "1.7.0" - -[[deps.RandomNumbers]] -deps = ["Random", "Requires"] -git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" -uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" -version = "1.5.3" - [[deps.RealDot]] deps = ["LinearAlgebra"] git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" @@ -1611,10 +1473,10 @@ uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" version = "0.7.1" [[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.0+0" +version = "0.4.2+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -1645,9 +1507,9 @@ version = "1.2.1" [[deps.SentinelArrays]] deps = ["Dates", "Random"] -git-tree-sha1 = "0e7508ff27ba32f26cd459474ca2ede1bc10991f" +git-tree-sha1 = "90b4f68892337554d31cdcdbe19e48989f26c7e6" uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.1" +version = "1.4.3" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -1702,9 +1564,9 @@ version = "0.1.2" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.1" +version = "2.4.0" weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] @@ -1717,10 +1579,10 @@ uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" [[deps.StableRNGs]] -deps = ["Random", "Test"] -git-tree-sha1 = "ddc1a7b85e760b5285b50b882fa91e40c603be47" +deps = ["Random"] +git-tree-sha1 = "83e6cce8324d49dfaf9ef059227f91ed4441a8e5" uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" -version = "1.0.1" +version = "1.0.2" [[deps.StackViews]] deps = ["OffsetArrays"] @@ -1730,9 +1592,9 @@ version = "0.1.1" [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" +git-tree-sha1 = "6e00379a24597be4ae1ee6b2d882e15392040132" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.3" +version = "1.9.5" weakdeps = ["ChainRulesCore", "Statistics"] [deps.StaticArrays.extensions] @@ -1740,23 +1602,9 @@ weakdeps = ["ChainRulesCore", "Statistics"] StaticArraysStatisticsExt = "Statistics" [[deps.StaticArraysCore]] -git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.2" - -[[deps.StatisticalMeasures]] -deps = ["CategoricalArrays", "CategoricalDistributions", "Distributions", "LearnAPI", "LinearAlgebra", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "StatisticalMeasuresBase", "Statistics", "StatsBase"] -git-tree-sha1 = "8b5a165b0ee2b361d692636bfb423b19abfd92b3" -uuid = "a19d573c-0a75-4610-95b3-7071388c7541" -version = "0.1.6" - - [deps.StatisticalMeasures.extensions] - LossFunctionsExt = "LossFunctions" - ScientificTypesExt = "ScientificTypes" - - [deps.StatisticalMeasures.weakdeps] - LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" - ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" +version = "1.4.3" [[deps.StatisticalMeasuresBase]] deps = ["CategoricalArrays", "InteractiveUtils", "MLUtils", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "Statistics"] @@ -1766,9 +1614,9 @@ version = "0.1.1" [[deps.StatisticalTraits]] deps = ["ScientificTypesBase"] -git-tree-sha1 = "30b9236691858e13f167ce829490a68e1a597782" +git-tree-sha1 = "983c41a0ddd6c19f5607ca87271d7c7620ab5d50" uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9" -version = "3.2.0" +version = "3.3.0" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -1801,11 +1649,17 @@ version = "1.3.1" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" -[[deps.Strided]] -deps = ["LinearAlgebra", "TupleTools"] -git-tree-sha1 = "a7a664c91104329c88222aa20264e1a05b6ad138" -uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" -version = "1.2.3" +[[deps.StridedViews]] +deps = ["LinearAlgebra", "PackageExtensionCompat"] +git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" +uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" +version = "0.2.2" + + [deps.StridedViews.extensions] + StridedViewsCUDAExt = "CUDA" + + [deps.StridedViews.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" [[deps.StringEncodings]] deps = ["Libiconv_jll"] @@ -1865,9 +1719,10 @@ uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" version = "1.11.1" [[deps.TaijaBase]] -git-tree-sha1 = "f7b57ab4a5c746d83e8883a38454839d0fab8b6f" +deps = ["CategoricalArrays", "Distributions", "Flux", "MLUtils", "Optimisers", "StatsBase", "Tables"] +git-tree-sha1 = "1c80c4472c6ab6e8c9fa544a22d907295b388dd0" uuid = "10284c91-9f28-4c9a-abbf-ee43576dfff6" -version = "1.0.2" +version = "1.2.2" [[deps.TaijaData]] deps = ["CSV", "CounterfactualExplanations", "DataAPI", "DataFrames", "Flux", "LazyArtifacts", "MLDatasets", "MLJBase", "MLJModels", "Random", "StatsBase"] @@ -1890,16 +1745,10 @@ version = "0.1.1" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -[[deps.TimerOutputs]] -deps = ["ExprTools", "Printf"] -git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" -uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.23" - [[deps.TranscodingStreams]] -git-tree-sha1 = "71509f04d045ec714c4748c785a59045c3736349" +git-tree-sha1 = "a947ea21087caba0a798c5e494d0bb78e3a1a3a0" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.10.7" +version = "0.10.9" weakdeps = ["Random", "Test"] [deps.TranscodingStreams.extensions] @@ -1943,11 +1792,6 @@ version = "0.3.7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -[[deps.TupleTools]] -git-tree-sha1 = "41d61b1c545b06279871ef1a4b5fcb2cac2191cd" -uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" -version = "1.5.0" - [[deps.URIs]] git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" @@ -1973,9 +1817,9 @@ version = "0.4.1" [[deps.Unitful]] deps = ["Dates", "LinearAlgebra", "Random"] -git-tree-sha1 = "3c793be6df9dd77a0cf49d80984ef9ff996948fa" +git-tree-sha1 = "dd260903fdabea27d9b6021689b3cd5401a57748" uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.19.0" +version = "1.20.0" [deps.Unitful.extensions] ConstructionBaseUnitfulExt = "ConstructionBase" @@ -2004,9 +1848,9 @@ version = "0.2.1" [[deps.UnsafeAtomicsLLVM]] deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" +git-tree-sha1 = "d9f5962fecd5ccece07db1ff006fb0b5271bdfdd" uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.3" +version = "0.1.4" [[deps.Unzip]] git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78" @@ -2044,9 +1888,9 @@ version = "1.6.1" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "532e22cf7be8462035d092ff21fada7527e2c488" +git-tree-sha1 = "52ff2af32e591541550bd753c0da8b9bc92bb9d9" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.12.6+0" +version = "2.12.7+0" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "Pkg", "XML2_jll", "Zlib_jll"] @@ -2061,16 +1905,16 @@ uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800" version = "5.4.6+0" [[deps.Xorg_libICE_jll]] -deps = ["Libdl", "Pkg"] -git-tree-sha1 = "e5becd4411063bdcac16be8b66fc2f9f6f1e8fe5" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "326b4fea307b0b39892b3e85fa451692eda8d46c" uuid = "f67eecfb-183a-506d-b269-f58e52b52d7c" -version = "1.0.10+1" +version = "1.1.1+0" [[deps.Xorg_libSM_jll]] -deps = ["Libdl", "Pkg", "Xorg_libICE_jll"] -git-tree-sha1 = "4a9d9e4c180e1e8119b5ffc224a7b59d3a7f7e18" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libICE_jll"] +git-tree-sha1 = "3796722887072218eabafb494a13c963209754ce" uuid = "c834827a-8449-5923-a945-d239c165b7dd" -version = "1.2.3+0" +version = "1.2.4+0" [[deps.Xorg_libX11_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] @@ -2097,10 +1941,10 @@ uuid = "a3789734-cfe1-5b06-b2d0-1dd0d9d62d05" version = "1.1.4+0" [[deps.Xorg_libXext_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] -git-tree-sha1 = "b7c0aa8c376b31e4852b360222848637f481f8c3" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] +git-tree-sha1 = "d2d1a5c49fae4ba39983f63de6afcbea47194e85" uuid = "1082639a-0dae-5f34-9b06-72781eeb8cb3" -version = "1.3.4+4" +version = "1.3.6+0" [[deps.Xorg_libXfixes_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] @@ -2127,10 +1971,10 @@ uuid = "ec84b674-ba8e-5d96-8ba1-2a689ba10484" version = "1.5.2+4" [[deps.Xorg_libXrender_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] -git-tree-sha1 = "19560f30fd49f4d4efbe7002a1037f8c43d43b96" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] +git-tree-sha1 = "47e45cd78224c53109495b3e324df0c37bb61fbe" uuid = "ea2f1a96-1ddc-540d-b46f-429655e07cfa" -version = "0.9.10+4" +version = "0.9.11+0" [[deps.Xorg_libpthread_stubs_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -2223,9 +2067,9 @@ version = "1.5.6+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "4ddb4470e47b0094c93055a3bcae799165cc68f1" +git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.69" +version = "0.6.70" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" @@ -2268,10 +2112,10 @@ uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" version = "1.1.2+0" [[deps.libaom_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "3a2ea60308f0996d26f1e5354e10c24e9ef905d4" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1827acba325fdcdf1d2647fc8d5301dd9ba43a9d" uuid = "a4ae2306-e953-59d6-aa16-d00cac43593b" -version = "3.4.0+0" +version = "3.9.0+0" [[deps.libass_jll]] deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] @@ -2290,12 +2134,6 @@ git-tree-sha1 = "141fe65dc3efabb0b1d5ba74e91f6ad26f84cc22" uuid = "2db6ffa8-e38f-5e21-84af-90c45d0032cc" version = "1.11.0+0" -[[deps.libevent_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "OpenSSL_jll"] -git-tree-sha1 = "f04ec6d9a186115fb38f858f05c0c4e1b7fc9dcb" -uuid = "1080aeaf-3a6a-583e-a51c-c537b09f60ec" -version = "2.1.13+1" - [[deps.libfdk_aac_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "daacc84a041563f965be61859a36e17c4e4fcd55" @@ -2336,12 +2174,6 @@ deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" version = "17.4.0+2" -[[deps.prrte_jll]] -deps = ["Artifacts", "Hwloc_jll", "JLLWrappers", "Libdl", "PMIx_jll", "libevent_jll"] -git-tree-sha1 = "5adb2d7a18a30280feb66cad6f1a1dfdca2dc7b0" -uuid = "eb928a42-fffd-568d-ab9c-3f5d54fc65b9" -version = "3.0.2+0" - [[deps.x264_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "4fea590b89e6ec504593146bf8b988b2c00922b2" diff --git a/test/Project.toml b/test/Project.toml index 7af07ce..588aa08 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,7 +6,6 @@ DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index b39a9cf..5c7b98d 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -1,6 +1,5 @@ using Random: Random import Random.seed! -using MLJ using MLJBase using MLJFlux using Flux @@ -114,6 +113,8 @@ function basictest_classification(X, y, builder, optimiser, threshold) link_approx=:incorrect, ) + # Test that shape is correct: + @test MLJFlux.shape(model, X, y)[2] == length(unique(y)) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) history = _report.training_losses @@ -140,7 +141,7 @@ function basictest_classification(X, y, builder, optimiser, threshold) @test length(history) == model.epochs + 1 # start fresh with small epochs: - model = LaplaceRegression(; + model = LaplaceClassification(; builder=builder, optimiser=optimiser, epochs=2, acceleration=CPU1(), rng=stable_rng ) @@ -193,4 +194,4 @@ builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) optimizer = Flux.Optimise.Adam(0.03) y_onehot = transpose(unique(y) .== permutedims(y)) -@test basictest_classification(X, y_onehot, builder, optimizer, 0.9) +@test basictest_classification(X, y, builder, optimizer, 0.9) From f95a4847ee5ddefaa0df4a638670025e0d7a0212 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Sat, 15 Jun 2024 09:58:48 +0200 Subject: [PATCH 14/32] fix error related to mean/var --- src/baselaplace/predicting.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index e431fd2..b2ef540 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -1,4 +1,5 @@ using Distributions: Distributions +using Statistics: mean, var """ functional_variance(la::AbstractLaplace, 𝐉::AbstractArray) From 5d47558087a864bcf8ad868f21317a61083ca2b0 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Sat, 15 Jun 2024 10:19:21 +0200 Subject: [PATCH 15/32] small changes to CI --- .github/workflows/CI.yml | 15 ++++----------- CHANGELOG.md | 4 ++-- Project.toml | 8 ++++---- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 138d43b..50507c9 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,8 +21,6 @@ jobs: matrix: version: - '1.7' - - '1.8' - - '1.9' - '1.10' os: - ubuntu-latest @@ -33,9 +31,6 @@ jobs: - os: windows-latest version: '1.7' arch: x64 - - os: windows-latest - version: '1.8' - arch: x64 - os: windows-latest version: '1' arch: x64 @@ -45,9 +40,6 @@ jobs: - os: macOS-latest version: '1.7' arch: x64 - - os: macOS-latest - version: '1.8' - arch: x64 steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 @@ -60,9 +52,10 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v2 - with: - files: lcov.info + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} docs: name: Documentation runs-on: ubuntu-latest diff --git a/CHANGELOG.md b/CHANGELOG.md index 1473dd4..6e6fe29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), ### Changed +- Updated codecov workflow in CI.yml. [#39] - fixed test functions - adapted the LaplaceClassification and the LaplaceRegression struct to use the new @mlj_model macro from MLJBase. - Changed the fit! method arguments. @@ -20,13 +21,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), - Removed the shape, build and clean! functions. - Removed Review dog for code format suggestions. [#39] - ## Version [0.2.3] - 2024-05-31 ### Changed - Removed the link_approx parameter in LaplaceRegression since it is not required. -- Changed MMI.clean! to check the value of link_approx only in the case likelihood is set to :classification +- Changed MMI.clean! to check the value of link_approx only in the case likelihood is set to `:classification` - Now the likelihood type in LaplaceClassification and LaplaceRegression is automatically set by the inner constructor. The user is not required to provide it as a parameter anymore. diff --git a/Project.toml b/Project.toml index 068dbb4..a04ac41 100644 --- a/Project.toml +++ b/Project.toml @@ -28,19 +28,19 @@ Compat = "4.7.0" ComputationalResources = "0.3.2" Distributions = "0.25.109" Flux = "0.12, 0.13, 0.14" -LinearAlgebra = "1.6, 1.7, 1.8, 1.9, 1.10" +LinearAlgebra = "1.7, 1.10" MLJBase = "0, 1.4.0" MLJFlux = "0.2.10, 0.3, 0.4" MLJModelInterface = "1.8.0" MLUtils = "0.4.3" ProgressMeter = "1.7.2" -Random = "1.6, 1.7, 1.8, 1.9, 1.10" +Random = "1.7, 1.10" Statistics = "1" Tables = "1.10.1" -Test = "1.6, 1.7, 1.8, 1.9, 1.10" +Test = "1.7, 1.10" Tullio = "0.3.5" Zygote = "0.6" -julia = "1.6, 1.7, 1.8, 1.9, 1.10" +julia = "1.7, 1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" From 48d647f4a8a28ba32463cc3c67b96a0c7fc4ec54 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sun, 16 Jun 2024 13:32:44 +0200 Subject: [PATCH 16/32] added docstrings, removed prints and comments. there is still the predict issue. --- CHANGELOG.md | 10 +-- src/baselaplace/predicting.jl | 2 + src/mlj_flux.jl | 127 +++++++++++++++++++++++++++------- test/mlj_flux_interfacing.jl | 6 +- 4 files changed, 110 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e6fe29..30cc94d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,14 +11,14 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), ### Changed - Updated codecov workflow in CI.yml. [#39] -- fixed test functions -- adapted the LaplaceClassification and the LaplaceRegression struct to use the new @mlj_model macro from MLJBase. -- Changed the fit! method arguments. -- Changed the predict functions for both LaplaceClassification and LaplaceRegression. +- fixed test functions [#39] +- adapted the LaplaceClassification and the LaplaceRegression struct to use the new @mlj_model macro from MLJBase.[#39] +- Changed the fit! method arguments. [#39] +- Changed the predict functions for both LaplaceClassification and LaplaceRegression.[#39] ### Removed -- Removed the shape, build and clean! functions. +- Removed the shape, build and clean! functions.[#39] - Removed Review dog for code format suggestions. [#39] ## Version [0.2.3] - 2024-05-31 diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index b2ef540..45f8eeb 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -46,6 +46,8 @@ function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray) Distributions.Normal(fμ[i, j], fstd[i, j]) for i in 1:size(fμ, 1), j in 1:size(fμ, 2) ] + #normal_distr = [ + #Distributions.Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 1)] maybe this one is the correct one return normal_distr end diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 789e8ca..4996c26 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -105,10 +105,22 @@ MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbab predict_proba::Bool = true::(_ in (true, false)) fit_prior_nsteps::Int = 100::(_ > 0) end -############### + const MLJ_Laplace = Union{LaplaceClassification,LaplaceRegression} -################################ functions shape and build +""" + MLJFlux.shape(model::LaplaceClassification, X, y) + +Compute the the number of features of the dataset X and the number of unique classes in y. + +# Arguments +- `model::LaplaceClassification`: The LaplaceClassification model to fit. +- `X`: The input data for training. +- `y`: The target labels for training one-hot encoded. + +# Returns +- (input size, output size) +""" function MLJFlux.shape(model::LaplaceClassification, X, y) X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X @@ -118,6 +130,21 @@ function MLJFlux.shape(model::LaplaceClassification, X, y) return (n_input, n_output) end + + +""" + MLJFlux.shape(model::LaplaceRegression, X, y) + +Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset. + +# Arguments +- `model::LaplaceRegression`: The LaplaceRegression model to fit. +- `X`: The input data for training. +- `y`: The target labels for training one-hot encoded. + +# Returns +- (input size, output size) +""" function MLJFlux.shape(model::LaplaceRegression, X, y) X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X n_input = size(X, 2) @@ -130,6 +157,21 @@ function MLJFlux.shape(model::LaplaceRegression, X, y) return (n_input, n_output) end + + +""" + MLJFlux.build(model::LaplaceClassification, rng, shape) + +Builds an MLJFlux model for Laplace classification compatible with the dimensions of the input and output layers specified by `shape`. + +# Arguments +- `model::LaplaceClassification`: The Laplace classification model. +- `rng`: A random number generator to ensure reproducibility. +- `shape`: A tuple or array specifying the dimensions of the input and output layers. + +# Returns +- The constructed MLJFlux model, compatible with the specified input and output dimensions. +""" function MLJFlux.build(model::LaplaceClassification, rng, shape) #chain chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) @@ -137,22 +179,68 @@ function MLJFlux.build(model::LaplaceClassification, rng, shape) return chain end + + +""" + MLJFlux.build(model::LaplaceRegression, rng, shape) + +Builds an MLJFlux model for Laplace regression compatible with the dimensions of the input and output layers specified by `shape`. + +# Arguments +- `model::LaplaceRegression`: The Laplace regression model. +- `rng`: A random number generator to ensure reproducibility. +- `shape`: A tuple or array specifying the dimensions of the input and output layers. + +# Returns +- The constructed MLJFlux model, compatible with the specified input and output dimensions. +""" function MLJFlux.build(model::LaplaceRegression, rng, shape) - #chain + chain = MLJFlux.build(model.builder, rng, shape...) return chain end + +""" + MLJFlux.fitresult(model::LaplaceClassification, chain, y) + +Computes the fit result for a Laplace classification model, returning the model chain and the number of unique classes in the target data. + +# Arguments +- `model::LaplaceClassification`: The Laplace classification model to be evaluated. +- `chain`: The trained model chain. +- `y`: The target data, typically a vector of class labels. + +# Returns +- A tuple containing: + - The model chain. + - The number of unique classes in the target data `y`. +""" function MLJFlux.fitresult(model::LaplaceClassification, chain, y) return (chain, length(unique(y))) end + +""" + MLJFlux.fitresult(model::LaplaceRegression, chain, y) + +Computes the fit result for a Laplace Regression model, returning the model chain and the number of output variables in the target data. + +# Arguments +- `model::LaplaceRegression`: The Laplace Regression model to be evaluated. +- `chain`: The trained model chain. +- `y`: The target data, typically a vector of class labels. + +# Returns +- A tuple containing: + - The model chain. + - The number of unique classes in the target data `y`. +""" function MLJFlux.fitresult(model::LaplaceRegression, chain, y) return (chain, size(y)) end -######################################### fit and predict for classification """ MLJFlux.fit!(model::LaplaceClassification, chain,penalty,optimiser,epochs, verbosity, X, y) @@ -166,7 +254,7 @@ Fit the LaplaceClassification model using MLJFlux. - `y`: The target labels for training one-hot encoded. # Returns -- (la,chain), history, report) +- (chain,la), history, report) where la is the fitted Laplace model. """ function MLJFlux.fit!(model::LaplaceClassification, verbosity, X, y) @@ -236,7 +324,7 @@ function MLJFlux.fit!(model::LaplaceClassification, verbosity, X, y) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) report = [] - return ((la, chain), history, report) + return ((chain,la), history, report) end """ @@ -254,7 +342,7 @@ An array of predicted class labels. """ function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) - la = fitresult[1] + la = fitresult[2] Xnew = MLJBase.matrix(Xnew) #convert in a vector of vectors because Laplace ask to do so X_vec = X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] @@ -281,10 +369,10 @@ Fit the LaplaceRegression model using Flux.jl. - `y`: The target labels for training. # Returns -- `model::LaplaceRegression`: The fitted LaplaceRegression model. +- (chain,la), loss history, report ) +where la is the fitted Laplace model. """ function MLJFlux.fit!(model::LaplaceRegression, verbosity, X, y) - #X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X X = MLJBase.matrix(X) @@ -316,8 +404,6 @@ function MLJFlux.fit!(model::LaplaceRegression, verbosity, X, y) barlen=25, color=:yellow, ) - println("Shape of X prima di loader: ", size(X)) - println("Shape of y prima di loader : ", size(y)) # Create a data loader loader = Flux.Data.DataLoader( (data=X', label=y); batchsize=model.batch_size, shuffle=true @@ -328,8 +414,6 @@ function MLJFlux.fit!(model::LaplaceRegression, verbosity, X, y) # train the model for (X_batch, y_batch) in loader y_batch = reshape(y_batch, 1, :) - println("Shape of X_batch dopo di loader: ", size(X_batch)) - println("Shape of y_batch dopo di loader : ", size(y_batch)) # Backward pass gs = Flux.gradient(parameters) do @@ -356,7 +440,7 @@ function MLJFlux.fit!(model::LaplaceRegression, verbosity, X, y) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) report = [] - return ((la, chain), history, report) + return ((chain,la), history, report) end """ @@ -376,7 +460,7 @@ Predict the output for new input data using a Laplace regression model. function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) Xnew = MLJBase.matrix(Xnew) - la = fitresult[1] + la = fitresult[2] #convert in a vector of vectors because MLJ ask to do so X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] #inizialize output vector yhat @@ -384,19 +468,10 @@ function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) # Predict using Laplace and collect the predictions yhat = [glm_predictive_distribution(la, x_vec) for x_vec in X_vec] - #predictions = [] - #for row in eachrow(yhat) - - #mean_val = Float64(row[1][1][1]) - #std_val = sqrt(Float64(row[1][2][1])) - # Append a Normal distribution: - #push!(predictions, Normal(mean_val, std_val)) - #end - return yhat end -# Then for each model, +# metadata for each model, MLJBase.metadata_model( LaplaceClassification; input=Union{ @@ -410,7 +485,7 @@ MLJBase.metadata_model( target=Union{AbstractArray{MLJBase.Finite},AbstractArray{MLJBase.Continuous}}, path="MLJFlux.LaplaceClassification", ) -# Then for each model, +# metadata for each model, MLJBase.metadata_model( LaplaceRegression; input=Union{ diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index 5c7b98d..019a28e 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -45,7 +45,7 @@ function basictest_regression(X, y, builder, optimiser, threshold) @test :chain in keys(MLJBase.fitted_params(model, fitresult)) - #yhat = MLJBase.predict(model, fitresult, X) + yhat = MLJBase.predict(model, fitresult, X) history = _report.training_losses @test length(history) == model.epochs + 1 @@ -135,7 +135,7 @@ function basictest_classification(X, y, builder, optimiser, threshold) @test :chain in keys(MLJBase.fitted_params(model, fitresult)) - #yhat = MLJBase.predict(model, fitresult, X) + yhat = MLJBase.predict(model, fitresult, X) history = _report.training_losses @test length(history) == model.epochs + 1 @@ -192,6 +192,4 @@ y = categorical( builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) optimizer = Flux.Optimise.Adam(0.03) -y_onehot = transpose(unique(y) .== permutedims(y)) - @test basictest_classification(X, y, builder, optimizer, 0.9) From 1b05992a84411a356f232c401f134e005834e8e0 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sun, 23 Jun 2024 20:08:27 +0200 Subject: [PATCH 17/32] partial (the update method is missing) --- src/mlj_flux.jl | 195 ++++++++++++++++------------------- test/mlj_flux_interfacing.jl | 83 +++------------ 2 files changed, 105 insertions(+), 173 deletions(-) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 4996c26..2a87490 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -7,6 +7,7 @@ using LinearAlgebra using LaplaceRedux using ComputationalResources using MLJBase: MLJBase +import MLJModelInterface as MMI """ MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic @@ -32,6 +33,7 @@ The model is trained using the `fit!` method. The model is defined by the follow - `μ₀`: the mean of the prior distribution. - `P₀`: the covariance matrix of the prior distribution. - `fit_prior_nsteps`: the number of steps used to fit the priors. +- `la`: the Laplace model. """ MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) @@ -53,6 +55,7 @@ MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilis μ₀::Float64 = 0.0 P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing fit_prior_nsteps::Int = 100::(_ > 0) + la::Union{Nothing,AbstractLaplace} = nothing end """ @@ -80,6 +83,7 @@ A mutable struct representing a Laplace Classification model that extends the ML - `μ₀`: the mean of the prior distribution. - `P₀`: the covariance matrix of the prior distribution. - `fit_prior_nsteps`: the number of steps used to fit the priors. +- `la`: the Laplace model. """ MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) @@ -104,6 +108,7 @@ MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbab link_approx::Symbol = :probit::(_ in (:probit, :plugin)) predict_proba::Bool = true::(_ in (true, false)) fit_prior_nsteps::Int = 100::(_ > 0) + la::Union{Nothing,AbstractLaplace} = nothing end const MLJ_Laplace = Union{LaplaceClassification,LaplaceRegression} @@ -123,6 +128,7 @@ Compute the the number of features of the dataset X and the number of unique cl """ function MLJFlux.shape(model::LaplaceClassification, X, y) + println("shapeclass") X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X n_input = size(X, 2) levels = unique(y) @@ -130,8 +136,6 @@ function MLJFlux.shape(model::LaplaceClassification, X, y) return (n_input, n_output) end - - """ MLJFlux.shape(model::LaplaceRegression, X, y) @@ -146,6 +150,7 @@ Compute the the number of features of the X input dataset and the number of var - (input size, output size) """ function MLJFlux.shape(model::LaplaceRegression, X, y) + println("shapereg") X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X n_input = size(X, 2) dims = size(y) @@ -157,8 +162,6 @@ function MLJFlux.shape(model::LaplaceRegression, X, y) return (n_input, n_output) end - - """ MLJFlux.build(model::LaplaceClassification, rng, shape) @@ -173,14 +176,23 @@ Builds an MLJFlux model for Laplace classification compatible with the dimension - The constructed MLJFlux model, compatible with the specified input and output dimensions. """ function MLJFlux.build(model::LaplaceClassification, rng, shape) - #chain + println("buildclass") chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) + model.la = Laplace( + chain; + likelihood=:classification, + subset_of_weights=model.subset_of_weights, + subnetwork_indices=model.subnetwork_indices, + hessian_structure=model.hessian_structure, + backend=model.backend, + σ=model.σ, + μ₀=model.μ₀, + P₀=model.P₀, + ) return chain end - - """ MLJFlux.build(model::LaplaceRegression, rng, shape) @@ -195,13 +207,24 @@ Builds an MLJFlux model for Laplace regression compatible with the dimensions of - The constructed MLJFlux model, compatible with the specified input and output dimensions. """ function MLJFlux.build(model::LaplaceRegression, rng, shape) + println("buildreg") chain = MLJFlux.build(model.builder, rng, shape...) + model.la = Laplace( + chain; + likelihood=:regression, + subset_of_weights=model.subset_of_weights, + subnetwork_indices=model.subnetwork_indices, + hessian_structure=model.hessian_structure, + backend=model.backend, + σ=model.σ, + μ₀=model.μ₀, + P₀=model.P₀, + ) return chain end - """ MLJFlux.fitresult(model::LaplaceClassification, chain, y) @@ -218,10 +241,10 @@ Computes the fit result for a Laplace classification model, returning the model - The number of unique classes in the target data `y`. """ function MLJFlux.fitresult(model::LaplaceClassification, chain, y) - return (chain, length(unique(y))) + println("fitresultclass") + return (chain, model.la, length(unique(y))) end - """ MLJFlux.fitresult(model::LaplaceRegression, chain, y) @@ -238,31 +261,26 @@ Computes the fit result for a Laplace Regression model, returning the model chai - The number of unique classes in the target data `y`. """ function MLJFlux.fitresult(model::LaplaceRegression, chain, y) - return (chain, size(y)) + println("fitresultregre") + if y isa AbstractArray + target_column_names = nothing + else + target_column_names = Tables.schema(y).names + end + return (chain, model.la, size(y)) end +function MLJFlux.fit!( + model::LaplaceClassification, penalty, chain, optimiser, epochs, verbosity, X, y +) + println("fitclass") + X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X -""" - MLJFlux.fit!(model::LaplaceClassification, chain,penalty,optimiser,epochs, verbosity, X, y) - -Fit the LaplaceClassification model using MLJFlux. - -# Arguments -- `model::LaplaceClassification`: The LaplaceClassification object to fit. -- `verbosity`: The verbosity level for training. -- `X`: The input data for training. -- `y`: The target labels for training one-hot encoded. - -# Returns -- (chain,la), history, report) -where la is the fitted Laplace model. -""" -function MLJFlux.fit!(model::LaplaceClassification, verbosity, X, y) - X = MLJBase.matrix(X) + #X = MLJBase.matrix(X) - shape = MLJFlux.shape(model, X, y) + #shape = MLJFlux.shape(model, X, y) - chain = MLJFlux.build(model, model.rng, shape) + #chain = MLJFlux.build(model, model.rng, shape) la = LaplaceRedux.Laplace( chain; @@ -275,7 +293,6 @@ function MLJFlux.fit!(model::LaplaceClassification, verbosity, X, y) μ₀=model.μ₀, P₀=model.P₀, ) - n_samples = size(X, 1) verbose_laplace = false # Initialize history: @@ -289,44 +306,27 @@ function MLJFlux.fit!(model::LaplaceClassification, verbosity, X, y) barlen=25, color=:yellow, ) - # Create a data loader - loader = Flux.Data.DataLoader( - (data=X', label=y); batchsize=model.batch_size, shuffle=true - ) - parameters = Flux.params(chain) - for i in 1:(model.epochs) - epoch_loss = 0.0 - # train the model - for (X_batch, y_batch) in loader - - # Backward pass - gs = Flux.gradient(parameters) do - batch_loss = (model.loss(chain(X_batch), y_batch)) - epoch_loss += batch_loss - end - # Update parameters - Flux.update!(model.optimiser, parameters, gs) - end - epoch_loss /= n_samples - push!(history, epoch_loss) - #verbosity - if verbosity == 1 - next!(meter) - elseif verbosity == 2 - next!(meter) - verbose_laplace = true - println("Loss is $(round(epoch_loss; sigdigits=4))") - end + for i in 1:epochs + current_loss = MLJFlux.train!(model, penalty, chain, optimiser, X, y) + verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" + verbosity != 1 || next!(meter) + append!(history, current_loss) end # fit the Laplace model: - LaplaceRedux.fit!(la, zip(eachrow(X), y)) + LaplaceRedux.fit!(la, zip(X, y)) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - report = [] + model.la = la - return ((chain,la), history, report) -end + cache = () + fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) + report = history + + #return cache, report + + return (fitresult, report, cache) +end """ predict(model::LaplaceClassification, Xnew) @@ -342,6 +342,7 @@ An array of predicted class labels. """ function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) + println("predictclass") la = fitresult[2] Xnew = MLJBase.matrix(Xnew) #convert in a vector of vectors because Laplace ask to do so @@ -369,16 +370,16 @@ Fit the LaplaceRegression model using Flux.jl. - `y`: The target labels for training. # Returns -- (chain,la), loss history, report ) +- where la is the fitted Laplace model. """ -function MLJFlux.fit!(model::LaplaceRegression, verbosity, X, y) - - X = MLJBase.matrix(X) - - shape = MLJFlux.shape(model, X, y) +#model::LaplaceRegression, penalty, chain, optimiser, epochs, verbosity, X, y +function MLJFlux.fit!( + model::LaplaceRegression, penalty, chain, optimiser, epochs, verbosity, X, y +) + println("fitregre") - chain = MLJFlux.build(model, model.rng, shape) + X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X la = LaplaceRedux.Laplace( chain; @@ -391,56 +392,41 @@ function MLJFlux.fit!(model::LaplaceRegression, verbosity, X, y) μ₀=model.μ₀, P₀=model.P₀, ) - n_samples = size(X, 1) + # Initialize history: history = [] verbose_laplace = false # intitialize and start progress meter: meter = Progress( - model.epochs + 1; + epochs + 1; dt=1.0, desc="Optimising neural net:", barglyphs=BarGlyphs("[=> ]"), barlen=25, color=:yellow, ) - # Create a data loader - loader = Flux.Data.DataLoader( - (data=X', label=y); batchsize=model.batch_size, shuffle=true - ) - parameters = Flux.params(chain) - for i in 1:(model.epochs) - epoch_loss = 0.0 - # train the model - for (X_batch, y_batch) in loader - y_batch = reshape(y_batch, 1, :) - - # Backward pass - gs = Flux.gradient(parameters) do - batch_loss = (model.loss(chain(X_batch), y_batch)) - epoch_loss += batch_loss - end - # Update parameters - Flux.update!(model.optimiser, parameters, gs) - end - epoch_loss /= n_samples - push!(history, epoch_loss) - #verbosity - if verbosity == 1 - next!(meter) - elseif verbosity == 2 - next!(meter) - verbose_laplace = true - println("Loss is $(round(epoch_loss; sigdigits=4))") - end + verbosity != 1 || next!(meter) + + for i in 1:epochs + current_loss = MLJFlux.train!(model, penalty, chain, optimiser, X, y) + verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" + verbosity != 1 || next!(meter) + push!(history, current_loss) end # fit the Laplace model: - LaplaceRedux.fit!(la, zip(eachrow(X), y)) + LaplaceRedux.fit!(la, zip(X, y)) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - report = [] + model.la = la + + cache = () + fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) + + report = history + + #return cache, report - return ((chain,la), history, report) + return (fitresult, report, cache) end """ @@ -458,6 +444,7 @@ Predict the output for new input data using a Laplace regression model. """ function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) + println("predictregre") Xnew = MLJBase.matrix(Xnew) la = fitresult[2] diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index 019a28e..351d233 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -24,11 +24,15 @@ function basictest_regression(X, y, builder, optimiser, threshold) hessian_structure=:incorrect, backend=:incorrect, ) - fitresult, cache, _report = MLJBase.fit(model, 0, X, y) + println("ayyyyy") + #println(boh) + #println(cache) + println(_report) history = _report.training_losses - @test length(history) == model.epochs + 1 + #println(fitresult) + @test length(history) == model.epochs # test improvement in training loss: @test history[end] < threshold * history[1] @@ -36,19 +40,11 @@ function basictest_regression(X, y, builder, optimiser, threshold) # increase iterations and check update is incremental: model.epochs = model.epochs + 3 - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - @test :chain in keys(MLJBase.fitted_params(model, fitresult)) yhat = MLJBase.predict(model, fitresult, X) history = _report.training_losses - @test length(history) == model.epochs + 1 # start fresh with small epochs: model = LaplaceRegression(; @@ -57,41 +53,9 @@ function basictest_regression(X, y, builder, optimiser, threshold) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - # change batch_size and check it performs cold restart: - model.batch_size = 2 - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - # change learning rate and check it does *not* restart: - model.optimiser.eta /= 2 - fitresult, cache, _report = @test_logs(MLJBase.update(model, 2, fitresult, cache, X, y)) - - # set `optimiser_changes_trigger_retraining = true` and change - # learning rate and check it does restart: - model.optimiser_changes_trigger_retraining = true - model.optimiser.eta /= 2 - @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - return true end -seed!(1234) -N = 300 -X = MLJBase.table(rand(Float32, N, 4)); -ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N) - -builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) -optimizer = Flux.Optimise.Adam(0.03) - -@test basictest_regression(X, ycont, builder, optimizer, 0.9) - function basictest_classification(X, y, builder, optimiser, threshold) optimiser = deepcopy(optimiser) @@ -100,25 +64,22 @@ function basictest_classification(X, y, builder, optimiser, threshold) model = LaplaceClassification(; builder=builder, optimiser=optimiser, - loss=Flux.crossentropy, - epochs=-1, - batch_size=-1, + acceleration=CPUThreads(), + rng=stable_rng, lambda=-1.0, alpha=-1.0, - rng=stable_rng, - acceleration=CPUThreads(), + epochs=-1, + batch_size=-1, subset_of_weights=:incorrect, hessian_structure=:incorrect, backend=:incorrect, link_approx=:incorrect, ) - # Test that shape is correct: - @test MLJFlux.shape(model, X, y)[2] == length(unique(y)) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) history = _report.training_losses - @test length(history) == model.epochs + 1 + @test length(history) == model.epochs # test improvement in training loss: @test history[end] < threshold * history[1] @@ -126,19 +87,11 @@ function basictest_classification(X, y, builder, optimiser, threshold) # increase iterations and check update is incremental: model.epochs = model.epochs + 3 - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - @test :chain in keys(MLJBase.fitted_params(model, fitresult)) yhat = MLJBase.predict(model, fitresult, X) history = _report.training_losses - @test length(history) == model.epochs + 1 # start fresh with small epochs: model = LaplaceClassification(; @@ -149,25 +102,14 @@ function basictest_classification(X, y, builder, optimiser, threshold) # change batch_size and check it performs cold restart: model.batch_size = 2 - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) # change learning rate and check it does *not* restart: model.optimiser.eta /= 2 - fitresult, cache, _report = @test_logs(MLJBase.update(model, 2, fitresult, cache, X, y)) # set `optimiser_changes_trigger_retraining = true` and change # learning rate and check it does restart: model.optimiser_changes_trigger_retraining = true model.optimiser.eta /= 2 - @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) return true end @@ -192,4 +134,7 @@ y = categorical( builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) optimizer = Flux.Optimise.Adam(0.03) + +@test basictest_regression(X, ycont, builder, optimizer, 0.9) + @test basictest_classification(X, y, builder, optimizer, 0.9) From 0bf4ea8ba19f37126e969c2674b20ee83dffbebc Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Mon, 24 Jun 2024 12:19:35 +0200 Subject: [PATCH 18/32] [internet-not working]: changed fitresult and the position of report --- src/mlj_flux.jl | 29 +++++++++++++---------------- test/mlj_flux_interfacing.jl | 1 - 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 2a87490..75249dc 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -237,12 +237,12 @@ Computes the fit result for a Laplace classification model, returning the model # Returns - A tuple containing: - - The model chain. + - The model. - The number of unique classes in the target data `y`. """ function MLJFlux.fitresult(model::LaplaceClassification, chain, y) println("fitresultclass") - return (chain, model.la, length(unique(y))) + return ( model, length(unique(y))) end """ @@ -257,17 +257,13 @@ Computes the fit result for a Laplace Regression model, returning the model chai # Returns - A tuple containing: - - The model chain. - - The number of unique classes in the target data `y`. + - The model. + - The size `y`. """ function MLJFlux.fitresult(model::LaplaceRegression, chain, y) println("fitresultregre") - if y isa AbstractArray - target_column_names = nothing - else - target_column_names = Tables.schema(y).names - end - return (chain, model.la, size(y)) + + return ( model, size(y)) end function MLJFlux.fit!( @@ -325,7 +321,7 @@ function MLJFlux.fit!( #return cache, report - return (fitresult, report, cache) + return (fitresult, cache, report) end """ predict(model::LaplaceClassification, Xnew) @@ -343,7 +339,7 @@ An array of predicted class labels. """ function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) println("predictclass") - la = fitresult[2] + model = fitresult[1] Xnew = MLJBase.matrix(Xnew) #convert in a vector of vectors because Laplace ask to do so X_vec = X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] @@ -351,7 +347,7 @@ function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) # Predict using Laplace and collect the predictions predictions = [ LaplaceRedux.predict( - la, x; link_approx=model.link_approx, predict_proba=model.predict_proba + model.la, x; link_approx=model.link_approx, predict_proba=model.predict_proba ) for x in X_vec ] @@ -426,7 +422,7 @@ function MLJFlux.fit!( #return cache, report - return (fitresult, report, cache) + return (fitresult, cache, report) end """ @@ -447,13 +443,14 @@ function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) println("predictregre") Xnew = MLJBase.matrix(Xnew) - la = fitresult[2] + model = fitresult[1] + #convert in a vector of vectors because MLJ ask to do so X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] #inizialize output vector yhat yhat = [] # Predict using Laplace and collect the predictions - yhat = [glm_predictive_distribution(la, x_vec) for x_vec in X_vec] + yhat = [glm_predictive_distribution(model.la, x_vec) for x_vec in X_vec] return yhat end diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index 351d233..7b8680b 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -26,7 +26,6 @@ function basictest_regression(X, y, builder, optimiser, threshold) ) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - println("ayyyyy") #println(boh) #println(cache) println(_report) From 7b8f6ab33b077fbb778b470e7bff90785ac3dc72 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Mon, 24 Jun 2024 18:19:04 +0200 Subject: [PATCH 19/32] Revert "[internet-not working]: changed fitresult and the position of report" This reverts commit 0bf4ea8ba19f37126e969c2674b20ee83dffbebc. --- src/mlj_flux.jl | 29 ++++++++++++++++------------- test/mlj_flux_interfacing.jl | 1 + 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 75249dc..2a87490 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -237,12 +237,12 @@ Computes the fit result for a Laplace classification model, returning the model # Returns - A tuple containing: - - The model. + - The model chain. - The number of unique classes in the target data `y`. """ function MLJFlux.fitresult(model::LaplaceClassification, chain, y) println("fitresultclass") - return ( model, length(unique(y))) + return (chain, model.la, length(unique(y))) end """ @@ -257,13 +257,17 @@ Computes the fit result for a Laplace Regression model, returning the model chai # Returns - A tuple containing: - - The model. - - The size `y`. + - The model chain. + - The number of unique classes in the target data `y`. """ function MLJFlux.fitresult(model::LaplaceRegression, chain, y) println("fitresultregre") - - return ( model, size(y)) + if y isa AbstractArray + target_column_names = nothing + else + target_column_names = Tables.schema(y).names + end + return (chain, model.la, size(y)) end function MLJFlux.fit!( @@ -321,7 +325,7 @@ function MLJFlux.fit!( #return cache, report - return (fitresult, cache, report) + return (fitresult, report, cache) end """ predict(model::LaplaceClassification, Xnew) @@ -339,7 +343,7 @@ An array of predicted class labels. """ function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) println("predictclass") - model = fitresult[1] + la = fitresult[2] Xnew = MLJBase.matrix(Xnew) #convert in a vector of vectors because Laplace ask to do so X_vec = X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] @@ -347,7 +351,7 @@ function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) # Predict using Laplace and collect the predictions predictions = [ LaplaceRedux.predict( - model.la, x; link_approx=model.link_approx, predict_proba=model.predict_proba + la, x; link_approx=model.link_approx, predict_proba=model.predict_proba ) for x in X_vec ] @@ -422,7 +426,7 @@ function MLJFlux.fit!( #return cache, report - return (fitresult, cache, report) + return (fitresult, report, cache) end """ @@ -443,14 +447,13 @@ function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) println("predictregre") Xnew = MLJBase.matrix(Xnew) - model = fitresult[1] - + la = fitresult[2] #convert in a vector of vectors because MLJ ask to do so X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] #inizialize output vector yhat yhat = [] # Predict using Laplace and collect the predictions - yhat = [glm_predictive_distribution(model.la, x_vec) for x_vec in X_vec] + yhat = [glm_predictive_distribution(la, x_vec) for x_vec in X_vec] return yhat end diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index 7b8680b..351d233 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -26,6 +26,7 @@ function basictest_regression(X, y, builder, optimiser, threshold) ) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) + println("ayyyyy") #println(boh) #println(cache) println(_report) From 506883b383b290ea9cef3d9c7edb67e20f4842ff Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Tue, 25 Jun 2024 21:07:42 +0200 Subject: [PATCH 20/32] trying to modify mljflux.update. Added Optimisers package to project, updated mljflux to 0.5.1 --- Project.toml | 3 +- src/mlj_flux.jl | 69 +++++++++++++++++++++++++++++++----- test/mlj_flux_interfacing.jl | 15 +++++--- 3 files changed, 74 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index a04ac41..8afa7be 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -30,7 +31,7 @@ Distributions = "0.25.109" Flux = "0.12, 0.13, 0.14" LinearAlgebra = "1.7, 1.10" MLJBase = "0, 1.4.0" -MLJFlux = "0.2.10, 0.3, 0.4" +MLJFlux = "0.2.10, 0.3, 0.4, 0.5.1" MLJModelInterface = "1.8.0" MLUtils = "0.4.3" ProgressMeter = "1.7.2" diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 2a87490..c92f1e4 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -237,12 +237,12 @@ Computes the fit result for a Laplace classification model, returning the model # Returns - A tuple containing: - - The model chain. + - The model. - The number of unique classes in the target data `y`. """ function MLJFlux.fitresult(model::LaplaceClassification, chain, y) println("fitresultclass") - return (chain, model.la, length(unique(y))) + return ( deepcopy(model), chain, length(unique(y))) end """ @@ -257,7 +257,7 @@ Computes the fit result for a Laplace Regression model, returning the model chai # Returns - A tuple containing: - - The model chain. + - The model. - The number of unique classes in the target data `y`. """ function MLJFlux.fitresult(model::LaplaceRegression, chain, y) @@ -267,7 +267,7 @@ function MLJFlux.fitresult(model::LaplaceRegression, chain, y) else target_column_names = Tables.schema(y).names end - return (chain, model.la, size(y)) + return (deepcopy(model),chain, size(y)) end function MLJFlux.fit!( @@ -343,7 +343,7 @@ An array of predicted class labels. """ function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) println("predictclass") - la = fitresult[2] + model = fitresult[1] Xnew = MLJBase.matrix(Xnew) #convert in a vector of vectors because Laplace ask to do so X_vec = X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] @@ -351,7 +351,7 @@ function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) # Predict using Laplace and collect the predictions predictions = [ LaplaceRedux.predict( - la, x; link_approx=model.link_approx, predict_proba=model.predict_proba + model.la, x; link_approx=model.link_approx, predict_proba=model.predict_proba ) for x in X_vec ] @@ -429,6 +429,59 @@ function MLJFlux.fit!( return (fitresult, report, cache) end +import Optimisers +function MLJFlux.update(model::LaplaceRegression, verbosity, old_fitresult, old_cache, X, y) + println("updatereg") + X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X + + old_model= old_fitresult[1] + old_chain = old_fitresult[2] + + + # Initialize history: + history = [] + verbose_laplace = false + # intitialize and start progress meter: + meter = Progress( + model.epochs + 1; + dt=1.0, + desc="Optimising neural net:", + barglyphs=BarGlyphs("[=> ]"), + barlen=25, + color=:yellow, + ) + verbosity != 1 || next!(meter) + + regularized_optimiser = MLJFlux.regularized_optimiser(model, length(y)) + optimiser_state = Optimisers.setup(regularized_optimiser, old_chain) + epochs = model.epochs - old_model.epochs + + for i in 1:(epochs) + println("inner loop") + current_loss = MLJFlux.train_epoch(model, old_chain, regularized_optimiser, optimiser_state, X, y) + verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" + verbosity != 1 || next!(meter) + push!(history, current_loss) + end + + # fit the Laplace model: + LaplaceRedux.fit!(old_model.la, zip(X, y)) + optimize_prior!(old_model.la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) + model.la = la + + cache = () + fitresult = MLJFlux.fitresult(model, Flux.cpu(old_chain), y) + + report = history + + #return cache, report + + return (fitresult, report, cache) +end + + + + """ predict(model::LaplaceRegression, Xnew) @@ -447,13 +500,13 @@ function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) println("predictregre") Xnew = MLJBase.matrix(Xnew) - la = fitresult[2] + model = fitresult[1] #convert in a vector of vectors because MLJ ask to do so X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] #inizialize output vector yhat yhat = [] # Predict using Laplace and collect the predictions - yhat = [glm_predictive_distribution(la, x_vec) for x_vec in X_vec] + yhat = [glm_predictive_distribution(model.la, x_vec) for x_vec in X_vec] return yhat end diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index 351d233..ec3e659 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -26,13 +26,12 @@ function basictest_regression(X, y, builder, optimiser, threshold) ) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - println("ayyyyy") #println(boh) - #println(cache) + println(fitresult) println(_report) history = _report.training_losses #println(fitresult) - @test length(history) == model.epochs + @test length(history) == model.epochs + 1 # test improvement in training loss: @test history[end] < threshold * history[1] @@ -40,6 +39,14 @@ function basictest_regression(X, y, builder, optimiser, threshold) # increase iterations and check update is incremental: model.epochs = model.epochs + 3 + fitresult, cache, _report = @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + @test :chain in keys(MLJBase.fitted_params(model, fitresult)) yhat = MLJBase.predict(model, fitresult, X) @@ -79,7 +86,7 @@ function basictest_classification(X, y, builder, optimiser, threshold) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) history = _report.training_losses - @test length(history) == model.epochs + @test length(history) == model.epochs + 1 # test improvement in training loss: @test history[end] < threshold * history[1] From 0886ce8d8df2e7a384ecdf3ccf18b66ee8344045 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Thu, 27 Jun 2024 18:23:17 +0200 Subject: [PATCH 21/32] passed to new api. update still missing and it seems there is an issue on how train_epoch handle the crossentropy function since the loss rise instead of going down --- Project.toml | 1 + src/baselaplace/predicting.jl | 2 +- src/mlj_flux.jl | 385 ++++++++++++++++------------------ test/mlj_flux_interfacing.jl | 38 +--- 4 files changed, 188 insertions(+), 238 deletions(-) diff --git a/Project.toml b/Project.toml index 8afa7be..eb09217 100644 --- a/Project.toml +++ b/Project.toml @@ -34,6 +34,7 @@ MLJBase = "0, 1.4.0" MLJFlux = "0.2.10, 0.3, 0.4, 0.5.1" MLJModelInterface = "1.8.0" MLUtils = "0.4.3" +Optimisers = "0.3.3" ProgressMeter = "1.7.2" Random = "1.7, 1.10" Statistics = "1" diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index 45f8eeb..155bd93 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -47,7 +47,7 @@ function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray) j in 1:size(fμ, 2) ] #normal_distr = [ - #Distributions.Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 1)] maybe this one is the correct one + #Distributions.Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 1)] maybe this one is the correct one return normal_distr end diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index c92f1e4..963b6ac 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -8,6 +8,7 @@ using LaplaceRedux using ComputationalResources using MLJBase: MLJBase import MLJModelInterface as MMI +using Optimisers: Optimisers """ MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic @@ -19,6 +20,7 @@ The model is trained using the `fit!` method. The model is defined by the follow - `builder`: a Flux model that constructs the neural network. - `optimiser`: a Flux optimiser. - `loss`: a loss function that takes the predicted output and the true output as arguments. +- `epochs`: the number of epochs. - `batch_size`: the size of a batch. - `lambda`: the regularization strength. - `alpha`: the regularization mix (0 for all l2, 1 for all l1). @@ -37,9 +39,9 @@ The model is trained using the `fit!` method. The model is defined by the follow """ MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) - optimiser = Flux.Optimise.Adam() + optimiser = Optimisers.Adam() loss = Flux.Losses.mse - epochs::Int = 10::(_ > 0) + epochs::Int = 100::(_ > 0) batch_size::Int = 1::(_ > 0) lambda::Float64 = 1.0 alpha::Float64 = 0.0 @@ -69,6 +71,7 @@ A mutable struct representing a Laplace Classification model that extends the ML - `finaliser`: a Flux model that processes the output of the neural network. - `optimiser`: a Flux optimiser. - `loss`: a loss function that takes the predicted output and the true output as arguments. +- `epochs`: the number of epochs. - `batch_size`: the size of a batch. - `lambda`: the regularization strength. - `alpha`: the regularization mix (0 for all l2, 1 for all l1). @@ -88,9 +91,9 @@ A mutable struct representing a Laplace Classification model that extends the ML MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) finaliser = Flux.softmax - optimiser = Flux.Optimise.Adam() + optimiser = Optimisers.Adam() loss = Flux.crossentropy - epochs::Int = 10::(_ > 0) + epochs::Int = 100::(_ > 0) batch_size::Int = 1::(_ > 0) lambda::Float64 = 1.0 alpha::Float64 = 0.0 @@ -113,29 +116,6 @@ end const MLJ_Laplace = Union{LaplaceClassification,LaplaceRegression} -""" - MLJFlux.shape(model::LaplaceClassification, X, y) - -Compute the the number of features of the dataset X and the number of unique classes in y. - -# Arguments -- `model::LaplaceClassification`: The LaplaceClassification model to fit. -- `X`: The input data for training. -- `y`: The target labels for training one-hot encoded. - -# Returns -- (input size, output size) -""" - -function MLJFlux.shape(model::LaplaceClassification, X, y) - println("shapeclass") - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - n_input = size(X, 2) - levels = unique(y) - n_output = length(levels) - return (n_input, n_output) -end - """ MLJFlux.shape(model::LaplaceRegression, X, y) @@ -150,7 +130,6 @@ Compute the the number of features of the X input dataset and the number of var - (input size, output size) """ function MLJFlux.shape(model::LaplaceRegression, X, y) - println("shapereg") X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X n_input = size(X, 2) dims = size(y) @@ -162,37 +141,6 @@ function MLJFlux.shape(model::LaplaceRegression, X, y) return (n_input, n_output) end -""" - MLJFlux.build(model::LaplaceClassification, rng, shape) - -Builds an MLJFlux model for Laplace classification compatible with the dimensions of the input and output layers specified by `shape`. - -# Arguments -- `model::LaplaceClassification`: The Laplace classification model. -- `rng`: A random number generator to ensure reproducibility. -- `shape`: A tuple or array specifying the dimensions of the input and output layers. - -# Returns -- The constructed MLJFlux model, compatible with the specified input and output dimensions. -""" -function MLJFlux.build(model::LaplaceClassification, rng, shape) - println("buildclass") - chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) - model.la = Laplace( - chain; - likelihood=:classification, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀, - ) - - return chain -end - """ MLJFlux.build(model::LaplaceRegression, rng, shape) @@ -207,8 +155,6 @@ Builds an MLJFlux model for Laplace regression compatible with the dimensions of - The constructed MLJFlux model, compatible with the specified input and output dimensions. """ function MLJFlux.build(model::LaplaceRegression, rng, shape) - println("buildreg") - chain = MLJFlux.build(model.builder, rng, shape...) model.la = Laplace( chain; @@ -225,26 +171,6 @@ function MLJFlux.build(model::LaplaceRegression, rng, shape) return chain end -""" - MLJFlux.fitresult(model::LaplaceClassification, chain, y) - -Computes the fit result for a Laplace classification model, returning the model chain and the number of unique classes in the target data. - -# Arguments -- `model::LaplaceClassification`: The Laplace classification model to be evaluated. -- `chain`: The trained model chain. -- `y`: The target data, typically a vector of class labels. - -# Returns -- A tuple containing: - - The model. - - The number of unique classes in the target data `y`. -""" -function MLJFlux.fitresult(model::LaplaceClassification, chain, y) - println("fitresultclass") - return ( deepcopy(model), chain, length(unique(y))) -end - """ MLJFlux.fitresult(model::LaplaceRegression, chain, y) @@ -258,33 +184,53 @@ Computes the fit result for a Laplace Regression model, returning the model chai # Returns - A tuple containing: - The model. + - The trained Flux chain. - The number of unique classes in the target data `y`. """ function MLJFlux.fitresult(model::LaplaceRegression, chain, y) - println("fitresultregre") if y isa AbstractArray target_column_names = nothing else target_column_names = Tables.schema(y).names end - return (deepcopy(model),chain, size(y)) + return (deepcopy(model), chain) end -function MLJFlux.fit!( - model::LaplaceClassification, penalty, chain, optimiser, epochs, verbosity, X, y -) - println("fitclass") - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X +""" + MLJFlux.train(model::LaplaceRegression, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y) - #X = MLJBase.matrix(X) +Fit the LaplaceRegression model using Flux.jl. - #shape = MLJFlux.shape(model, X, y) +# Arguments +- `model::LaplaceRegression`: The LaplaceRegression model. +- `regularized_optimiser`: the regularized optimiser to apply to the loss function. +- `optimiser_state`: thestate of the optimiser. +- `epochs`: The number of epochs for training. +- `verbosity`: The verbosity level for training. +- `X`: The input data for training. +- `y`: The target labels for training. - #chain = MLJFlux.build(model, model.rng, shape) +# Returns (fitresult, cache, report ) +where +- `fitresult`: is the output of MLJFlux.fitresult. +- `cache`: an empty tuple. +- `report`: a named tuple that contain the field training_losses. +""" +function MLJFlux.train( + model::LaplaceRegression, + chain, + regularized_optimiser, + optimiser_state, + epochs, + verbosity, + X, + y, +) + X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X la = LaplaceRedux.Laplace( chain; - likelihood=:classification, + likelihood=:regression, subset_of_weights=model.subset_of_weights, subnetwork_indices=model.subnetwork_indices, hessian_structure=model.hessian_structure, @@ -293,24 +239,28 @@ function MLJFlux.fit!( μ₀=model.μ₀, P₀=model.P₀, ) - verbose_laplace = false # Initialize history: history = [] + verbose_laplace = false # intitialize and start progress meter: meter = Progress( - model.epochs + 1; - dt=0, + epochs + 1; + dt=1.0, desc="Optimising neural net:", barglyphs=BarGlyphs("[=> ]"), barlen=25, color=:yellow, ) + verbosity != 1 || next!(meter) + for i in 1:epochs - current_loss = MLJFlux.train!(model, penalty, chain, optimiser, X, y) + chain, optimiser_state, current_loss = MLJFlux.train_epoch( + model, chain, regularized_optimiser, optimiser_state, X, y + ) verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" verbosity != 1 || next!(meter) - append!(history, current_loss) + push!(history, current_loss) end # fit the Laplace model: @@ -323,67 +273,145 @@ function MLJFlux.fit!( report = history - #return cache, report - - return (fitresult, report, cache) + return (fitresult, cache, report) end + """ - predict(model::LaplaceClassification, Xnew) + predict(model::LaplaceRegression, Xnew) -Predicts the class labels for new data using the LaplaceClassification model. +Predict the output for new input data using a Laplace regression model. # Arguments -- `model::LaplaceClassification`: The trained LaplaceClassification model. -- fitresult: the fitresult output produced by MLJFlux.fit! -- `Xnew`: The new data to make predictions on. +- `model::LaplaceRegression`: The trained Laplace regression model. +- the fitresult output produced by MLJFlux.fit! +- `Xnew`: The new input data. # Returns -An array of predicted class labels. +- The predicted output for the new input data. """ -function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) - println("predictclass") - model = fitresult[1] +function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) Xnew = MLJBase.matrix(Xnew) - #convert in a vector of vectors because Laplace ask to do so - X_vec = X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] + model = fitresult[1] + #convert in a vector of vectors because MLJ ask to do so + X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] + #inizialize output vector yhat + yhat = [] # Predict using Laplace and collect the predictions - predictions = [ - LaplaceRedux.predict( - model.la, x; link_approx=model.link_approx, predict_proba=model.predict_proba - ) for x in X_vec - ] + yhat = [glm_predictive_distribution(model.la, x_vec) for x_vec in X_vec] - return predictions + return yhat +end + +""" + MLJFlux.shape(model::LaplaceClassification, X, y) + +Compute the the number of features of the dataset X and the number of unique classes in y. + +# Arguments +- `model::LaplaceClassification`: The LaplaceClassification model to fit. +- `X`: The input data for training. +- `y`: The target labels for training one-hot encoded. + +# Returns +- (input size, output size) +""" + +function MLJFlux.shape(model::LaplaceClassification, X, y) + X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X + n_input = size(X, 2) + levels = unique(y) + n_output = length(levels) + println(n_output) + return (n_input, n_output) +end + +""" + MLJFlux.build(model::LaplaceClassification, rng, shape) + +Builds an MLJFlux model for Laplace classification compatible with the dimensions of the input and output layers specified by `shape`. + +# Arguments +- `model::LaplaceClassification`: The Laplace classification model. +- `rng`: A random number generator to ensure reproducibility. +- `shape`: A tuple or array specifying the dimensions of the input and output layers. + +# Returns +- The constructed MLJFlux model, compatible with the specified input and output dimensions. +""" +function MLJFlux.build(model::LaplaceClassification, rng, shape) + chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) + println(chain) + model.la = Laplace( + chain; + likelihood=:classification, + subset_of_weights=model.subset_of_weights, + subnetwork_indices=model.subnetwork_indices, + hessian_structure=model.hessian_structure, + backend=model.backend, + σ=model.σ, + μ₀=model.μ₀, + P₀=model.P₀, + ) + + return chain +end + +""" + MLJFlux.fitresult(model::LaplaceClassification, chain, y) + +Computes the fit result for a Laplace classification model, returning the model chain and the number of unique classes in the target data. + +# Arguments +- `model::LaplaceClassification`: The Laplace classification model to be evaluated. +- `chain`: The trained model chain. +- `y`: The target data, typically a vector of class labels. + +# Returns +- A tuple containing: + - The model. + - The number of unique classes in the target data `y`. +""" +function MLJFlux.fitresult(model::LaplaceClassification, chain, y) + return (deepcopy(model), chain, length(unique(y))) end """ - MLJFlux.fit!(model::LaplaceRegression, penalty, chain, epochs, batch_size, optimiser, verbosity, X, y) + MLJFlux.train(model::LaplaceClassification, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y) Fit the LaplaceRegression model using Flux.jl. # Arguments -- `model::LaplaceRegression`: The LaplaceRegression model. +- `model::LaplaceClassification`: The LaplaceClassification model. +- `regularized_optimiser`: the regularized optimiser to apply to the loss function. +- `optimiser_state`: thestate of the optimiser. +- `epochs`: The number of epochs for training. - `verbosity`: The verbosity level for training. - `X`: The input data for training. - `y`: The target labels for training. -# Returns -- -where la is the fitted Laplace model. +# Returns (fitresult, cache, report ) +where +- `fitresult`: is the output of MLJFlux.fitresult. +- `cache`: an empty tuple. +- `report`: a named tuple that contain the field training_losses. """ -#model::LaplaceRegression, penalty, chain, optimiser, epochs, verbosity, X, y -function MLJFlux.fit!( - model::LaplaceRegression, penalty, chain, optimiser, epochs, verbosity, X, y +function MLJFlux.train( + model::LaplaceClassification, + chain, + regularized_optimiser, + optimiser_state, + epochs, + verbosity, + X, + y, ) - println("fitregre") - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X la = LaplaceRedux.Laplace( chain; - likelihood=:regression, + likelihood=:classification, subset_of_weights=model.subset_of_weights, subnetwork_indices=model.subnetwork_indices, hessian_structure=model.hessian_structure, @@ -392,23 +420,23 @@ function MLJFlux.fit!( μ₀=model.μ₀, P₀=model.P₀, ) + verbose_laplace = false # Initialize history: history = [] - verbose_laplace = false # intitialize and start progress meter: meter = Progress( - epochs + 1; - dt=1.0, + model.epochs + 1; + dt=0, desc="Optimising neural net:", barglyphs=BarGlyphs("[=> ]"), barlen=25, color=:yellow, ) - verbosity != 1 || next!(meter) - for i in 1:epochs - current_loss = MLJFlux.train!(model, penalty, chain, optimiser, X, y) + chain, optimiser_state, current_loss = MLJFlux.train_epoch( + model, chain, regularized_optimiser, optimiser_state, X, y + ) verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" verbosity != 1 || next!(meter) push!(history, current_loss) @@ -424,91 +452,38 @@ function MLJFlux.fit!( report = history - #return cache, report - - return (fitresult, report, cache) + return (fitresult, cache, report) end -import Optimisers -function MLJFlux.update(model::LaplaceRegression, verbosity, old_fitresult, old_cache, X, y) - println("updatereg") - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - - old_model= old_fitresult[1] - old_chain = old_fitresult[2] - - - # Initialize history: - history = [] - verbose_laplace = false - # intitialize and start progress meter: - meter = Progress( - model.epochs + 1; - dt=1.0, - desc="Optimising neural net:", - barglyphs=BarGlyphs("[=> ]"), - barlen=25, - color=:yellow, - ) - verbosity != 1 || next!(meter) - - regularized_optimiser = MLJFlux.regularized_optimiser(model, length(y)) - optimiser_state = Optimisers.setup(regularized_optimiser, old_chain) - epochs = model.epochs - old_model.epochs - - for i in 1:(epochs) - println("inner loop") - current_loss = MLJFlux.train_epoch(model, old_chain, regularized_optimiser, optimiser_state, X, y) - verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" - verbosity != 1 || next!(meter) - push!(history, current_loss) - end - - # fit the Laplace model: - LaplaceRedux.fit!(old_model.la, zip(X, y)) - optimize_prior!(old_model.la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - model.la = la - - cache = () - fitresult = MLJFlux.fitresult(model, Flux.cpu(old_chain), y) - - report = history - - #return cache, report - - return (fitresult, report, cache) -end - - - - """ - predict(model::LaplaceRegression, Xnew) + predict(model::LaplaceClassification, Xnew) -Predict the output for new input data using a Laplace regression model. +Predicts the class labels for new data using the LaplaceClassification model. # Arguments -- `model::LaplaceRegression`: The trained Laplace regression model. -- the fitresult output produced by MLJFlux.fit! -- `Xnew`: The new input data. +- `model::LaplaceClassification`: The trained LaplaceClassification model. +- fitresult: the fitresult output produced by MLJFlux.fit! +- `Xnew`: The new data to make predictions on. # Returns -- The predicted output for the new input data. +An array of predicted class labels. """ -function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) - println("predictregre") +function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) + println("predictclass") + model = fitresult[1] Xnew = MLJBase.matrix(Xnew) + #convert in a vector of vectors because Laplace ask to do so + X_vec = X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] - model = fitresult[1] - #convert in a vector of vectors because MLJ ask to do so - X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] - #inizialize output vector yhat - yhat = [] # Predict using Laplace and collect the predictions - yhat = [glm_predictive_distribution(model.la, x_vec) for x_vec in X_vec] + predictions = [ + LaplaceRedux.predict( + model.la, x; link_approx=model.link_approx, predict_proba=model.predict_proba + ) for x in X_vec + ] - return yhat + return predictions end # metadata for each model, diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index ec3e659..ab2ec95 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -26,32 +26,19 @@ function basictest_regression(X, y, builder, optimiser, threshold) ) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - #println(boh) - println(fitresult) - println(_report) history = _report.training_losses - #println(fitresult) - @test length(history) == model.epochs + 1 + + @test length(history) == model.epochs # test improvement in training loss: @test history[end] < threshold * history[1] - # increase iterations and check update is incremental: - model.epochs = model.epochs + 3 - - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - @test :chain in keys(MLJBase.fitted_params(model, fitresult)) yhat = MLJBase.predict(model, fitresult, X) history = _report.training_losses + #@test length(history) == model.epochs # start fresh with small epochs: model = LaplaceRegression(; @@ -86,19 +73,17 @@ function basictest_classification(X, y, builder, optimiser, threshold) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) history = _report.training_losses - @test length(history) == model.epochs + 1 + @test length(history) == model.epochs # test improvement in training loss: @test history[end] < threshold * history[1] - # increase iterations and check update is incremental: - model.epochs = model.epochs + 3 - @test :chain in keys(MLJBase.fitted_params(model, fitresult)) yhat = MLJBase.predict(model, fitresult, X) - history = _report.training_losses + #history = _report.training_losses + @test length(history) == model.epochs # start fresh with small epochs: model = LaplaceClassification(; @@ -107,17 +92,6 @@ function basictest_classification(X, y, builder, optimiser, threshold) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - # change batch_size and check it performs cold restart: - model.batch_size = 2 - - # change learning rate and check it does *not* restart: - model.optimiser.eta /= 2 - - # set `optimiser_changes_trigger_retraining = true` and change - # learning rate and check it does restart: - model.optimiser_changes_trigger_retraining = true - model.optimiser.eta /= 2 - return true end From 63cef6234f37aeb43b8cd28b94f148bb37319806 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sat, 29 Jun 2024 19:00:37 +0200 Subject: [PATCH 22/32] attempted mljflux.update for regression --- src/mlj_flux.jl | 140 +++++++++++++++++++++++++++++++++-- test/mlj_flux_interfacing.jl | 37 ++++++++- 2 files changed, 167 insertions(+), 10 deletions(-) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 963b6ac..dd9e4d0 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -41,7 +41,7 @@ MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilis builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) optimiser = Optimisers.Adam() loss = Flux.Losses.mse - epochs::Int = 100::(_ > 0) + epochs::Int = 10::(_ > 0) batch_size::Int = 1::(_ > 0) lambda::Float64 = 1.0 alpha::Float64 = 0.0 @@ -98,7 +98,7 @@ MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbab lambda::Float64 = 1.0 alpha::Float64 = 0.0 rng::Union{AbstractRNG,Int64} = Random.GLOBAL_RNG - optimiser_changes_trigger_retraining::Bool = false + optimiser_changes_trigger_retraining::Bool = true::(_ in (true, false)) acceleration = CPU1()::(_ in (CPU1(), CUDALibs())) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([]) @@ -193,7 +193,7 @@ function MLJFlux.fitresult(model::LaplaceRegression, chain, y) else target_column_names = Tables.schema(y).names end - return (deepcopy(model), chain) + return (chain, deepcopy(model)) end """ @@ -268,7 +268,20 @@ function MLJFlux.train( optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) model.la = la - cache = () + shape = MLJFlux.shape(model, X, y) + move = MLJFlux.Mover(model.acceleration) + + cache = ( + deepcopy(model), + zip(X, y), + history, + shape, + regularized_optimiser, + optimiser_state, + deepcopy(model.rng), + move, + ) + fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) report = history @@ -293,7 +306,7 @@ Predict the output for new input data using a Laplace regression model. function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) Xnew = MLJBase.matrix(Xnew) - model = fitresult[1] + model = fitresult[2] #convert in a vector of vectors because MLJ ask to do so X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] #inizialize output vector yhat @@ -304,6 +317,122 @@ function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) return yhat end +function _isdefined(object, name) + pnames = propertynames(object) + fnames = fieldnames(typeof(object)) + name in pnames && !(name in fnames) && return true + return isdefined(object, name) +end + +function _equal_to_depth_one(x1, x2) + names = propertynames(x1) + names === propertynames(x2) || return false + for name in names + getproperty(x1, name) == getproperty(x2, name) || return false + end + return true +end + +function MMI.is_same_except( + m1::M1, m2::M2, exceptions::Symbol... +) where {M1<:LaplaceRegression,M2<:LaplaceRegression} + typeof(m1) === typeof(m2) || return false + names = propertynames(m1) + propertynames(m2) === names || return false + + for name in names + if !(name in exceptions) && name != :la + if !_isdefined(m1, name) + !_isdefined(m2, name) || return false + elseif _isdefined(m2, name) + if name in MLJFlux.deep_properties(M1) + _equal_to_depth_one(getproperty(m1, name), getproperty(m2, name)) || + return false + else + ( + MMI.is_same_except(getproperty(m1, name), getproperty(m2, name)) || + getproperty(m1, name) isa AbstractRNG || + getproperty(m2, name) isa AbstractRNG + ) || return false + end + else + return false + end + end + end + return true +end + +function MLJFlux.update(model::LaplaceRegression, verbosity, old_fitresult, old_cache, X, y) + println("test update") + X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X + + old_model, data, old_history, shape, regularized_optimiser, optimiser_state, rng, move = + old_cache + old_chain = old_fitresult[1] + + optimiser_flag = + model.optimiser_changes_trigger_retraining && model.optimiser != old_model.optimiser + + keep_chain = + !optimiser_flag && + model.epochs >= old_model.epochs && + MMI.is_same_except(model, old_model, :optimiser, :epochs) + + println(old_chain[1]) + + println(model.optimiser_changes_trigger_retraining) + println(model.optimiser) + println(old_model.optimiser) + + println(model.optimiser != old_model.optimiser) + println(old_model.epochs) + println(model.epochs) + println(optimiser_flag) + println(MMI.is_same_except(model, old_model, :optimiser, :epochs)) + + println(keep_chain) + + if keep_chain + chain = move(old_chain[1]) + epochs = model.epochs - old_model.epochs + # (`optimiser_state` is not reset) + else + move = MLJFlux.Mover(model.acceleration) + rng = model.rng + chain = MLJFlux.build(model, rng, shape) |> move + # reset `optimiser_state`: + #data = move.(MLJFlux.collate(model, X, y)) + nbatches = length(data[2]) + regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) + optimiser_state = Optimisers.setup(regularized_optimiser, chain) + epochs = model.epochs + end + + chain, optimiser_state, history = MLJFlux.train( + model, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y + ) + if keep_chain + # note: history[1] = old_history[end] + history = vcat(old_history[1:(end - 1)], history) + end + + fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) + cache = ( + deepcopy(model), + data, + history, + shape, + regularized_optimiser, + optimiser_state, + deepcopy(rng), + move, + ) + report = (training_losses=history,) + + return fitresult, cache, report +end + """ MLJFlux.shape(model::LaplaceClassification, X, y) @@ -470,7 +599,6 @@ An array of predicted class labels. """ function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) - println("predictclass") model = fitresult[1] Xnew = MLJBase.matrix(Xnew) #convert in a vector of vectors because Laplace ask to do so diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index ab2ec95..fed9729 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -30,15 +30,22 @@ function basictest_regression(X, y, builder, optimiser, threshold) @test length(history) == model.epochs - # test improvement in training loss: - @test history[end] < threshold * history[1] + # increase iterations and check update is incremental: + model.epochs = model.epochs + 3 + + fitresult, cache, _report = @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) @test :chain in keys(MLJBase.fitted_params(model, fitresult)) yhat = MLJBase.predict(model, fitresult, X) history = _report.training_losses - #@test length(history) == model.epochs + @test length(history) == model.epochs # start fresh with small epochs: model = LaplaceRegression(; @@ -47,6 +54,28 @@ function basictest_regression(X, y, builder, optimiser, threshold) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) + # change batch_size and check it performs cold restart: + model.batch_size = 2 + fitresult, cache, _report = @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + # change learning rate and check it does *not* restart: + model.optimiser.eta /= 2 + fitresult, cache, _report = @test_logs(MLJBase.update(model, 2, fitresult, cache, X, y)) + + # set `optimiser_changes_trigger_retraining = true` and change + # learning rate and check it does restart: + model.optimiser_changes_trigger_retraining = true + model.optimiser.eta /= 2 + @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + return true end @@ -118,4 +147,4 @@ optimizer = Flux.Optimise.Adam(0.03) @test basictest_regression(X, ycont, builder, optimizer, 0.9) -@test basictest_classification(X, y, builder, optimizer, 0.9) +#@test basictest_classification(X, y, builder, optimizer, 0.9) From af6c2103f02e4d9f0338e91f0d846c8d84a6520f Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sun, 30 Jun 2024 01:12:37 +0200 Subject: [PATCH 23/32] boh i don't know why it doesn't work. --- src/mlj_flux.jl | 74 +++++++----------------------------- test/mlj_flux_interfacing.jl | 2 +- 2 files changed, 14 insertions(+), 62 deletions(-) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index dd9e4d0..e32938d 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -188,6 +188,7 @@ Computes the fit result for a Laplace Regression model, returning the model chai - The number of unique classes in the target data `y`. """ function MLJFlux.fitresult(model::LaplaceRegression, chain, y) + println("fitresult function") if y isa AbstractArray target_column_names = nothing else @@ -226,6 +227,7 @@ function MLJFlux.train( X, y, ) + println("train function") X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X la = LaplaceRedux.Laplace( @@ -273,7 +275,6 @@ function MLJFlux.train( cache = ( deepcopy(model), - zip(X, y), history, shape, regularized_optimiser, @@ -317,57 +318,11 @@ function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) return yhat end -function _isdefined(object, name) - pnames = propertynames(object) - fnames = fieldnames(typeof(object)) - name in pnames && !(name in fnames) && return true - return isdefined(object, name) -end - -function _equal_to_depth_one(x1, x2) - names = propertynames(x1) - names === propertynames(x2) || return false - for name in names - getproperty(x1, name) == getproperty(x2, name) || return false - end - return true -end - -function MMI.is_same_except( - m1::M1, m2::M2, exceptions::Symbol... -) where {M1<:LaplaceRegression,M2<:LaplaceRegression} - typeof(m1) === typeof(m2) || return false - names = propertynames(m1) - propertynames(m2) === names || return false - - for name in names - if !(name in exceptions) && name != :la - if !_isdefined(m1, name) - !_isdefined(m2, name) || return false - elseif _isdefined(m2, name) - if name in MLJFlux.deep_properties(M1) - _equal_to_depth_one(getproperty(m1, name), getproperty(m2, name)) || - return false - else - ( - MMI.is_same_except(getproperty(m1, name), getproperty(m2, name)) || - getproperty(m1, name) isa AbstractRNG || - getproperty(m2, name) isa AbstractRNG - ) || return false - end - else - return false - end - end - end - return true -end - function MLJFlux.update(model::LaplaceRegression, verbosity, old_fitresult, old_cache, X, y) println("test update") X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - old_model, data, old_history, shape, regularized_optimiser, optimiser_state, rng, move = + old_model, old_history, shape, regularized_optimiser, optimiser_state, rng, move = old_cache old_chain = old_fitresult[1] @@ -381,16 +336,6 @@ function MLJFlux.update(model::LaplaceRegression, verbosity, old_fitresult, old_ println(old_chain[1]) - println(model.optimiser_changes_trigger_retraining) - println(model.optimiser) - println(old_model.optimiser) - - println(model.optimiser != old_model.optimiser) - println(old_model.epochs) - println(model.epochs) - println(optimiser_flag) - println(MMI.is_same_except(model, old_model, :optimiser, :epochs)) - println(keep_chain) if keep_chain @@ -400,14 +345,18 @@ function MLJFlux.update(model::LaplaceRegression, verbosity, old_fitresult, old_ else move = MLJFlux.Mover(model.acceleration) rng = model.rng - chain = MLJFlux.build(model, rng, shape) |> move + shape = MLJFlux.shape(model, X, y) + println(shape) + chain = MLJFlux.build(model, rng, shape) #|> move + println(chain) # reset `optimiser_state`: #data = move.(MLJFlux.collate(model, X, y)) - nbatches = length(data[2]) + nbatches = length(y) regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) optimiser_state = Optimisers.setup(regularized_optimiser, chain) epochs = model.epochs end + println("after if") chain, optimiser_state, history = MLJFlux.train( model, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y @@ -417,10 +366,13 @@ function MLJFlux.update(model::LaplaceRegression, verbosity, old_fitresult, old_ history = vcat(old_history[1:(end - 1)], history) end + println("after train") + fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) cache = ( deepcopy(model), - data, + X, + y, history, shape, regularized_optimiser, diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index fed9729..a3269d8 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -147,4 +147,4 @@ optimizer = Flux.Optimise.Adam(0.03) @test basictest_regression(X, ycont, builder, optimizer, 0.9) -#@test basictest_classification(X, y, builder, optimizer, 0.9) +@test basictest_classification(X, y, builder, optimizer, 0.9) From 838966032ed289da130e555e6ab7b6fd77f27208 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 1 Jul 2024 08:45:43 +0200 Subject: [PATCH 24/32] not sure we actually need to overload the update function --- src/mlj_flux.jl | 136 +++++++++++++++++------------------ test/Manifest.toml | 94 ++++++++++++------------ test/mlj_flux_interfacing.jl | 76 ++++++++++++++++---- 3 files changed, 177 insertions(+), 129 deletions(-) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index e32938d..954d01b 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -188,7 +188,6 @@ Computes the fit result for a Laplace Regression model, returning the model chai - The number of unique classes in the target data `y`. """ function MLJFlux.fitresult(model::LaplaceRegression, chain, y) - println("fitresult function") if y isa AbstractArray target_column_names = nothing else @@ -227,7 +226,6 @@ function MLJFlux.train( X, y, ) - println("train function") X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X la = LaplaceRedux.Laplace( @@ -318,72 +316,72 @@ function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) return yhat end -function MLJFlux.update(model::LaplaceRegression, verbosity, old_fitresult, old_cache, X, y) - println("test update") - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - - old_model, old_history, shape, regularized_optimiser, optimiser_state, rng, move = - old_cache - old_chain = old_fitresult[1] - - optimiser_flag = - model.optimiser_changes_trigger_retraining && model.optimiser != old_model.optimiser - - keep_chain = - !optimiser_flag && - model.epochs >= old_model.epochs && - MMI.is_same_except(model, old_model, :optimiser, :epochs) - - println(old_chain[1]) - - println(keep_chain) - - if keep_chain - chain = move(old_chain[1]) - epochs = model.epochs - old_model.epochs - # (`optimiser_state` is not reset) - else - move = MLJFlux.Mover(model.acceleration) - rng = model.rng - shape = MLJFlux.shape(model, X, y) - println(shape) - chain = MLJFlux.build(model, rng, shape) #|> move - println(chain) - # reset `optimiser_state`: - #data = move.(MLJFlux.collate(model, X, y)) - nbatches = length(y) - regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) - optimiser_state = Optimisers.setup(regularized_optimiser, chain) - epochs = model.epochs - end - println("after if") - - chain, optimiser_state, history = MLJFlux.train( - model, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y - ) - if keep_chain - # note: history[1] = old_history[end] - history = vcat(old_history[1:(end - 1)], history) - end - - println("after train") - - fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) - cache = ( - deepcopy(model), - X, - y, - history, - shape, - regularized_optimiser, - optimiser_state, - deepcopy(rng), - move, - ) - report = (training_losses=history,) - - return fitresult, cache, report -end +# function MLJFlux.update(model::LaplaceRegression, verbosity, old_fitresult, old_cache, X, y) +# println("test update") +# X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X + +# old_model, old_history, shape, regularized_optimiser, optimiser_state, rng, move = +# old_cache +# old_chain = old_fitresult[1] + +# optimiser_flag = +# model.optimiser_changes_trigger_retraining && model.optimiser != old_model.optimiser + +# keep_chain = +# !optimiser_flag && +# model.epochs >= old_model.epochs && +# MMI.is_same_except(model, old_model, :optimiser, :epochs) + +# println(old_chain[1]) + +# println(keep_chain) + +# if keep_chain +# chain = move(old_chain[1]) +# epochs = model.epochs - old_model.epochs +# # (`optimiser_state` is not reset) +# else +# move = MLJFlux.Mover(model.acceleration) +# rng = model.rng +# shape = MLJFlux.shape(model, X, y) +# println(shape) +# chain = MLJFlux.build(model, rng, shape) #|> move +# println(chain) +# # reset `optimiser_state`: +# #data = move.(MLJFlux.collate(model, X, y)) +# nbatches = length(y) +# regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) +# optimiser_state = Optimisers.setup(regularized_optimiser, chain) +# epochs = model.epochs +# end +# println("after if") + +# chain, optimiser_state, history = MLJFlux.train( +# model, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y +# ) +# if keep_chain +# # note: history[1] = old_history[end] +# history = vcat(old_history[1:(end - 1)], history) +# end + +# println("after train") + +# fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) +# cache = ( +# deepcopy(model), +# X, +# y, +# history, +# shape, +# regularized_optimiser, +# optimiser_state, +# deepcopy(rng), +# move, +# ) +# report = (training_losses=history,) + +# return fitresult, cache, report +# end """ MLJFlux.shape(model::LaplaceClassification, X, y) @@ -404,7 +402,6 @@ function MLJFlux.shape(model::LaplaceClassification, X, y) n_input = size(X, 2) levels = unique(y) n_output = length(levels) - println(n_output) return (n_input, n_output) end @@ -423,7 +420,6 @@ Builds an MLJFlux model for Laplace classification compatible with the dimension """ function MLJFlux.build(model::LaplaceClassification, rng, shape) chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) - println(chain) model.la = Laplace( chain; likelihood=:classification, diff --git a/test/Manifest.toml b/test/Manifest.toml index 32e8eb0..bf3e5b9 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.3" +julia_version = "1.10.4" manifest_format = "2.0" project_hash = "fa6672850323ab23f77b8212aabdf7f033fa4213" @@ -118,9 +118,9 @@ uuid = "9718e550-a3fa-408a-8086-8db961cd8217" version = "0.1.1" [[deps.BitFlags]] -git-tree-sha1 = "2dc09997850d68179b69dafb58ae806167a32b1b" +git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" -version = "0.1.8" +version = "0.1.9" [[deps.BufferedStreams]] git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" @@ -523,9 +523,9 @@ version = "0.8.5" [[deps.Flux]] deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "a5475163b611812d073171583982c42ea48d22b0" +git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.15" +version = "0.14.16" [deps.Flux.extensions] FluxAMDGPUExt = "AMDGPU" @@ -590,9 +590,9 @@ version = "3.3.9+0" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "c154546e322a9c73364e8a60430b0f79b812d320" +git-tree-sha1 = "5c9de6d5af87acd2cf719e214ed7d51e14017b7a" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.2.0" +version = "10.2.2" [[deps.GPUArraysCore]] deps = ["Adapt"] @@ -602,15 +602,15 @@ version = "0.1.6" [[deps.GR]] deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"] -git-tree-sha1 = "ddda044ca260ee324c5fc07edb6d7cf3f0b9c350" +git-tree-sha1 = "3e527447a45901ea392fe12120783ad6ec222803" uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.73.5" +version = "0.73.6" [[deps.GR_jll]] deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "278e5e0f820178e8a26df3184fcb2280717c79b1" +git-tree-sha1 = "182c478a179b267dd7a741b6f8f4c3e0803795d6" uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.73.5+0" +version = "0.73.6+0" [[deps.GZip]] deps = ["Libdl", "Zlib_jll"] @@ -678,9 +678,9 @@ version = "2.8.1+1" [[deps.Hwloc_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "ca0f6bf568b4bfc807e7537f081c81e35ceca114" +git-tree-sha1 = "1d334207121865ac8c1c97eb7f42d0339e4635bf" uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" -version = "2.10.0+0" +version = "2.11.0+0" [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] @@ -719,9 +719,15 @@ version = "0.3.1" [[deps.InlineStrings]] deps = ["Parsers"] -git-tree-sha1 = "9cc2baf75c6d09f9da536ddf58eb2f29dedaf461" +git-tree-sha1 = "86356004f30f8e737eff143d57d41bd580e437aa" uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -version = "1.4.0" +version = "1.4.1" + + [deps.InlineStrings.extensions] + ArrowTypesExt = "ArrowTypes" + + [deps.InlineStrings.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" [[deps.InteractiveUtils]] deps = ["Markdown"] @@ -798,9 +804,9 @@ version = "0.2.4" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "8e5a339882cc401688d79b811d923a38ba77d50a" +git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.20" +version = "0.9.22" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" @@ -822,9 +828,9 @@ version = "3.0.0+1" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "389aea28d882a40b5e1747069af71bdbd47a1cae" +git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "7.2.1" +version = "8.0.0" weakdeps = ["BFloat16s"] [deps.LLVM.extensions] @@ -832,9 +838,9 @@ weakdeps = ["BFloat16s"] [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" +git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.29+0" +version = "0.0.30+0" [[deps.LLVMOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -1000,9 +1006,9 @@ version = "0.7.14" [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "24e5d28b2ea86b3feb6af5a5735f012d62e27b65" +git-tree-sha1 = "b81fe8aaf3a253d76d915ab6d6f749ab9c9973f6" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -version = "1.4.0" +version = "1.5.0" [deps.MLJBase.extensions] DefaultMeasuresExt = "StatisticalMeasures" @@ -1024,9 +1030,9 @@ version = "0.5.1" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] -git-tree-sha1 = "88ef480f46e0506143681b3fb14d86742f3cecb1" +git-tree-sha1 = "ceaff6618408d0e412619321ae43b33b40c1a733" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -version = "1.10.0" +version = "1.11.0" [[deps.MLJModels]] deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] @@ -1145,9 +1151,9 @@ version = "0.10.3" [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "3d4617f943afe6410206a5294a95948c8d1b35bd" +git-tree-sha1 = "78de319bce99d1d8c1d4fe5401f7cfc2627df396" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.17" +version = "0.9.18" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" @@ -1413,9 +1419,9 @@ version = "1.2.0" [[deps.Qt6Base_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"] -git-tree-sha1 = "37b7bb7aabf9a085e0044307e1717436117f2b3b" +git-tree-sha1 = "492601870742dcd38f233b23c3ec629628c1d724" uuid = "c0090381-4147-56d7-9ebc-da0b1113ec56" -version = "6.5.3+1" +version = "6.7.1+1" [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] @@ -1592,9 +1598,9 @@ version = "0.1.1" [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "6e00379a24597be4ae1ee6b2d882e15392040132" +git-tree-sha1 = "20833c5b7f7edf0e5026f23db7f268e4f23ec577" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.5" +version = "1.9.6" weakdeps = ["ChainRulesCore", "Statistics"] [deps.StaticArrays.extensions] @@ -1614,9 +1620,9 @@ version = "0.1.1" [[deps.StatisticalTraits]] deps = ["ScientificTypesBase"] -git-tree-sha1 = "983c41a0ddd6c19f5607ca87271d7c7620ab5d50" +git-tree-sha1 = "542d979f6e756f13f862aa00b224f04f9e445f11" uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9" -version = "3.3.0" +version = "3.4.0" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -1746,9 +1752,9 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.TranscodingStreams]] -git-tree-sha1 = "a947ea21087caba0a798c5e494d0bb78e3a1a3a0" +git-tree-sha1 = "d73336d81cafdc277ff45558bb7eaa2b04a8e472" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.10.9" +version = "0.10.10" weakdeps = ["Random", "Test"] [deps.TranscodingStreams.extensions] @@ -1848,9 +1854,9 @@ version = "0.2.1" [[deps.UnsafeAtomicsLLVM]] deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "d9f5962fecd5ccece07db1ff006fb0b5271bdfdd" +git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.4" +version = "0.1.5" [[deps.Unzip]] git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78" @@ -1888,15 +1894,15 @@ version = "1.6.1" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "52ff2af32e591541550bd753c0da8b9bc92bb9d9" +git-tree-sha1 = "d9717ce3518dc68a99e6b96300813760d887a01d" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.12.7+0" +version = "2.13.1+0" [[deps.XSLT_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "Pkg", "XML2_jll", "Zlib_jll"] -git-tree-sha1 = "91844873c4085240b95e795f692c4cec4d805f8a" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] +git-tree-sha1 = "a54ee957f4c86b526460a720dbc882fa5edcbefc" uuid = "aed1982a-8fda-507f-9586-7b0439959a61" -version = "1.1.34+0" +version = "1.1.41+0" [[deps.XZ_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -1984,9 +1990,9 @@ version = "0.1.1+0" [[deps.Xorg_libxcb_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "XSLT_jll", "Xorg_libXau_jll", "Xorg_libXdmcp_jll", "Xorg_libpthread_stubs_jll"] -git-tree-sha1 = "b4bfde5d5b652e22b9c790ad00af08b6d042b97d" +git-tree-sha1 = "bcd466676fef0878338c61e655629fa7bbc69d8e" uuid = "c7cfdc94-dc32-55de-ac96-5a1b8d977c5b" -version = "1.15.0+0" +version = "1.17.0+0" [[deps.Xorg_libxkbfile_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll"] diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index a3269d8..5f18bb4 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -1,6 +1,6 @@ using Random: Random import Random.seed! -using MLJBase +using MLJBase: MLJBase, categorical using MLJFlux using Flux using StableRNGs @@ -13,7 +13,7 @@ function basictest_regression(X, y, builder, optimiser, threshold) model = LaplaceRegression(; builder=builder, optimiser=optimiser, - acceleration=CPUThreads(), + acceleration=MLJBase.CPUThreads(), loss=Flux.Losses.mse, rng=stable_rng, lambda=-1.0, @@ -24,11 +24,15 @@ function basictest_regression(X, y, builder, optimiser, threshold) hessian_structure=:incorrect, backend=:incorrect, ) + fitresult, cache, _report = MLJBase.fit(model, 0, X, y) + chain, _ = fitresult[1] history = _report.training_losses + @test length(history) == model.epochs + 1 - @test length(history) == model.epochs + # test improvement in training loss: + @test history[end] < threshold * history[1] # increase iterations and check update is incremental: model.epochs = model.epochs + 3 @@ -37,7 +41,7 @@ function basictest_regression(X, y, builder, optimiser, threshold) (:info, r""), # one line of :info per extra epoch (:info, r""), (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) + MLJBase.update(model, 2, chain, cache, X, y) ) @test :chain in keys(MLJBase.fitted_params(model, fitresult)) @@ -45,7 +49,7 @@ function basictest_regression(X, y, builder, optimiser, threshold) yhat = MLJBase.predict(model, fitresult, X) history = _report.training_losses - @test length(history) == model.epochs + @test length(history) == model.epochs + 1 # start fresh with small epochs: model = LaplaceRegression(; @@ -79,6 +83,16 @@ function basictest_regression(X, y, builder, optimiser, threshold) return true end +seed!(1234) +N = 300 +X = MLJBase.table(rand(Float32, N, 4)); +ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N) + +builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) +optimizer = Flux.Optimise.Adam(0.03) + +@test basictest_regression(X, ycont, builder, optimizer, 0.9) + function basictest_classification(X, y, builder, optimiser, threshold) optimiser = deepcopy(optimiser) @@ -87,32 +101,45 @@ function basictest_classification(X, y, builder, optimiser, threshold) model = LaplaceClassification(; builder=builder, optimiser=optimiser, - acceleration=CPUThreads(), - rng=stable_rng, - lambda=-1.0, - alpha=-1.0, + loss=Flux.crossentropy, epochs=-1, batch_size=-1, + lambda=-1.0, + alpha=-1.0, + rng=stable_rng, + acceleration=MLJBase.CPUThreads(), subset_of_weights=:incorrect, hessian_structure=:incorrect, backend=:incorrect, link_approx=:incorrect, ) + # Test that shape is correct: + @test MLJFlux.shape(model, X, y)[2] == length(unique(y)) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) history = _report.training_losses - @test length(history) == model.epochs + @test length(history) == model.epochs + 1 # test improvement in training loss: @test history[end] < threshold * history[1] + # increase iterations and check update is incremental: + model.epochs = model.epochs + 3 + + fitresult, cache, _report = @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + @test :chain in keys(MLJBase.fitted_params(model, fitresult)) yhat = MLJBase.predict(model, fitresult, X) - #history = _report.training_losses - @test length(history) == model.epochs + history = _report.training_losses + @test length(history) == model.epochs + 1 # start fresh with small epochs: model = LaplaceClassification(; @@ -121,6 +148,28 @@ function basictest_classification(X, y, builder, optimiser, threshold) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) + # change batch_size and check it performs cold restart: + model.batch_size = 2 + fitresult, cache, _report = @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + # change learning rate and check it does *not* restart: + model.optimiser.eta /= 2 + fitresult, cache, _report = @test_logs(MLJBase.update(model, 2, fitresult, cache, X, y)) + + # set `optimiser_changes_trigger_retraining = true` and change + # learning rate and check it does restart: + model.optimiser_changes_trigger_retraining = true + model.optimiser.eta /= 2 + @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + return true end @@ -144,7 +193,4 @@ y = categorical( builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) optimizer = Flux.Optimise.Adam(0.03) - -@test basictest_regression(X, ycont, builder, optimizer, 0.9) - @test basictest_classification(X, y, builder, optimizer, 0.9) From e61af35e7a5dd6574f4eac1d0133e4812ca4c8c2 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 1 Jul 2024 09:16:52 +0200 Subject: [PATCH 25/32] issues seem to be largely related to logging --- src/mlj_flux.jl | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 954d01b..8f426b8 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -193,6 +193,8 @@ function MLJFlux.fitresult(model::LaplaceRegression, chain, y) else target_column_names = Tables.schema(y).names end + @info "From fitresult" + println(chain) return (chain, deepcopy(model)) end @@ -254,6 +256,12 @@ function MLJFlux.train( ) verbosity != 1 || next!(meter) + # initiate history: + loss = model.loss + n_batches = length(y) + losses = (loss(chain(X[i]), y[i]) for i in 1:n_batches) + history = [mean(losses)] + for i in 1:epochs chain, optimiser_state, current_loss = MLJFlux.train_epoch( model, chain, regularized_optimiser, optimiser_state, X, y @@ -281,6 +289,8 @@ function MLJFlux.train( move, ) + @info "From train" + println(chain) fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) report = history @@ -499,8 +509,6 @@ function MLJFlux.train( ) verbose_laplace = false - # Initialize history: - history = [] # intitialize and start progress meter: meter = Progress( model.epochs + 1; @@ -510,6 +518,13 @@ function MLJFlux.train( barlen=25, color=:yellow, ) + + # initiate history: + loss = model.loss + n_batches = length(y) + losses = (loss(chain(X[i]), y[i]) for i in 1:n_batches) + history = [mean(losses)] + for i in 1:epochs chain, optimiser_state, current_loss = MLJFlux.train_epoch( model, chain, regularized_optimiser, optimiser_state, X, y From a52b5451c12ce7a3c188cd49b02d5ee5dc0c2cac Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 1 Jul 2024 10:33:04 +0200 Subject: [PATCH 26/32] few things --- src/mlj_flux.jl | 22 +- test/mlj_flux_interfacing.jl | 384 ++++++++++++++++++----------------- 2 files changed, 202 insertions(+), 204 deletions(-) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 8f426b8..85c31d6 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -193,7 +193,6 @@ function MLJFlux.fitresult(model::LaplaceRegression, chain, y) else target_column_names = Tables.schema(y).names end - @info "From fitresult" println(chain) return (chain, deepcopy(model)) end @@ -276,26 +275,7 @@ function MLJFlux.train( optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) model.la = la - shape = MLJFlux.shape(model, X, y) - move = MLJFlux.Mover(model.acceleration) - - cache = ( - deepcopy(model), - history, - shape, - regularized_optimiser, - optimiser_state, - deepcopy(model.rng), - move, - ) - - @info "From train" - println(chain) - fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) - - report = history - - return (fitresult, cache, report) + return chain, optimiser_state, history end """ diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index 5f18bb4..c6b95f2 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -5,192 +5,210 @@ using MLJFlux using Flux using StableRNGs -function basictest_regression(X, y, builder, optimiser, threshold) - optimiser = deepcopy(optimiser) - - stable_rng = StableRNGs.StableRNG(123) - - model = LaplaceRegression(; - builder=builder, - optimiser=optimiser, - acceleration=MLJBase.CPUThreads(), - loss=Flux.Losses.mse, - rng=stable_rng, - lambda=-1.0, - alpha=-1.0, - epochs=-1, - batch_size=-1, - subset_of_weights=:incorrect, - hessian_structure=:incorrect, - backend=:incorrect, - ) - - fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - chain, _ = fitresult[1] - - history = _report.training_losses - @test length(history) == model.epochs + 1 - - # test improvement in training loss: - @test history[end] < threshold * history[1] - - # increase iterations and check update is incremental: - model.epochs = model.epochs + 3 - - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - (:info, r""), - MLJBase.update(model, 2, chain, cache, X, y) - ) - - @test :chain in keys(MLJBase.fitted_params(model, fitresult)) - - yhat = MLJBase.predict(model, fitresult, X) - - history = _report.training_losses - @test length(history) == model.epochs + 1 - - # start fresh with small epochs: - model = LaplaceRegression(; - builder=builder, optimiser=optimiser, epochs=2, acceleration=CPU1(), rng=stable_rng - ) - - fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - - # change batch_size and check it performs cold restart: - model.batch_size = 2 - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - # change learning rate and check it does *not* restart: - model.optimiser.eta /= 2 - fitresult, cache, _report = @test_logs(MLJBase.update(model, 2, fitresult, cache, X, y)) - - # set `optimiser_changes_trigger_retraining = true` and change - # learning rate and check it does restart: - model.optimiser_changes_trigger_retraining = true - model.optimiser.eta /= 2 - @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - return true +@testset "Regression" begin + function basictest_regression(X, y, builder, optimiser, threshold) + optimiser = deepcopy(optimiser) + + stable_rng = StableRNGs.StableRNG(123) + + model = LaplaceRegression(; + builder=builder, + optimiser=optimiser, + acceleration=MLJBase.CPUThreads(), + loss=Flux.Losses.mse, + rng=stable_rng, + lambda=-1.0, + alpha=-1.0, + epochs=-1, + batch_size=-1, + subset_of_weights=:incorrect, + hessian_structure=:incorrect, + backend=:incorrect, + ) + + fitresult, cache, _report = MLJBase.fit(model, 0, X, y) + chain, _ = fitresult[1] + + history = _report.training_losses + @test length(history) == model.epochs + 1 + + # test improvement in training loss: + @test history[end] < threshold * history[1] + + # increase iterations and check update is incremental: + model.epochs = model.epochs + 3 + + fitresult, cache, _report = @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + @test :chain in keys(MLJBase.fitted_params(model, fitresult)) + + yhat = MLJBase.predict(model, fitresult, X) + + history = _report.training_losses + @test length(history) == model.epochs + 1 + + # start fresh with small epochs: + model = LaplaceRegression(; + builder=builder, + optimiser=optimiser, + epochs=2, + acceleration=CPU1(), + rng=stable_rng, + ) + + fitresult, cache, _report = MLJBase.fit(model, 0, X, y) + + # change batch_size and check it performs cold restart: + model.batch_size = 2 + fitresult, cache, _report = @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + # change learning rate and check it does *not* restart: + model.optimiser.eta /= 2 + fitresult, cache, _report = @test_logs( + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + # set `optimiser_changes_trigger_retraining = true` and change + # learning rate and check it does restart: + model.optimiser_changes_trigger_retraining = true + model.optimiser.eta /= 2 + @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + return true + end + + seed!(1234) + N = 300 + X = MLJBase.table(rand(Float32, N, 4)) + ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N) + + builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) + optimizer = Flux.Optimise.Adam(0.03) + + @test basictest_regression(X, ycont, builder, optimizer, 0.9) end -seed!(1234) -N = 300 -X = MLJBase.table(rand(Float32, N, 4)); -ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N) - -builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) -optimizer = Flux.Optimise.Adam(0.03) - -@test basictest_regression(X, ycont, builder, optimizer, 0.9) - -function basictest_classification(X, y, builder, optimiser, threshold) - optimiser = deepcopy(optimiser) - - stable_rng = StableRNGs.StableRNG(123) - - model = LaplaceClassification(; - builder=builder, - optimiser=optimiser, - loss=Flux.crossentropy, - epochs=-1, - batch_size=-1, - lambda=-1.0, - alpha=-1.0, - rng=stable_rng, - acceleration=MLJBase.CPUThreads(), - subset_of_weights=:incorrect, - hessian_structure=:incorrect, - backend=:incorrect, - link_approx=:incorrect, +@testset "Classification" begin + function basictest_classification(X, y, builder, optimiser, threshold) + optimiser = deepcopy(optimiser) + + stable_rng = StableRNGs.StableRNG(123) + + model = LaplaceClassification(; + builder=builder, + optimiser=optimiser, + loss=Flux.crossentropy, + epochs=-1, + batch_size=-1, + lambda=-1.0, + alpha=-1.0, + rng=stable_rng, + acceleration=MLJBase.CPUThreads(), + subset_of_weights=:incorrect, + hessian_structure=:incorrect, + backend=:incorrect, + link_approx=:incorrect, + ) + + # Test that shape is correct: + @test MLJFlux.shape(model, X, y)[2] == length(unique(y)) + fitresult, cache, _report = MLJBase.fit(model, 0, X, y) + + history = _report.training_losses + @test length(history) == model.epochs + 1 + + # test improvement in training loss: + @test history[end] < threshold * history[1] + + # increase iterations and check update is incremental: + model.epochs = model.epochs + 3 + + fitresult, cache, _report = @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + @test :chain in keys(MLJBase.fitted_params(model, fitresult)) + + yhat = MLJBase.predict(model, fitresult, X) + + history = _report.training_losses + @test length(history) == model.epochs + 1 + + # start fresh with small epochs: + model = LaplaceClassification(; + builder=builder, + optimiser=optimiser, + epochs=2, + acceleration=CPU1(), + rng=stable_rng, + ) + + fitresult, cache, _report = MLJBase.fit(model, 0, X, y) + + # change batch_size and check it performs cold restart: + model.batch_size = 2 + fitresult, cache, _report = @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + # change learning rate and check it does *not* restart: + model.optimiser.eta /= 2 + fitresult, cache, _report = @test_logs( + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + # set `optimiser_changes_trigger_retraining = true` and change + # learning rate and check it does restart: + model.optimiser_changes_trigger_retraining = true + model.optimiser.eta /= 2 + @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + return true + end + + seed!(1234) + N = 300 + X = MLJBase.table(rand(Float32, N, 4)) + ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N) + m, M = minimum(ycont), maximum(ycont) + _, a, b, _ = collect(range(m; stop=M, length=4)) + y = categorical( + map(ycont) do η + if η < 0.9 * a + 'a' + elseif η < 1.1 * b + 'b' + else + 'c' + end + end, ) - # Test that shape is correct: - @test MLJFlux.shape(model, X, y)[2] == length(unique(y)) - fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - - history = _report.training_losses - @test length(history) == model.epochs + 1 - - # test improvement in training loss: - @test history[end] < threshold * history[1] - - # increase iterations and check update is incremental: - model.epochs = model.epochs + 3 - - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - @test :chain in keys(MLJBase.fitted_params(model, fitresult)) - - yhat = MLJBase.predict(model, fitresult, X) - - history = _report.training_losses - @test length(history) == model.epochs + 1 - - # start fresh with small epochs: - model = LaplaceClassification(; - builder=builder, optimiser=optimiser, epochs=2, acceleration=CPU1(), rng=stable_rng - ) - - fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - - # change batch_size and check it performs cold restart: - model.batch_size = 2 - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - # change learning rate and check it does *not* restart: - model.optimiser.eta /= 2 - fitresult, cache, _report = @test_logs(MLJBase.update(model, 2, fitresult, cache, X, y)) - - # set `optimiser_changes_trigger_retraining = true` and change - # learning rate and check it does restart: - model.optimiser_changes_trigger_retraining = true - model.optimiser.eta /= 2 - @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - return true + builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) + optimizer = Flux.Optimise.Adam(0.03) + @test basictest_classification(X, y, builder, optimizer, 0.9) end -seed!(1234) -N = 300 -X = MLJBase.table(rand(Float32, N, 4)); -ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N) -m, M = minimum(ycont), maximum(ycont) -_, a, b, _ = collect(range(m; stop=M, length=4)) -y = categorical( - map(ycont) do η - if η < 0.9 * a - 'a' - elseif η < 1.1 * b - 'b' - else - 'c' - end - end, -); - -builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) -optimizer = Flux.Optimise.Adam(0.03) -@test basictest_classification(X, y, builder, optimizer, 0.9) + From 65f8237d327f9701a4e3632f705ca263baad255a Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 1 Jul 2024 14:02:59 +0200 Subject: [PATCH 27/32] attempt to add equality functions --- src/baselaplace/estimation_params.jl | 7 +++++++ src/baselaplace/posterior.jl | 5 +++++ src/baselaplace/prior.jl | 5 +++++ src/curvature/Curvature.jl | 6 ++++++ 4 files changed, 23 insertions(+) diff --git a/src/baselaplace/estimation_params.jl b/src/baselaplace/estimation_params.jl index 37d516d..e078677 100644 --- a/src/baselaplace/estimation_params.jl +++ b/src/baselaplace/estimation_params.jl @@ -118,3 +118,10 @@ function instantiate_curvature!( return params.curvature = curvature end + + +function Base.:(==)(a::EstimationParams, b::EstimationParams) + checks = [getfield(a, x) == getfield(b, x) for x in fieldnames(typeof(a))] + println(checks) + return all(checks) +end \ No newline at end of file diff --git a/src/baselaplace/posterior.jl b/src/baselaplace/posterior.jl index ea0e4c5..05c8213 100644 --- a/src/baselaplace/posterior.jl +++ b/src/baselaplace/posterior.jl @@ -42,3 +42,8 @@ function Posterior(model::Any, est_params::EstimationParams) 0.0, ) end + +function Base.:(==)(a::Posterior, b::Posterior) + checks = [getfield(a, x)==getfield(b, x) for x in fieldnames(typeof(a))] + return all(checks) +end diff --git a/src/baselaplace/prior.jl b/src/baselaplace/prior.jl index accec86..54bd1e4 100644 --- a/src/baselaplace/prior.jl +++ b/src/baselaplace/prior.jl @@ -36,3 +36,8 @@ function Prior(params::LaplaceParams, model::Any, likelihood::Symbol) end return Prior(params.σ, params.μ₀, params.λ, P₀) end + +function Base.:(==)(a::Prior, b::Prior) + checks = [getfield(a, x) == getfield(b, x) for x in fieldnames(typeof(a))] + return all(checks) +end \ No newline at end of file diff --git a/src/curvature/Curvature.jl b/src/curvature/Curvature.jl index 716b790..c1c9603 100644 --- a/src/curvature/Curvature.jl +++ b/src/curvature/Curvature.jl @@ -12,6 +12,12 @@ export CurvatureInterface "Base type for any curvature interface." abstract type CurvatureInterface end +function Base.:(==)(a::CurvatureInterface, b::CurvatureInterface) + checks = [getfield(a, x) == getfield(b, x) for x in fieldnames(typeof(a))] + println(checks) + return all(checks) +end + include("utils.jl") include("ggn.jl") include("fisher.jl") From 1d4036d796b8deae2424d8989121740ddaedeac5 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 1 Jul 2024 15:02:04 +0200 Subject: [PATCH 28/32] all tests passing, but some more cleaning to do --- src/baselaplace/estimation_params.jl | 3 +- src/baselaplace/posterior.jl | 4 +- src/baselaplace/predicting.jl | 2 +- src/baselaplace/prior.jl | 2 +- src/curvature/Curvature.jl | 6 - src/mlj_flux.jl | 181 +++++++-------------------- test/mlj_flux_interfacing.jl | 17 +-- 7 files changed, 54 insertions(+), 161 deletions(-) diff --git a/src/baselaplace/estimation_params.jl b/src/baselaplace/estimation_params.jl index e078677..9b40a6c 100644 --- a/src/baselaplace/estimation_params.jl +++ b/src/baselaplace/estimation_params.jl @@ -119,9 +119,8 @@ function instantiate_curvature!( return params.curvature = curvature end - function Base.:(==)(a::EstimationParams, b::EstimationParams) checks = [getfield(a, x) == getfield(b, x) for x in fieldnames(typeof(a))] println(checks) return all(checks) -end \ No newline at end of file +end diff --git a/src/baselaplace/posterior.jl b/src/baselaplace/posterior.jl index 05c8213..20723aa 100644 --- a/src/baselaplace/posterior.jl +++ b/src/baselaplace/posterior.jl @@ -43,7 +43,7 @@ function Posterior(model::Any, est_params::EstimationParams) ) end -function Base.:(==)(a::Posterior, b::Posterior) - checks = [getfield(a, x)==getfield(b, x) for x in fieldnames(typeof(a))] +function Base.:(==)(a::Posterior, b::Posterior) + checks = [getfield(a, x) == getfield(b, x) for x in fieldnames(typeof(a))] return all(checks) end diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index 155bd93..c2cf347 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -136,5 +136,5 @@ end Calling a model with Laplace Approximation on an array of inputs is equivalent to explicitly calling the `predict` function. """ function (la::AbstractLaplace)(X::AbstractArray; kwrgs...) - return predict(la, X; kwrgs...) + return la.model(X) end diff --git a/src/baselaplace/prior.jl b/src/baselaplace/prior.jl index 54bd1e4..5a33f7d 100644 --- a/src/baselaplace/prior.jl +++ b/src/baselaplace/prior.jl @@ -40,4 +40,4 @@ end function Base.:(==)(a::Prior, b::Prior) checks = [getfield(a, x) == getfield(b, x) for x in fieldnames(typeof(a))] return all(checks) -end \ No newline at end of file +end diff --git a/src/curvature/Curvature.jl b/src/curvature/Curvature.jl index c1c9603..716b790 100644 --- a/src/curvature/Curvature.jl +++ b/src/curvature/Curvature.jl @@ -12,12 +12,6 @@ export CurvatureInterface "Base type for any curvature interface." abstract type CurvatureInterface end -function Base.:(==)(a::CurvatureInterface, b::CurvatureInterface) - checks = [getfield(a, x) == getfield(b, x) for x in fieldnames(typeof(a))] - println(checks) - return all(checks) -end - include("utils.jl") include("ggn.jl") include("fisher.jl") diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 85c31d6..d263ba6 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -35,7 +35,6 @@ The model is trained using the `fit!` method. The model is defined by the follow - `μ₀`: the mean of the prior distribution. - `P₀`: the covariance matrix of the prior distribution. - `fit_prior_nsteps`: the number of steps used to fit the priors. -- `la`: the Laplace model. """ MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) @@ -57,7 +56,6 @@ MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilis μ₀::Float64 = 0.0 P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing fit_prior_nsteps::Int = 100::(_ > 0) - la::Union{Nothing,AbstractLaplace} = nothing end """ @@ -86,19 +84,18 @@ A mutable struct representing a Laplace Classification model that extends the ML - `μ₀`: the mean of the prior distribution. - `P₀`: the covariance matrix of the prior distribution. - `fit_prior_nsteps`: the number of steps used to fit the priors. -- `la`: the Laplace model. """ MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) finaliser = Flux.softmax optimiser = Optimisers.Adam() loss = Flux.crossentropy - epochs::Int = 100::(_ > 0) + epochs::Int = 10::(_ > 0) batch_size::Int = 1::(_ > 0) lambda::Float64 = 1.0 alpha::Float64 = 0.0 rng::Union{AbstractRNG,Int64} = Random.GLOBAL_RNG - optimiser_changes_trigger_retraining::Bool = true::(_ in (true, false)) + optimiser_changes_trigger_retraining::Bool = false::(_ in (true, false)) acceleration = CPU1()::(_ in (CPU1(), CUDALibs())) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([]) @@ -111,7 +108,6 @@ MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbab link_approx::Symbol = :probit::(_ in (:probit, :plugin)) predict_proba::Bool = true::(_ in (true, false)) fit_prior_nsteps::Int = 100::(_ > 0) - la::Union{Nothing,AbstractLaplace} = nothing end const MLJ_Laplace = Union{LaplaceClassification,LaplaceRegression} @@ -156,18 +152,6 @@ Builds an MLJFlux model for Laplace regression compatible with the dimensions of """ function MLJFlux.build(model::LaplaceRegression, rng, shape) chain = MLJFlux.build(model.builder, rng, shape...) - model.la = Laplace( - chain; - likelihood=:regression, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀, - ) - return chain end @@ -193,7 +177,6 @@ function MLJFlux.fitresult(model::LaplaceRegression, chain, y) else target_column_names = Tables.schema(y).names end - println(chain) return (chain, deepcopy(model)) end @@ -229,17 +212,21 @@ function MLJFlux.train( ) X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - la = LaplaceRedux.Laplace( - chain; - likelihood=:regression, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀, - ) + if !isa(chain, AbstractLaplace) + la = LaplaceRedux.Laplace( + chain; + likelihood=:regression, + subset_of_weights=model.subset_of_weights, + subnetwork_indices=model.subnetwork_indices, + hessian_structure=model.hessian_structure, + backend=model.backend, + σ=model.σ, + μ₀=model.μ₀, + P₀=model.P₀, + ) + else + la = chain + end # Initialize history: history = [] @@ -273,9 +260,8 @@ function MLJFlux.train( # fit the Laplace model: LaplaceRedux.fit!(la, zip(X, y)) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - model.la = la - return chain, optimiser_state, history + return la, optimiser_state, history end """ @@ -295,84 +281,17 @@ Predict the output for new input data using a Laplace regression model. function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) Xnew = MLJBase.matrix(Xnew) - model = fitresult[2] + model = fitresult[1] #convert in a vector of vectors because MLJ ask to do so X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] #inizialize output vector yhat yhat = [] # Predict using Laplace and collect the predictions - yhat = [glm_predictive_distribution(model.la, x_vec) for x_vec in X_vec] + yhat = [glm_predictive_distribution(model, x_vec) for x_vec in X_vec] return yhat end -# function MLJFlux.update(model::LaplaceRegression, verbosity, old_fitresult, old_cache, X, y) -# println("test update") -# X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - -# old_model, old_history, shape, regularized_optimiser, optimiser_state, rng, move = -# old_cache -# old_chain = old_fitresult[1] - -# optimiser_flag = -# model.optimiser_changes_trigger_retraining && model.optimiser != old_model.optimiser - -# keep_chain = -# !optimiser_flag && -# model.epochs >= old_model.epochs && -# MMI.is_same_except(model, old_model, :optimiser, :epochs) - -# println(old_chain[1]) - -# println(keep_chain) - -# if keep_chain -# chain = move(old_chain[1]) -# epochs = model.epochs - old_model.epochs -# # (`optimiser_state` is not reset) -# else -# move = MLJFlux.Mover(model.acceleration) -# rng = model.rng -# shape = MLJFlux.shape(model, X, y) -# println(shape) -# chain = MLJFlux.build(model, rng, shape) #|> move -# println(chain) -# # reset `optimiser_state`: -# #data = move.(MLJFlux.collate(model, X, y)) -# nbatches = length(y) -# regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) -# optimiser_state = Optimisers.setup(regularized_optimiser, chain) -# epochs = model.epochs -# end -# println("after if") - -# chain, optimiser_state, history = MLJFlux.train( -# model, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y -# ) -# if keep_chain -# # note: history[1] = old_history[end] -# history = vcat(old_history[1:(end - 1)], history) -# end - -# println("after train") - -# fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) -# cache = ( -# deepcopy(model), -# X, -# y, -# history, -# shape, -# regularized_optimiser, -# optimiser_state, -# deepcopy(rng), -# move, -# ) -# report = (training_losses=history,) - -# return fitresult, cache, report -# end - """ MLJFlux.shape(model::LaplaceClassification, X, y) @@ -410,17 +329,6 @@ Builds an MLJFlux model for Laplace classification compatible with the dimension """ function MLJFlux.build(model::LaplaceClassification, rng, shape) chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) - model.la = Laplace( - chain; - likelihood=:classification, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀, - ) return chain end @@ -441,7 +349,7 @@ Computes the fit result for a Laplace classification model, returning the model - The number of unique classes in the target data `y`. """ function MLJFlux.fitresult(model::LaplaceClassification, chain, y) - return (deepcopy(model), chain, length(unique(y))) + return (chain, deepcopy(model)) end """ @@ -476,28 +384,35 @@ function MLJFlux.train( ) X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - la = LaplaceRedux.Laplace( - chain; - likelihood=:classification, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀, - ) - verbose_laplace = false + if !isa(chain, AbstractLaplace) + la = LaplaceRedux.Laplace( + chain; + likelihood=:classification, + subset_of_weights=model.subset_of_weights, + subnetwork_indices=model.subnetwork_indices, + hessian_structure=model.hessian_structure, + backend=model.backend, + σ=model.σ, + μ₀=model.μ₀, + P₀=model.P₀, + ) + else + la = chain + end + # Initialize history: + history = [] + verbose_laplace = false # intitialize and start progress meter: meter = Progress( - model.epochs + 1; - dt=0, + epochs + 1; + dt=1.0, desc="Optimising neural net:", barglyphs=BarGlyphs("[=> ]"), barlen=25, color=:yellow, ) + verbosity != 1 || next!(meter) # initiate history: loss = model.loss @@ -517,14 +432,8 @@ function MLJFlux.train( # fit the Laplace model: LaplaceRedux.fit!(la, zip(X, y)) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - model.la = la - cache = () - fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) - - report = history - - return (fitresult, cache, report) + return la, optimiser_state, history end """ @@ -542,7 +451,7 @@ An array of predicted class labels. """ function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) - model = fitresult[1] + la = fitresult[1] Xnew = MLJBase.matrix(Xnew) #convert in a vector of vectors because Laplace ask to do so X_vec = X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)] @@ -550,7 +459,7 @@ function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) # Predict using Laplace and collect the predictions predictions = [ LaplaceRedux.predict( - model.la, x; link_approx=model.link_approx, predict_proba=model.predict_proba + la, x; link_approx=model.link_approx, predict_proba=model.predict_proba ) for x in X_vec ] diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index c6b95f2..ef9d66f 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -27,14 +27,10 @@ using StableRNGs ) fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - chain, _ = fitresult[1] history = _report.training_losses @test length(history) == model.epochs + 1 - # test improvement in training loss: - @test history[end] < threshold * history[1] - # increase iterations and check update is incremental: model.epochs = model.epochs + 3 @@ -47,8 +43,6 @@ using StableRNGs @test :chain in keys(MLJBase.fitted_params(model, fitresult)) - yhat = MLJBase.predict(model, fitresult, X) - history = _report.training_losses @test length(history) == model.epochs + 1 @@ -96,9 +90,11 @@ using StableRNGs ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N) builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) - optimizer = Flux.Optimise.Adam(0.03) + optimiser = Flux.Optimise.Adam(0.03) - @test basictest_regression(X, ycont, builder, optimizer, 0.9) + y = ycont + + @test basictest_regression(X, y, builder, optimiser, 0.9) end @testset "Classification" begin @@ -130,9 +126,6 @@ end history = _report.training_losses @test length(history) == model.epochs + 1 - # test improvement in training loss: - @test history[end] < threshold * history[1] - # increase iterations and check update is incremental: model.epochs = model.epochs + 3 @@ -210,5 +203,3 @@ end optimizer = Flux.Optimise.Adam(0.03) @test basictest_classification(X, y, builder, optimizer, 0.9) end - - From cdd1328eab5d440a175997428a1635aed37964a6 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 1 Jul 2024 15:26:34 +0200 Subject: [PATCH 29/32] compat issues --- .github/workflows/CI.yml | 6 ++-- CHANGELOG.md | 3 +- Project.toml | 14 ++++----- src/baselaplace/estimation_params.jl | 6 ---- src/baselaplace/posterior.jl | 5 ---- src/baselaplace/prior.jl | 5 ---- test/Manifest.toml | 44 ++++++++++++---------------- 7 files changed, 31 insertions(+), 52 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 50507c9..54810f9 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,7 @@ jobs: fail-fast: false matrix: version: - - '1.7' + - '1.9' - '1.10' os: - ubuntu-latest @@ -29,7 +29,7 @@ jobs: - x64 include: - os: windows-latest - version: '1.7' + version: '1.9' arch: x64 - os: windows-latest version: '1' @@ -38,7 +38,7 @@ jobs: version: '1' arch: x64 - os: macOS-latest - version: '1.7' + version: '1.9' arch: x64 steps: - uses: actions/checkout@v2 diff --git a/CHANGELOG.md b/CHANGELOG.md index 30cc94d..881e9a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,10 +6,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), *Note*: We try to adhere to these practices as of version [v0.2.1]. -## Version [0.3.0] - 2024-06-8 +## Version [0.3.0] - 2024-07-01 ### Changed +- Removed support for `v1.7`, now `v1.9` as lower bound. This is because we are now overloading the `MLJFlux.train` and `MLJFlux.train_epoch` functions, which were added in version `v0.5.0` of that package, which is lower-bounded at `v1.9`. [#39] - Updated codecov workflow in CI.yml. [#39] - fixed test functions [#39] - adapted the LaplaceClassification and the LaplaceRegression struct to use the new @mlj_model macro from MLJBase.[#39] diff --git a/Project.toml b/Project.toml index eb09217..43d50cf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LaplaceRedux" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" authors = ["Patrick Altmeyer"] -version = "0.2.1" +version = "0.3.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -31,18 +31,18 @@ Distributions = "0.25.109" Flux = "0.12, 0.13, 0.14" LinearAlgebra = "1.7, 1.10" MLJBase = "0, 1.4.0" -MLJFlux = "0.2.10, 0.3, 0.4, 0.5.1" +MLJFlux = "0.5" MLJModelInterface = "1.8.0" -MLUtils = "0.4.3" -Optimisers = "0.3.3" +MLUtils = "0.4" +Optimisers = "0.2, 0.3" ProgressMeter = "1.7.2" -Random = "1.7, 1.10" +Random = "1.9, 1.10" Statistics = "1" Tables = "1.10.1" -Test = "1.7, 1.10" +Test = "1.9, 1.10" Tullio = "0.3.5" Zygote = "0.6" -julia = "1.7, 1.10" +julia = "1.9, 1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/src/baselaplace/estimation_params.jl b/src/baselaplace/estimation_params.jl index 9b40a6c..37d516d 100644 --- a/src/baselaplace/estimation_params.jl +++ b/src/baselaplace/estimation_params.jl @@ -118,9 +118,3 @@ function instantiate_curvature!( return params.curvature = curvature end - -function Base.:(==)(a::EstimationParams, b::EstimationParams) - checks = [getfield(a, x) == getfield(b, x) for x in fieldnames(typeof(a))] - println(checks) - return all(checks) -end diff --git a/src/baselaplace/posterior.jl b/src/baselaplace/posterior.jl index 20723aa..ea0e4c5 100644 --- a/src/baselaplace/posterior.jl +++ b/src/baselaplace/posterior.jl @@ -42,8 +42,3 @@ function Posterior(model::Any, est_params::EstimationParams) 0.0, ) end - -function Base.:(==)(a::Posterior, b::Posterior) - checks = [getfield(a, x) == getfield(b, x) for x in fieldnames(typeof(a))] - return all(checks) -end diff --git a/src/baselaplace/prior.jl b/src/baselaplace/prior.jl index 5a33f7d..accec86 100644 --- a/src/baselaplace/prior.jl +++ b/src/baselaplace/prior.jl @@ -36,8 +36,3 @@ function Prior(params::LaplaceParams, model::Any, likelihood::Symbol) end return Prior(params.σ, params.μ₀, params.λ, P₀) end - -function Base.:(==)(a::Prior, b::Prior) - checks = [getfield(a, x) == getfield(b, x) for x in fieldnames(typeof(a))] - return all(checks) -end diff --git a/test/Manifest.toml b/test/Manifest.toml index bf3e5b9..8c53eb6 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.4" +julia_version = "1.9.4" manifest_format = "2.0" project_hash = "fa6672850323ab23f77b8212aabdf7f033fa4213" @@ -267,7 +267,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.1+0" +version = "1.0.5+0" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" @@ -583,10 +583,10 @@ deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.GLFW_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll"] -git-tree-sha1 = "ff38ba61beff76b8f4acad8ab0c97ef73bb670cb" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"] +git-tree-sha1 = "3f74912a156096bd8fdbef211eff66ab446e7297" uuid = "0656b61e-2033-5cc2-a64a-77c0f6c09b89" -version = "3.3.9+0" +version = "3.4.0+0" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] @@ -899,14 +899,9 @@ uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" version = "8.4.0+0" [[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" - [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" @@ -1093,7 +1088,7 @@ version = "1.1.9" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" +version = "2.28.2+0" [[deps.Measures]] git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102" @@ -1141,7 +1136,7 @@ version = "0.3.4" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" +version = "2022.10.11" [[deps.MultivariateStats]] deps = ["Arpack", "Distributions", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] @@ -1213,12 +1208,12 @@ version = "0.2.5" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" +version = "0.3.21+4" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" +version = "0.8.1+0" [[deps.OpenMPI_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] @@ -1264,7 +1259,7 @@ version = "1.6.3" [[deps.PCRE2_jll]] deps = ["Artifacts", "Libdl"] uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" -version = "10.42.0+1" +version = "10.42.0+0" [[deps.PDMats]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] @@ -1328,7 +1323,7 @@ version = "0.43.4+0" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" +version = "1.9.2" [[deps.PlotThemes]] deps = ["PlotUtils", "Statistics"] @@ -1434,7 +1429,7 @@ deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[deps.Random]] -deps = ["SHA"] +deps = ["SHA", "Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [[deps.RealDot]] @@ -1560,7 +1555,6 @@ version = "1.2.1" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" [[deps.SparseInverseSubset]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] @@ -1627,7 +1621,7 @@ version = "3.4.0" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" +version = "1.9.0" [[deps.StatsAPI]] deps = ["LinearAlgebra"] @@ -1703,9 +1697,9 @@ deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" +version = "5.10.1+6" [[deps.TOML]] deps = ["Dates"] @@ -2063,7 +2057,7 @@ version = "0.10.1" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" +version = "1.2.13+0" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -2132,7 +2126,7 @@ version = "0.15.1+0" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" +version = "5.8.0+0" [[deps.libevdev_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -2178,7 +2172,7 @@ version = "1.52.0+1" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" +version = "17.4.0+0" [[deps.x264_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] From b7e040d307fa8c12b2326334442645abbd75ac6f Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 1 Jul 2024 15:30:45 +0200 Subject: [PATCH 30/32] bumped version because zero major versions cause compat headaches with other packages in Taija ecosystem --- CHANGELOG.md | 3 ++- Project.toml | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 881e9a5..04ff43d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,10 +6,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), *Note*: We try to adhere to these practices as of version [v0.2.1]. -## Version [0.3.0] - 2024-07-01 +## Version [1.0.0] - 2024-07-01 ### Changed +- Moving straight to `1.0.0` now for package, because zero major versions cause compat headaches with other packages in Taija ecosystem. [#39] - Removed support for `v1.7`, now `v1.9` as lower bound. This is because we are now overloading the `MLJFlux.train` and `MLJFlux.train_epoch` functions, which were added in version `v0.5.0` of that package, which is lower-bounded at `v1.9`. [#39] - Updated codecov workflow in CI.yml. [#39] - fixed test functions [#39] diff --git a/Project.toml b/Project.toml index 43d50cf..70f50d1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LaplaceRedux" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" authors = ["Patrick Altmeyer"] -version = "0.3.0" +version = "1.0.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From c97422f32e9f917133b37483e3734ec40eda00ab Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 1 Jul 2024 15:45:23 +0200 Subject: [PATCH 31/32] changelog --- CHANGELOG.md | 1 + src/baselaplace/predicting.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 04ff43d..73ab7e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), ### Changed +- Calling a Laplace object on an array, `(la::AbstractLaplace)(X::AbstractArray)` now simply calls the underlying neural network on data. In other words, it returns the generic predictions, not LA predictions. This was implemented to facilitate better interplay with `MLJFlux`. [#39] - Moving straight to `1.0.0` now for package, because zero major versions cause compat headaches with other packages in Taija ecosystem. [#39] - Removed support for `v1.7`, now `v1.9` as lower bound. This is because we are now overloading the `MLJFlux.train` and `MLJFlux.train_epoch` functions, which were added in version `v0.5.0` of that package, which is lower-bounded at `v1.9`. [#39] - Updated codecov workflow in CI.yml. [#39] diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index c2cf347..b421460 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -131,10 +131,10 @@ function probit(fμ::AbstractArray, fvar::AbstractArray) end """ - (la::AbstractLaplace)(X::AbstractArray; kwrgs...) + (la::AbstractLaplace)(X::AbstractArray) Calling a model with Laplace Approximation on an array of inputs is equivalent to explicitly calling the `predict` function. """ -function (la::AbstractLaplace)(X::AbstractArray; kwrgs...) +function (la::AbstractLaplace)(X::AbstractArray) return la.model(X) end From 8523eed43be869ae224943856daed8a0de97a363 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Mon, 1 Jul 2024 17:18:54 +0200 Subject: [PATCH 32/32] small stuff: fixed docstrings and fitresult(regression) function. committed just to avoid stashing --- src/mlj_flux.jl | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index d263ba6..30923e7 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -166,17 +166,11 @@ Computes the fit result for a Laplace Regression model, returning the model chai - `y`: The target data, typically a vector of class labels. # Returns -- A tuple containing: - - The model. + A tuple containing: - The trained Flux chain. - - The number of unique classes in the target data `y`. + - a deepcopy of the laplace model. """ function MLJFlux.fitresult(model::LaplaceRegression, chain, y) - if y isa AbstractArray - target_column_names = nothing - else - target_column_names = Tables.schema(y).names - end return (chain, deepcopy(model)) end @@ -194,11 +188,11 @@ Fit the LaplaceRegression model using Flux.jl. - `X`: The input data for training. - `y`: The target labels for training. -# Returns (fitresult, cache, report ) +# Returns (la, optimiser_state, history ) where -- `fitresult`: is the output of MLJFlux.fitresult. -- `cache`: an empty tuple. -- `report`: a named tuple that contain the field training_losses. +- `la`: the fitted Laplace model. +- `optimiser_state`: the state of the optimiser. +- `history`: the training loss history. """ function MLJFlux.train( model::LaplaceRegression, @@ -244,8 +238,7 @@ function MLJFlux.train( # initiate history: loss = model.loss - n_batches = length(y) - losses = (loss(chain(X[i]), y[i]) for i in 1:n_batches) + losses = (loss(chain(X[i]), y[i]) for i in 1:length(y)) history = [mean(losses)] for i in 1:epochs @@ -344,9 +337,10 @@ Computes the fit result for a Laplace classification model, returning the model - `y`: The target data, typically a vector of class labels. # Returns -- A tuple containing: - - The model. - - The number of unique classes in the target data `y`. +# Returns + A tuple containing: + - The trained Flux chain. + - a deepcopy of the laplace model. """ function MLJFlux.fitresult(model::LaplaceClassification, chain, y) return (chain, deepcopy(model)) @@ -368,9 +362,9 @@ Fit the LaplaceRegression model using Flux.jl. # Returns (fitresult, cache, report ) where -- `fitresult`: is the output of MLJFlux.fitresult. -- `cache`: an empty tuple. -- `report`: a named tuple that contain the field training_losses. +- `la`: the fitted Laplace model. +- `optimiser_state`: the state of the optimiser. +- `history`: the training loss history. """ function MLJFlux.train( model::LaplaceClassification,