Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: draft PR for adding Laplace approximation and Expectation Propagation #58

Closed
wants to merge 63 commits into from
Closed
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
3b8a075
initial commit of Laplace approximation & EP
st-- Sep 17, 2021
c73e9ce
Laplace demo
st-- Sep 17, 2021
93ec795
bugfix
st-- Sep 17, 2021
d62b2dc
bugfix demo
st-- Sep 17, 2021
76afc55
WIP: cleanup; pass around dist_y_given_f explicitly
st-- Sep 17, 2021
46009e4
use ForwardDiff elementwise (much faster)
st-- Sep 17, 2021
0c1a329
format
st-- Sep 17, 2021
88410cc
WIP: gradient
st-- Sep 17, 2021
4853b5f
*Very* WIP - but works
st-- Sep 20, 2021
3c91a22
frule bugfix
st-- Sep 21, 2021
513d34d
clean up comments
st-- Sep 21, 2021
33af436
intermediate cleanup (callbacks not tested)
st-- Sep 21, 2021
30b45ac
fix callback support
st-- Sep 21, 2021
09e5e82
Laplace initial tests
st-- Sep 21, 2021
2591555
make argument order consistent
st-- Sep 21, 2021
5d58f5d
more cleanup
st-- Sep 21, 2021
6c4e56d
cleanup
st-- Sep 21, 2021
ead526f
cleanup
st-- Sep 21, 2021
fe76626
@info -> @debug and ChainRulesTestUtil workaround
st-- Sep 21, 2021
f7439d6
chainrule tests
st-- Sep 21, 2021
d71ca20
add @info for res_cold/res_warm
st-- Sep 22, 2021
ecb29d9
cleanup
st-- Sep 22, 2021
b275332
explicit FiniteDifferences gradient test on laplace_lml
st-- Sep 22, 2021
8135c40
Merge branch 'master' of github.com:JuliaGaussianProcesses/Approximat…
st-- Sep 22, 2021
0d8eb02
format
st-- Sep 22, 2021
af07036
format
st-- Sep 22, 2021
6f08e73
remove Zygote dependency - part 1
st-- Sep 23, 2021
f84ed86
remove Zygote dependency - part 2
st-- Sep 23, 2021
e4aab7e
pkg bugfix
st-- Sep 23, 2021
d19003f
update example manifests
st-- Sep 23, 2021
2d43113
fix chainrule test by evaluating frule/rrule on newton_inner_loop bas…
st-- Sep 23, 2021
bbd24ec
add compat
st-- Sep 23, 2021
ac6486d
clean up test
st-- Sep 23, 2021
96957e3
add missing CRC dependency
st-- Sep 23, 2021
5af23fa
remove workaround
st-- Sep 23, 2021
54acbe3
use more of AbstractGPs API
st-- Sep 23, 2021
df88117
clean up laplace_steps
st-- Sep 23, 2021
8d20a65
add laplace example
st-- Sep 23, 2021
30ba663
cleanup
st-- Sep 23, 2021
8209442
cleanup2
st-- Sep 23, 2021
def61e7
format
st-- Sep 23, 2021
4566c17
remove demo script
st-- Sep 23, 2021
4095df1
bugfix
st-- Sep 23, 2021
91d0b34
update notebook
st-- Sep 23, 2021
bdc3618
bugfix 2
st-- Sep 23, 2021
12655f8
bugfiiiix
st-- Sep 23, 2021
b1c0f80
also plot mean
st-- Sep 23, 2021
3c97c97
improve plotting
st-- Sep 23, 2021
81ceb93
Apply suggestions from code review
st-- Sep 23, 2021
82d4695
remove `@ref` that does not work
st-- Sep 23, 2021
aca90b1
cleanup
st-- Sep 23, 2021
02b528b
make use of closure fields
st-- Sep 23, 2021
cbfbc95
Merge branch 'st/laplace_and_ep' of github.com:JuliaGaussianProcesses…
st-- Sep 23, 2021
e788f4c
yaf
st-- Sep 23, 2021
ca68222
improved type stability
st-- Sep 24, 2021
9144ed3
replace QuadGK with Gauss-Hermite
st-- Sep 24, 2021
771ef3e
Apply suggestions from code review
st-- Sep 24, 2021
b52ef55
more type stability cleanup
st-- Sep 24, 2021
4f45ff6
Merge branch 'st/laplace_and_ep' of github.com:JuliaGaussianProcesses…
st-- Sep 24, 2021
4a694d9
fix test
st-- Sep 24, 2021
62840ad
add missing test file
st-- Sep 24, 2021
6f7a5ba
more explanation on the example script
st-- Sep 24, 2021
5445a95
fix test seed
st-- Sep 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Copyright (c) 2021

