diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 7a2b866..75786cb 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,6 @@ jobs: fail-fast: false matrix: version: - - '1.9' - '1.10' os: - ubuntu-latest @@ -29,10 +28,10 @@ jobs: - x64 include: - os: windows-latest - version: '1.9' + version: '1' arch: x64 - os: macOS-latest - version: '1.9' + version: '1' arch: x64 steps: - uses: actions/checkout@v2 diff --git a/Project.toml b/Project.toml index 3949c1e..6039591 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LaplaceRedux" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" authors = ["Patrick Altmeyer"] -version = "1.1.0" +version = "1.1.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -36,13 +36,13 @@ MLJModelInterface = "1.8.0" MLUtils = "0.4" Optimisers = "0.2, 0.3" ProgressMeter = "1.7.2" -Random = "1.9, 1.10" +Random = "1.10" Statistics = "1" Tables = "1.10.1" -Test = "1.9, 1.10" +Test = "1.10" Tullio = "0.3.5" Zygote = "0.6" -julia = "1.9, 1.10" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/dev/Manifest.toml b/dev/Manifest.toml index c2aa14e..f50e025 100644 --- a/dev/Manifest.toml +++ b/dev/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.5" manifest_format = "2.0" -project_hash = "7c4391b00ad44ccd083aa89cf817ea4c83643a8f" +project_hash = "6224e923196c94cc69f23bf59d6ea60b7ea2c919" [[deps.AbstractFFTs]] deps = ["LinearAlgebra"] @@ -717,7 +717,7 @@ version = "1.3.1" deps = ["ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] path = ".." uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" -version = "1.1.0" +version = "1.1.1" [[deps.Latexify]] deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] diff --git a/dev/Project.toml b/dev/Project.toml index f68c2fc..3fb161e 100644 --- a/dev/Project.toml +++ b/dev/Project.toml @@ -4,6 +4,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" Luxor = "ae8d54c2-7ccd-5906-9d76-62fc9837b5bc" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PlotThemes = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/dev/issues/predict_slow.jl b/dev/issues/predict_slow.jl index 67ae1c6..ae713cb 100644 --- a/dev/issues/predict_slow.jl +++ b/dev/issues/predict_slow.jl @@ -1,10 +1,11 @@ using LaplaceRedux using MLJBase +using Optimisers X = MLJBase.table(rand(Float32, 100, 3)); y = coerce(rand("abc", 100), Multiclass); -model = LaplaceClassification(); -fitresult, _, _ = MLJBase.fit(model, 0, X, y); +model = LaplaceClassification(optimiser=Optimisers.Adam(0.1), epochs=100); +fitresult, _, _ = MLJBase.fit(model, 2, X, y); la = fitresult[1]; Xmat = matrix(X) |> permutedims; diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 4df5771..18890b3 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -247,21 +247,17 @@ function MLJFlux.train( push!(history, current_loss) end - if !isa(chain, AbstractLaplace) - la = LaplaceRedux.Laplace( - chain; - likelihood=:regression, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀, - ) - else - la = chain - end + la = LaplaceRedux.Laplace( + chain; + likelihood=:regression, + subset_of_weights=model.subset_of_weights, + subnetwork_indices=model.subnetwork_indices, + hessian_structure=model.hessian_structure, + backend=model.backend, + σ=model.σ, + μ₀=model.μ₀, + P₀=model.P₀, + ) # fit the Laplace model: LaplaceRedux.fit!(la, zip(X, y)) @@ -285,12 +281,9 @@ Predict the output for new input data using a Laplace regression model. """ function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) - Xnew = MLJBase.matrix(Xnew) + Xnew = MLJBase.matrix(Xnew) |> permutedims la = fitresult[1] - #convert in a vector of vectors because MLJ ask to do so - X_vec = collect(eachrow(Xnew)) - # Predict using Laplace and collect the predictions - yhat = [map(x -> LaplaceRedux.predict(la, x; ret_distr=model.ret_distr), X_vec)...] + yhat = LaplaceRedux.predict(la, Xnew; ret_distr=model.ret_distr) return yhat end @@ -416,21 +409,18 @@ function MLJFlux.train( push!(history, current_loss) end - if !isa(chain, AbstractLaplace) - la = LaplaceRedux.Laplace( - chain; - likelihood=:classification, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀, - ) - else - la = chain - end + la = LaplaceRedux.Laplace( + chain; + likelihood=:classification, + subset_of_weights=model.subset_of_weights, + subnetwork_indices=model.subnetwork_indices, + hessian_structure=model.hessian_structure, + backend=model.backend, + σ=model.σ, + μ₀=model.μ₀, + P₀=model.P₀, + ) + # fit the Laplace model: LaplaceRedux.fit!(la, zip(X, y)) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)