Skip to content

Commit

Permalink
Merge pull request #36 from navimakarov/10-mlj-interfacing
Browse files Browse the repository at this point in the history
Mlj interfacing
  • Loading branch information
MarkArdman authored Jun 21, 2023
2 parents 3f64270 + 17a28bb commit 2f59450
Show file tree
Hide file tree
Showing 10 changed files with 1,726 additions and 3 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ jobs:
${{ runner.os }}-
- name: Run benchmark
run: |
cd test/benchmarks
julia -e 'import Pkg; Pkg.add("LaplaceRedux")'
julia -e 'import Pkg; Pkg.add("Flux")'
julia -e 'import Pkg; Pkg.add("Plots")'
Expand All @@ -131,10 +132,9 @@ jobs:
julia -e 'import Pkg; Pkg.add("Printf")'
julia -e 'import Pkg; Pkg.add("BenchmarkTools")'
julia -e 'import Pkg; Pkg.add("Tullio")'
cd test/benchmarks
julia --project --color=yes -e '
using Pkg;
Pkg.instantiate();
Pkg.resolve();
include("BenchmarkFit.jl")'
- name: Store benchmark result
Expand Down
9 changes: 9 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@ version = "0.1.2"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
1,242 changes: 1,242 additions & 0 deletions dev/notebooks/mlj-interfacing/mlj.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/LaplaceRedux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ export Laplace,
posterior_covariance,
posterior_precision

include("mlj_flux.jl")
export LaplaceApproximation

include("plotting.jl")

end
2 changes: 2 additions & 0 deletions src/curvature/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ function jacobians(curvature::CurvatureInterface, X::AbstractArray)
nn = curvature.model
# Output:
= nn(X)
# Convert ŷ to a vector
= vec(ŷ)
# Jacobian:
# Differentiate f with regards to the model parameters
𝐉 = jacobian(() -> nn(X), Flux.params(nn))
Expand Down
2 changes: 2 additions & 0 deletions src/laplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ function validate_subnetwork_indices(
subnetwork_indices::Union{Nothing,Vector{Vector{Int}}}, params
)
@assert (subnetwork_indices !== nothing) "If `subset_of_weights` is `:subnetwork`, then `subnetwork_indices` should be a vector of vectors of integers."
# Check if subnetwork_indices is a vector containing an empty vector
@assert !(subnetwork_indices == [[]]) "If `subset_of_weights` is `:subnetwork`, then `subnetwork_indices` should be a vector of vectors of integers."
# Initialise a set of vectors
selected = Set{Vector{Int}}()
for (i, index) in enumerate(subnetwork_indices)
Expand Down
Loading

0 comments on commit 2f59450

Please sign in to comment.