Ross Viljoen
The JuliaGaussianProcess Contributors
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't check how it's done in all the other packages, but particularly when putting packages into an org I find it much more inviting to future contributors if it just says "Copyright the people who contributed to it"... let me know your thoughts on this, happy to discuss it more in depth

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that for the license we should just have "The JuliaGaussianProcesses organization" and for the author field in the Project.toml have `JuliaGaussianProcesses and contributors"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My inclination has normally been to go with "Person who created the package + the org", so "Ross Viljoen and the JuliaGaussianProcess org", so that the person who did the initial work retains some credit -- in this case, it would be good to keep Ross' name on it imho. Just a suggestion though.

I've just taken a look through the other packages in the org, and it seems like we're pretty inconsistent. AbstractGPs just mentions JuliaGPs, KF mentions Theo and Turing. We should probably agree to a consistent approach and stick to it, but this PR isn't the place for that.

Consequently, could I suggest that we leave the license as-is in this PR, and update it elsewhere once we've made an org-wide decision?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this pr isn't gonna get merged anyways: ) just used the opportunity to spark the discussion.

I'd prefer just "the org"; I think it makes more sense from an overall point of view to use this package, for example, to bundle SVGP, Laplace, EP - and I believe we incentivise contributions by new people by not giving the feeling that you're just increasing someone else's credit with your contribution.

From my personal point of view, as much as I want to get myself not to care about attribution as much, I can feel it inside myself that I'd rather create github.com/st--/LaplaceGPs.jl or something like that that I can then put my name on than to add a large set of functionality to "someone else's project". Whereas if it's just listed as "the org" then it feels like by contributing more code I deserve more to be part of the org and take my share of the credit from there. Which is why I changed the text in my branch before pushing it.


Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Laplace needs derivatives for the inner loop. Hopefully, ForwardDiff is a sufficiently lightweight dependency.

GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
KLDivergences = "3c9cd921-3d3f-41e2-830c-e020174918cc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
st-- marked this conversation as resolved.
Show resolved Hide resolved
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -23,8 +25,10 @@ ChainRulesCore = "1"
Distributions = "0.25"
FastGaussQuadrature = "0.4"
FillArrays = "0.12"
ForwardDiff = "0.10"
GPLikelihoods = "0.1, 0.2"
KLDivergences = "0.2.1"
QuadGK = "2"
Reexport = "1"
SpecialFunctions = "1"
StatsBase = "0.33"
Expand Down
84 changes: 44 additions & 40 deletions examples/a-regression/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.3.1"

[[ApproximateGPs]]
deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"]
deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "QuadGK", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"]
path = "../.."
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
version = "0.1.0"
version = "0.1.1"

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[ArrayInterface]]
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
git-tree-sha1 = "d84c956c4c0548b4caf0e4e96cf5b6494b5b1529"
deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
git-tree-sha1 = "b8d49c34c3da35f220e7295659cd0bab8e739fed"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.32"
version = "3.1.33"

[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
Expand Down Expand Up @@ -81,9 +81,9 @@ version = "1.11.5"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "4ce9393e871aca86cc457d9f66976c3da6902ea7"
git-tree-sha1 = "bd4afa1fdeec0c8b89dad3c6e92bc6e3b0fec9ce"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.4.0"
version = "1.6.0"

[[CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
Expand Down Expand Up @@ -117,9 +117,9 @@ version = "0.3.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "4866e381721b30fac8dda4c8cb1d9db45c8d2994"
git-tree-sha1 = "1a90210acd935f222ea19657f143004d2c2a1117"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.37.0"
version = "3.38.0"

[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
Expand All @@ -137,9 +137,9 @@ uuid = "d38c429a-6771-53c6-b99e-75d170b6e991"
version = "0.5.7"

[[DataAPI]]
git-tree-sha1 = "bec2532f8adb82005476c141ec23e921fc20971b"
git-tree-sha1 = "cc70b17275652eb47bc9e5f81635981f13cea5c8"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.8.0"
version = "1.9.0"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
Expand Down Expand Up @@ -168,9 +168,9 @@ version = "1.0.3"

[[DiffRules]]
deps = ["NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "3ed8fa7178a10d1cd0f1ca524f249ba6937490c0"
git-tree-sha1 = "7220bc21c33e990c14f4a9a319b1d242ebc5b269"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.3.0"
version = "1.3.1"

[[Distances]]
deps = ["LinearAlgebra", "Statistics", "StatsAPI"]
Expand Down Expand Up @@ -235,9 +235,9 @@ version = "0.4.7"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
git-tree-sha1 = "caf289224e622f518c9dbfe832cdafa17d7c80a6"
git-tree-sha1 = "7f6ad1a7f4621b4ab8e554133dade99ebc6e7221"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.12.4"
version = "0.12.5"

[[FixedPointNumbers]]
deps = ["Statistics"]
Expand Down Expand Up @@ -300,9 +300,9 @@ version = "0.2.0"

[[GPUArrays]]
deps = ["Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
git-tree-sha1 = "8fac1cf7d6ce0f2249c7acaf25d22e1e85c4a07f"
git-tree-sha1 = "7c39d767a9c55fafd01f7bc8b3fd0adf175fbc97"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "8.0.2"
version = "8.1.0"

[[GPUCompiler]]
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
Expand All @@ -312,9 +312,9 @@ version = "0.12.9"

[[GR]]
deps = ["Base64", "DelimitedFiles", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Printf", "Random", "Serialization", "Sockets", "Test", "UUIDs"]
git-tree-sha1 = "182da592436e287758ded5be6e32c406de3a2e47"
git-tree-sha1 = "c2178cfbc0a5a552e16d097fae508f2024de61a3"
uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
version = "0.58.1"
version = "0.59.0"

[[GR_jll]]
deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Pkg", "Qt5Base_jll", "Zlib_jll", "libpng_jll"]
Expand Down Expand Up @@ -449,9 +449,9 @@ version = "3.100.1+0"

[[LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "8fb1a675d1b51885a78bc980fbf1944279880f97"
git-tree-sha1 = "36d95ecdfbc3240d728f68d73064d5b097fbf2ef"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "4.5.1"
version = "4.5.2"

[[LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand Down Expand Up @@ -635,6 +635,10 @@ git-tree-sha1 = "7937eda4681660b4d6aeeecc2f7e1c81c8ee4e2f"
uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051"
version = "1.3.5+0"

[[OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"

[[OpenSSL_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "15003dcb7d8db3c6c857fda14891a539a8f2705a"
Expand Down Expand Up @@ -672,9 +676,9 @@ version = "0.11.1"

[[Parsers]]
deps = ["Dates"]
git-tree-sha1 = "438d35d2d95ae2c5e8780b330592b6de8494e779"
git-tree-sha1 = "9d8c00ef7a8d110787ff6f170579846f776133a9"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "2.0.3"
version = "2.0.4"

[[Pixman_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand All @@ -700,9 +704,9 @@ version = "1.0.14"

[[Plots]]
deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "GeometryBasics", "JSON", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "PlotThemes", "PlotUtils", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "UUIDs"]
git-tree-sha1 = "2dbafeadadcf7dadff20cd60046bba416b4912be"
git-tree-sha1 = "457b13497a3ea4deb33d273a6a5ea15c25c0ebd9"
uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
version = "1.21.3"
version = "1.22.2"

[[Preferences]]
deps = ["TOML"]
Expand All @@ -726,9 +730,9 @@ version = "5.15.3+0"

[[QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "12fbe86da16df6679be7521dfb39fbc861e1dc7b"
git-tree-sha1 = "78aadffb3efd2155af139781b8a8df1ef279ea39"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.4.1"
version = "2.4.2"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
Expand Down Expand Up @@ -757,9 +761,9 @@ version = "1.1.2"

[[RecipesPipeline]]
deps = ["Dates", "NaNMath", "PlotUtils", "RecipesBase"]
git-tree-sha1 = "d4491becdc53580c6dadb0f6249f90caae888554"
git-tree-sha1 = "7ad0dfa8d03b7bcf8c597f59f5292801730c55b8"
uuid = "01d81517-befc-4cb6-b9ec-a95719d0359c"
version = "0.4.0"
version = "0.4.1"

[[Reexport]]
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
Expand Down Expand Up @@ -820,10 +824,10 @@ deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
git-tree-sha1 = "a322a9493e49c5f3a10b50df3aedaf1cdb3244b7"
deps = ["ChainRulesCore", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "ad42c30a6204c74d264692e633133dcea0e8b14e"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.6.1"
version = "1.6.2"

[[Static]]
deps = ["IfElse"]
Expand Down Expand Up @@ -860,9 +864,9 @@ version = "0.9.10"

[[StructArrays]]
deps = ["Adapt", "DataAPI", "StaticArrays", "Tables"]
git-tree-sha1 = "f41020e84127781af49fc12b7e92becd7f5dd0ba"
git-tree-sha1 = "2ce41e0d042c60ecd131e9fb7154a3bfadbf50d3"
uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
version = "0.6.2"
version = "0.6.3"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
Expand Down Expand Up @@ -900,9 +904,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["ExprTools", "Printf"]
git-tree-sha1 = "209a8326c4f955e2442c07b56029e88bb48299c7"
git-tree-sha1 = "7cb456f358e8f9d102a8b25e8dfedf58fa5689bc"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.12"
version = "0.5.13"

[[TranscodingStreams]]
deps = ["Random", "Test"]
Expand Down Expand Up @@ -1074,9 +1078,9 @@ version = "1.4.0+3"

[[ZipFile]]
deps = ["Libdl", "Printf", "Zlib_jll"]
git-tree-sha1 = "c3a5637e27e914a7a445b8d0ad063d701931e9f7"
git-tree-sha1 = "3593e69e469d2111389a9bd06bac1f3d730ac6de"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.9.3"
version = "0.9.4"

[[Zlib_jll]]
deps = ["Libdl"]
Expand All @@ -1090,9 +1094,9 @@ version = "1.5.0+0"

[[Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "ffbf36ba9cd8476347486a013c93590b910a4855"
git-tree-sha1 = "4b799addc63aa77ad4112cede8086564d9068511"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.21"
version = "0.6.22"

[[ZygoteRules]]
deps = ["MacroTools"]
Expand Down
Loading