Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Laplace approximation #59

Merged
merged 30 commits into from
Sep 29, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
31c2cf3
add Laplace approximation
st-- Sep 24, 2021
e499f4b
version bump
st-- Sep 24, 2021
38e1076
Update examples/c-comparisons/script.jl
st-- Sep 26, 2021
05d40f1
update example tomls
st-- Sep 27, 2021
6a2c7e5
add comments to example
st-- Sep 27, 2021
351cbdb
address review comments
st-- Sep 27, 2021
27f421d
more docstrings
st-- Sep 27, 2021
992814c
remove varargs to get closure-fields working correctly
st-- Sep 27, 2021
91e54bf
back to varargs
st-- Sep 28, 2021
58a4bb1
fix code comments for Literate.jl style
st-- Sep 28, 2021
9fc55fa
remove ignore_ad workaround - now in ChainRulesCore 1.7
st-- Sep 28, 2021
bad4753
use likelihood object
st-- Sep 28, 2021
82e31dd
clean up example notebook
st-- Sep 28, 2021
3282c54
clean up notebook
st-- Sep 28, 2021
d9e9598
improve documentation
st-- Sep 28, 2021
164c52c
reorganize order of laplace.jl, add more comments, add error for deri…
st-- Sep 28, 2021
c1b47d3
moved todo into https://github.com/JuliaGaussianProcesses/Approximate…
st-- Sep 28, 2021
4ac8313
clean up exports
st-- Sep 28, 2021
7db3db1
fix test
st-- Sep 29, 2021
f34c0b0
add abstractgps internal api test
st-- Sep 29, 2021
90b783b
finish tests
st-- Sep 29, 2021
68e7145
Apply suggestions from code review
st-- Sep 29, 2021
eeb8cee
bugfix
st-- Sep 29, 2021
15a886a
posterior(::LaplaceApproximation, ...): call newton_inner_loop and _l…
st-- Sep 29, 2021
c545be0
add test for erroring of _newton_inner_loop on backward pass
st-- Sep 29, 2021
5b8da87
Apply suggestions from code review
st-- Sep 29, 2021
b3d827f
minor cleanup
st-- Sep 29, 2021
4820d9e
Merge branch 'st/LaplaceApproximation' of github.com:JuliaGaussianPro…
st-- Sep 29, 2021
0b65d16
the -> a mode
st-- Sep 29, 2021
bdacea7
Update src/laplace.jl
st-- Sep 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Copyright (c) 2021

Ross Viljoen
The JuliaGaussianProcess Contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "ApproximateGPs"
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
authors = ["Ross Viljoen <[email protected]>"]
version = "0.1.1"
version = "0.1.2"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
KLDivergences = "3c9cd921-3d3f-41e2-830c-e020174918cc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -23,6 +24,7 @@ ChainRulesCore = "1"
Distributions = "0.25"
FastGaussQuadrature = "0.4"
FillArrays = "0.12"
ForwardDiff = "0.10"
GPLikelihoods = "0.1, 0.2"
KLDivergences = "0.2.1"
Reexport = "1"
Expand Down
84 changes: 44 additions & 40 deletions examples/a-regression/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.3.1"

[[ApproximateGPs]]
deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"]
deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "QuadGK", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"]
path = "../.."
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
version = "0.1.0"
version = "0.1.1"

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[ArrayInterface]]
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
git-tree-sha1 = "d84c956c4c0548b4caf0e4e96cf5b6494b5b1529"
deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
git-tree-sha1 = "b8d49c34c3da35f220e7295659cd0bab8e739fed"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.32"
version = "3.1.33"

