Skip to content

Commit

Permalink
also changed for regression
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Sep 12, 2024
1 parent c9f7e65 commit 70e0985
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 46 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ jobs:
fail-fast: false
matrix:
version:
- '1.9'
- '1.10'
os:
- ubuntu-latest
Expand All @@ -29,10 +28,10 @@ jobs:
- x64
include:
- os: windows-latest
version: '1.9'
version: '1'
arch: x64
- os: macOS-latest
version: '1.9'
version: '1'
arch: x64
steps:
- uses: actions/checkout@v2
Expand Down
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LaplaceRedux"
uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
authors = ["Patrick Altmeyer"]
version = "1.1.0"
version = "1.1.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down Expand Up @@ -36,13 +36,13 @@ MLJModelInterface = "1.8.0"
MLUtils = "0.4"
Optimisers = "0.2, 0.3"
ProgressMeter = "1.7.2"
Random = "1.9, 1.10"
Random = "1.10"
Statistics = "1"
Tables = "1.10.1"
Test = "1.9, 1.10"
Test = "1.10"
Tullio = "0.3.5"
Zygote = "0.6"
julia = "1.9, 1.10"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
4 changes: 2 additions & 2 deletions dev/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.5"
manifest_format = "2.0"
project_hash = "7c4391b00ad44ccd083aa89cf817ea4c83643a8f"
project_hash = "6224e923196c94cc69f23bf59d6ea60b7ea2c919"

[[deps.AbstractFFTs]]
deps = ["LinearAlgebra"]
Expand Down Expand Up @@ -717,7 +717,7 @@ version = "1.3.1"
deps = ["ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"]
path = ".."
uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
version = "1.1.0"
version = "1.1.1"

[[deps.Latexify]]
deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"]
Expand Down
1 change: 1 addition & 0 deletions dev/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
Luxor = "ae8d54c2-7ccd-5906-9d76-62fc9837b5bc"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
PlotThemes = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
5 changes: 3 additions & 2 deletions dev/issues/predict_slow.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
using LaplaceRedux
using MLJBase
using Optimisers

X = MLJBase.table(rand(Float32, 100, 3));
y = coerce(rand("abc", 100), Multiclass);
model = LaplaceClassification();
fitresult, _, _ = MLJBase.fit(model, 0, X, y);
model = LaplaceClassification(optimiser=Optimisers.Adam(0.1), epochs=100);
fitresult, _, _ = MLJBase.fit(model, 2, X, y);
la = fitresult[1];
Xmat = matrix(X) |> permutedims;

Expand Down
60 changes: 25 additions & 35 deletions src/mlj_flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,21 +247,17 @@ function MLJFlux.train(
push!(history, current_loss)
end

if !isa(chain, AbstractLaplace)
la = LaplaceRedux.Laplace(
chain;
likelihood=:regression,
subset_of_weights=model.subset_of_weights,
subnetwork_indices=model.subnetwork_indices,
hessian_structure=model.hessian_structure,
backend=model.backend,
σ=model.σ,
μ₀=model.μ₀,
P₀=model.P₀,
)
else
la = chain
end
la = LaplaceRedux.Laplace(
chain;
likelihood=:regression,
subset_of_weights=model.subset_of_weights,
subnetwork_indices=model.subnetwork_indices,
hessian_structure=model.hessian_structure,
backend=model.backend,
σ=model.σ,
μ₀=model.μ₀,
P₀=model.P₀,
)

# fit the Laplace model:
LaplaceRedux.fit!(la, zip(X, y))
Expand All @@ -285,12 +281,9 @@ Predict the output for new input data using a Laplace regression model.
"""
function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew)
Xnew = MLJBase.matrix(Xnew)
Xnew = MLJBase.matrix(Xnew) |> permutedims
la = fitresult[1]
#convert in a vector of vectors because MLJ ask to do so
X_vec = collect(eachrow(Xnew))
# Predict using Laplace and collect the predictions
yhat = [map(x -> LaplaceRedux.predict(la, x; ret_distr=model.ret_distr), X_vec)...]
yhat = LaplaceRedux.predict(la, Xnew; ret_distr=model.ret_distr)
return yhat
end

Expand Down Expand Up @@ -416,21 +409,18 @@ function MLJFlux.train(
push!(history, current_loss)
end

if !isa(chain, AbstractLaplace)
la = LaplaceRedux.Laplace(
chain;
likelihood=:classification,
subset_of_weights=model.subset_of_weights,
subnetwork_indices=model.subnetwork_indices,
hessian_structure=model.hessian_structure,
backend=model.backend,
σ=model.σ,
μ₀=model.μ₀,
P₀=model.P₀,
)
else
la = chain
end
la = LaplaceRedux.Laplace(
chain;
likelihood=:classification,
subset_of_weights=model.subset_of_weights,
subnetwork_indices=model.subnetwork_indices,
hessian_structure=model.hessian_structure,
backend=model.backend,
σ=model.σ,
μ₀=model.μ₀,
P₀=model.P₀,
)

# fit the Laplace model:
LaplaceRedux.fit!(la, zip(X, y))
optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)
Expand Down

0 comments on commit 70e0985

Please sign in to comment.