[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
Expand Down Expand Up @@ -81,9 +81,9 @@ version = "1.11.5"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "4ce9393e871aca86cc457d9f66976c3da6902ea7"
git-tree-sha1 = "bd4afa1fdeec0c8b89dad3c6e92bc6e3b0fec9ce"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.4.0"
version = "1.6.0"

[[CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
Expand Down Expand Up @@ -117,9 +117,9 @@ version = "0.3.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "4866e381721b30fac8dda4c8cb1d9db45c8d2994"
git-tree-sha1 = "1a90210acd935f222ea19657f143004d2c2a1117"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.37.0"
version = "3.38.0"

[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
Expand All @@ -137,9 +137,9 @@ uuid = "d38c429a-6771-53c6-b99e-75d170b6e991"
version = "0.5.7"

[[DataAPI]]
git-tree-sha1 = "bec2532f8adb82005476c141ec23e921fc20971b"
git-tree-sha1 = "cc70b17275652eb47bc9e5f81635981f13cea5c8"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.8.0"
version = "1.9.0"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
Expand Down Expand Up @@ -168,9 +168,9 @@ version = "1.0.3"

[[DiffRules]]
deps = ["NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "3ed8fa7178a10d1cd0f1ca524f249ba6937490c0"
git-tree-sha1 = "7220bc21c33e990c14f4a9a319b1d242ebc5b269"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.3.0"
version = "1.3.1"

[[Distances]]
deps = ["LinearAlgebra", "Statistics", "StatsAPI"]
Expand Down Expand Up @@ -235,9 +235,9 @@ version = "0.4.7"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
git-tree-sha1 = "caf289224e622f518c9dbfe832cdafa17d7c80a6"
git-tree-sha1 = "7f6ad1a7f4621b4ab8e554133dade99ebc6e7221"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.12.4"
version = "0.12.5"

[[FixedPointNumbers]]
deps = ["Statistics"]
Expand Down Expand Up @@ -300,9 +300,9 @@ version = "0.2.0"

[[GPUArrays]]
deps = ["Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
git-tree-sha1 = "8fac1cf7d6ce0f2249c7acaf25d22e1e85c4a07f"
git-tree-sha1 = "7c39d767a9c55fafd01f7bc8b3fd0adf175fbc97"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "8.0.2"
version = "8.1.0"

[[GPUCompiler]]
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
Expand All @@ -312,9 +312,9 @@ version = "0.12.9"

[[GR]]
deps = ["Base64", "DelimitedFiles", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Printf", "Random", "Serialization", "Sockets", "Test", "UUIDs"]
git-tree-sha1 = "182da592436e287758ded5be6e32c406de3a2e47"
git-tree-sha1 = "c2178cfbc0a5a552e16d097fae508f2024de61a3"
uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
version = "0.58.1"
version = "0.59.0"

[[GR_jll]]
deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Pkg", "Qt5Base_jll", "Zlib_jll", "libpng_jll"]
Expand Down Expand Up @@ -449,9 +449,9 @@ version = "3.100.1+0"

[[LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "8fb1a675d1b51885a78bc980fbf1944279880f97"
git-tree-sha1 = "36d95ecdfbc3240d728f68d73064d5b097fbf2ef"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "4.5.1"
version = "4.5.2"

[[LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand Down Expand Up @@ -635,6 +635,10 @@ git-tree-sha1 = "7937eda4681660b4d6aeeecc2f7e1c81c8ee4e2f"
uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051"
version = "1.3.5+0"

[[OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"

[[OpenSSL_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "15003dcb7d8db3c6c857fda14891a539a8f2705a"
Expand Down Expand Up @@ -672,9 +676,9 @@ version = "0.11.1"

[[Parsers]]
deps = ["Dates"]
git-tree-sha1 = "438d35d2d95ae2c5e8780b330592b6de8494e779"
git-tree-sha1 = "9d8c00ef7a8d110787ff6f170579846f776133a9"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "2.0.3"
version = "2.0.4"

[[Pixman_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand All @@ -700,9 +704,9 @@ version = "1.0.14"

[[Plots]]
deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "GeometryBasics", "JSON", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "PlotThemes", "PlotUtils", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "UUIDs"]
git-tree-sha1 = "2dbafeadadcf7dadff20cd60046bba416b4912be"
git-tree-sha1 = "457b13497a3ea4deb33d273a6a5ea15c25c0ebd9"
uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
version = "1.21.3"
version = "1.22.2"

[[Preferences]]
deps = ["TOML"]
Expand All @@ -726,9 +730,9 @@ version = "5.15.3+0"

[[QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "12fbe86da16df6679be7521dfb39fbc861e1dc7b"
git-tree-sha1 = "78aadffb3efd2155af139781b8a8df1ef279ea39"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.4.1"
version = "2.4.2"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
Expand Down Expand Up @@ -757,9 +761,9 @@ version = "1.1.2"

[[RecipesPipeline]]
deps = ["Dates", "NaNMath", "PlotUtils", "RecipesBase"]
git-tree-sha1 = "d4491becdc53580c6dadb0f6249f90caae888554"
git-tree-sha1 = "7ad0dfa8d03b7bcf8c597f59f5292801730c55b8"
uuid = "01d81517-befc-4cb6-b9ec-a95719d0359c"
version = "0.4.0"
version = "0.4.1"

[[Reexport]]
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
Expand Down Expand Up @@ -820,10 +824,10 @@ deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
git-tree-sha1 = "a322a9493e49c5f3a10b50df3aedaf1cdb3244b7"
deps = ["ChainRulesCore", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "ad42c30a6204c74d264692e633133dcea0e8b14e"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.6.1"
version = "1.6.2"

[[Static]]
deps = ["IfElse"]
Expand Down Expand Up @@ -860,9 +864,9 @@ version = "0.9.10"

[[StructArrays]]
deps = ["Adapt", "DataAPI", "StaticArrays", "Tables"]
git-tree-sha1 = "f41020e84127781af49fc12b7e92becd7f5dd0ba"
git-tree-sha1 = "2ce41e0d042c60ecd131e9fb7154a3bfadbf50d3"
uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
version = "0.6.2"
version = "0.6.3"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
Expand Down Expand Up @@ -900,9 +904,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["ExprTools", "Printf"]
git-tree-sha1 = "209a8326c4f955e2442c07b56029e88bb48299c7"
git-tree-sha1 = "7cb456f358e8f9d102a8b25e8dfedf58fa5689bc"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.12"
version = "0.5.13"

[[TranscodingStreams]]
deps = ["Random", "Test"]
Expand Down Expand Up @@ -1074,9 +1078,9 @@ version = "1.4.0+3"

[[ZipFile]]
deps = ["Libdl", "Printf", "Zlib_jll"]
git-tree-sha1 = "c3a5637e27e914a7a445b8d0ad063d701931e9f7"
git-tree-sha1 = "3593e69e469d2111389a9bd06bac1f3d730ac6de"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.9.3"
version = "0.9.4"

[[Zlib_jll]]
deps = ["Libdl"]
Expand All @@ -1090,9 +1094,9 @@ version = "1.5.0+0"

[[Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "ffbf36ba9cd8476347486a013c93590b910a4855"
git-tree-sha1 = "4b799addc63aa77ad4112cede8086564d9068511"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.21"
version = "0.6.22"

[[ZygoteRules]]
deps = ["MacroTools"]
Expand Down
Loading