diff --git a/CHANGELOG.md b/CHANGELOG.md index 255db12..bab39a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), *Note*: We try to adhere to these practices as of version [v0.2.1]. + +## Version [1.0.2] - 2024-08-12 +### +- added TaijaPlotting to the docs env +### Changed +- modified the MLJFlux.train function so that it now properly return a trained chain [[#112](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/issues/112)] + ## Version [1.0.0] - 2024-07-22 ### Changed diff --git a/Project.toml b/Project.toml index 5864743..e54b4fc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LaplaceRedux" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" authors = ["Patrick Altmeyer"] -version = "1.0.1" +version = "1.0.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 25d7d67..a6810a1 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.4" +julia_version = "1.10.3" manifest_format = "2.0" -project_hash = "07bab9fa5d046478b21247a44464171c6b19ad4c" +project_hash = "0bd11d5fa58aad2714bf7893e520fc7c086ef3ca" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -71,6 +71,48 @@ version = "2.3.0" uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.1" +[[deps.Arpack]] +deps = ["Arpack_jll", "Libdl", "LinearAlgebra", "Logging"] +git-tree-sha1 = "9b9b347613394885fd1c8c7729bfc60528faa436" +uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" +version = "0.5.4" + +[[deps.Arpack_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "OpenBLAS_jll", "Pkg"] +git-tree-sha1 = "5ba6c757e8feccf03a1554dfaf3e26b3cfc7fd5e" +uuid = "68821587-b530-5797-8361-c406ea357684" +version = "3.5.1+1" + +[[deps.ArrayInterface]] +deps = ["Adapt", "LinearAlgebra"] +git-tree-sha1 = "f54c23a5d304fb87110de62bace7777d59088c34" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "7.15.0" + + [deps.ArrayInterface.extensions] + ArrayInterfaceBandedMatricesExt = "BandedMatrices" + ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" + ArrayInterfaceCUDAExt = "CUDA" + ArrayInterfaceCUDSSExt = "CUDSS" + ArrayInterfaceChainRulesExt = "ChainRules" + ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" + ArrayInterfaceReverseDiffExt = "ReverseDiff" + ArrayInterfaceSparseArraysExt = "SparseArrays" + ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" + ArrayInterfaceTrackerExt = "Tracker" + + [deps.ArrayInterface.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" + ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" + GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -316,6 +358,12 @@ git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" version = "2.4.2" +[[deps.ConformalPrediction]] +deps = ["CategoricalArrays", "ChainRules", "ComputationalResources", "Flux", "InferOpt", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJLinearModels", "MLJModelInterface", "MLUtils", "ProgressMeter", "Random", "StatsBase", "Tables"] +git-tree-sha1 = "c5ddd335cb7557efbaf44da2d2c6d395ea41e18d" +uuid = "98bfc277-1877-43dc-819b-a3e38c30242f" +version = "0.1.13" + [[deps.ConstructionBase]] deps = ["LinearAlgebra"] git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" @@ -341,6 +389,22 @@ git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8" uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" version = "0.6.3" +[[deps.CounterfactualExplanations]] +deps = ["CategoricalArrays", "ChainRulesCore", "DataFrames", "Distributions", "Flux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLJBase", "MLJDecisionTreeInterface", "MLUtils", "MultivariateStats", "PackageExtensionCompat", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "TaijaBase", "UUIDs"] +git-tree-sha1 = "8a68385b6852e9357889aea661536059bc8b6158" +uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" +version = "1.1.6" + + [deps.CounterfactualExplanations.extensions] + DecisionTreeExt = "DecisionTree" + LaplaceReduxExt = "LaplaceRedux" + NeuroTreeExt = "NeuroTreeModels" + + [deps.CounterfactualExplanations.weakdeps] + DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" + LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" + NeuroTreeModels = "1db4e0a5-a364-4b0c-897c-2bd5a4a3a1f2" + [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" @@ -372,6 +436,18 @@ version = "1.0.0" deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +[[deps.Dbus_jll]] +deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl"] +git-tree-sha1 = "fc173b380865f70627d7dd1190dc2fce6cc105af" +uuid = "ee1fde0b-3d02-5ea6-8484-8dfef6360eab" +version = "1.14.10+0" + +[[deps.DecisionTree]] +deps = ["AbstractTrees", "DelimitedFiles", "LinearAlgebra", "Random", "ScikitLearnBase", "Statistics"] +git-tree-sha1 = "526ca14aaaf2d5a0e242f3a8a7966eb9065d7d78" +uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" +version = "0.12.4" + [[deps.DefineSingletons]] git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" @@ -383,6 +459,12 @@ git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" version = "1.9.1" +[[deps.DensityInterface]] +deps = ["InverseFunctions", "Test"] +git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" +uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" +version = "0.4.0" + [[deps.DiffResults]] deps = ["StaticArraysCore"] git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" @@ -395,6 +477,17 @@ git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" version = "1.15.1" +[[deps.Distances]] +deps = ["LinearAlgebra", "Statistics", "StatsAPI"] +git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.10.11" +weakdeps = ["ChainRulesCore", "SparseArrays"] + + [deps.Distances.extensions] + DistancesChainRulesCoreExt = "ChainRulesCore" + DistancesSparseArraysExt = "SparseArrays" + [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -404,17 +497,13 @@ deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadG git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" version = "0.25.109" +weakdeps = ["ChainRulesCore", "DensityInterface", "Test"] [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" DistributionsDensityInterfaceExt = "DensityInterface" DistributionsTestExt = "Test" - [deps.Distributions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" - Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - [[deps.DocStringExtensions]] deps = ["LibGit2"] git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" @@ -512,6 +601,22 @@ weakdeps = ["PDMats", "SparseArrays", "Statistics"] FillArraysSparseArraysExt = "SparseArrays" FillArraysStatisticsExt = "Statistics" +[[deps.FiniteDiff]] +deps = ["ArrayInterface", "LinearAlgebra", "Setfield", "SparseArrays"] +git-tree-sha1 = "f9219347ebf700e77ca1d48ef84e4a82a6701882" +uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" +version = "2.24.0" + + [deps.FiniteDiff.extensions] + FiniteDiffBandedMatricesExt = "BandedMatrices" + FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices" + FiniteDiffStaticArraysExt = "StaticArrays" + + [deps.FiniteDiff.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + [[deps.FixedPointNumbers]] deps = ["Statistics"] git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" @@ -580,7 +685,7 @@ deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.GLFW_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "libdecor_jll", "xkbcommon_jll"] git-tree-sha1 = "3f74912a156096bd8fdbef211eff66ab446e7297" uuid = "0656b61e-2033-5cc2-a64a-77c0f6c09b89" version = "3.4.0+0" @@ -680,6 +785,18 @@ git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" version = "0.4.14" +[[deps.InferOpt]] +deps = ["ChainRulesCore", "DensityInterface", "LinearAlgebra", "Random", "RequiredInterfaces", "Statistics", "StatsBase", "StatsFuns", "ThreadsX"] +git-tree-sha1 = "cbe07b2683de4b1dd0c8def5e5f62ce97c60d24c" +uuid = "4846b161-c94e-4150-8dac-c7ae193c601f" +version = "0.6.1" + + [deps.InferOpt.extensions] + InferOptFrankWolfeExt = "DifferentiableFrankWolfe" + + [deps.InferOpt.weakdeps] + DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d" + [[deps.InitialValues]] git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" @@ -722,6 +839,12 @@ git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" version = "0.2.2" +[[deps.IterativeSolvers]] +deps = ["LinearAlgebra", "Printf", "Random", "RecipesBase", "SparseArrays"] +git-tree-sha1 = "59545b0a2b27208b0650df0a46b8e3019f85055b" +uuid = "42fd0dbc-a981-5370-80f2-aaf504508153" +version = "0.9.4" + [[deps.IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" @@ -833,9 +956,9 @@ version = "1.3.1" [[deps.LaplaceRedux]] deps = ["ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] -path = ".." +git-tree-sha1 = "27821766cccfcef9a9d6b9cee6e924796ec845dd" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" -version = "1.0.0" +version = "1.0.1" [[deps.Latexify]] deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] @@ -941,10 +1064,28 @@ git-tree-sha1 = "5ee6203157c120d79034c748a2acba45b82b8807" uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" version = "2.40.1+0" +[[deps.LineSearches]] +deps = ["LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "Printf"] +git-tree-sha1 = "e4c3be53733db1051cc15ecf573b1042b3a712a1" +uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" +version = "7.3.0" + [[deps.LinearAlgebra]] deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[[deps.LinearMaps]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "ee79c3208e55786de58f8dcccca098ced79f743f" +uuid = "7a12625a-238d-50fd-b39a-03d52299707e" +version = "3.11.3" +weakdeps = ["ChainRulesCore", "SparseArrays", "Statistics"] + + [deps.LinearMaps.extensions] + LinearMapsChainRulesCoreExt = "ChainRulesCore" + LinearMapsSparseArraysExt = "SparseArrays" + LinearMapsStatisticsExt = "Statistics" + [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" @@ -982,12 +1123,30 @@ version = "1.7.0" [deps.MLJBase.weakdeps] StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" +[[deps.MLJDecisionTreeInterface]] +deps = ["CategoricalArrays", "DecisionTree", "MLJModelInterface", "Random", "Tables"] +git-tree-sha1 = "90ef4d3b6cacec631c57cc034e1e61b4aa0ce511" +uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661" +version = "0.4.2" + +[[deps.MLJEnsembles]] +deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatisticalMeasuresBase", "StatsBase"] +git-tree-sha1 = "84a5be55a364bb6b6dc7780bbd64317ebdd3ad1e" +uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" +version = "0.4.3" + [[deps.MLJFlux]] deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables"] git-tree-sha1 = "50c7f24b84005a2a80875c10d4f4059df17a0f68" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" version = "0.5.1" +[[deps.MLJLinearModels]] +deps = ["DocStringExtensions", "IterativeSolvers", "LinearAlgebra", "LinearMaps", "MLJModelInterface", "Optim", "Parameters"] +git-tree-sha1 = "7f517fd840ca433a8fae673edb31678ff55d969c" +uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692" +version = "0.10.0" + [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] git-tree-sha1 = "ceaff6618408d0e412619321ae43b33b40c1a733" @@ -1072,6 +1231,18 @@ version = "0.7.8" uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.1.10" +[[deps.MultivariateStats]] +deps = ["Arpack", "Distributions", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] +git-tree-sha1 = "816620e3aac93e5b5359e4fdaf23ca4525b00ddf" +uuid = "6f286f6a-111f-5878-ab1e-185364afe411" +version = "0.10.3" + +[[deps.NLSolversBase]] +deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] +git-tree-sha1 = "a0b464d183da839699f4c79e7606d9d186ec172c" +uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" +version = "7.8.3" + [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" @@ -1116,6 +1287,23 @@ git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" version = "0.1.5" +[[deps.NaturalSort]] +git-tree-sha1 = "eda490d06b9f7c00752ee81cfa451efe55521e21" +uuid = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" +version = "1.0.0" + +[[deps.NearestNeighborModels]] +deps = ["Distances", "FillArrays", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "NearestNeighbors", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "e411143a8362926e4284a54e745972e939fbab78" +uuid = "636a865e-7cf4-491e-846c-de09b730eb36" +version = "0.2.3" + +[[deps.NearestNeighbors]] +deps = ["Distances", "StaticArrays"] +git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" +uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +version = "0.4.18" + [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" @@ -1160,6 +1348,18 @@ git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" version = "0.5.5+0" +[[deps.Optim]] +deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] +git-tree-sha1 = "d9b79c4eed437421ac4285148fcadf42e0700e89" +uuid = "429524aa-4258-5aef-a3af-852621145aeb" +version = "1.9.4" + + [deps.Optim.extensions] + OptimMOIExt = "MathOptInterface" + + [deps.Optim.weakdeps] + MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" + [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" @@ -1188,6 +1388,18 @@ git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" version = "0.11.31" +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + +[[deps.Pango_jll]] +deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl"] +git-tree-sha1 = "cb5a2ab6763464ae0f19c86c56c63d4a2b0f5bda" +uuid = "36c8627f-9965-5494-a995-c6b170f724f3" +version = "1.52.2+0" + [[deps.Parameters]] deps = ["OrderedCollections", "UnPack"] git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" @@ -1260,6 +1472,12 @@ git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" version = "1.4.3" +[[deps.PositiveFactorizations]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "17275485f373e6673f7e7f97051f703ed5b15b20" +uuid = "85a6dd25-e78a-55b7-8502-1745935b8125" +version = "0.2.4" + [[deps.PrecompileTools]] deps = ["Preferences"] git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" @@ -1389,6 +1607,12 @@ git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" uuid = "189a3867-3050-52da-a836-e630ba90ab69" version = "1.2.2" +[[deps.Referenceables]] +deps = ["Adapt"] +git-tree-sha1 = "02d31ad62838181c1a3a5fd23a1ce5914a643601" +uuid = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" +version = "0.1.3" + [[deps.RegistryInstances]] deps = ["LazilyInitializedFields", "Pkg", "TOML", "Tar"] git-tree-sha1 = "ffd19052caf598b8653b99404058fce14828be51" @@ -1401,6 +1625,12 @@ git-tree-sha1 = "ffdaf70d81cf6ff22c2b6e733c900c3321cab864" uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" version = "1.0.1" +[[deps.RequiredInterfaces]] +deps = ["InteractiveUtils", "Logging", "Test"] +git-tree-sha1 = "c3250333ea2894237ed015baf7d5fcb8a1ea3169" +uuid = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6" +version = "0.1.6" + [[deps.Requires]] deps = ["UUIDs"] git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" @@ -1434,6 +1664,12 @@ git-tree-sha1 = "a8e18eb383b5ecf1b5e6fc237eb39255044fd92b" uuid = "30f210dd-8aff-4c5f-94ba-8e64358c1161" version = "3.0.0" +[[deps.ScikitLearnBase]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "7877e55c1523a4b336b433da39c8e8c08d2f221f" +uuid = "6e75b9c4-186b-50bd-896f-2d2496a4843e" +version = "0.5.0" + [[deps.Scratch]] deps = ["Dates"] git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" @@ -1513,6 +1749,12 @@ git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" +[[deps.StableRNGs]] +deps = ["Random"] +git-tree-sha1 = "83e6cce8324d49dfaf9ef059227f91ed4441a8e5" +uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" +version = "1.0.2" + [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" @@ -1620,6 +1862,18 @@ git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" version = "1.12.0" +[[deps.TaijaBase]] +deps = ["CategoricalArrays", "Distributions", "Flux", "MLUtils", "Optimisers", "StatsBase", "Tables"] +git-tree-sha1 = "1c80c4472c6ab6e8c9fa544a22d907295b388dd0" +uuid = "10284c91-9f28-4c9a-abbf-ee43576dfff6" +version = "1.2.2" + +[[deps.TaijaPlotting]] +deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "MLUtils", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "Plots"] +git-tree-sha1 = "2a4fcdf2abd5533d6d24a97ce5e89327391b2dc1" +uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" +version = "1.1.2" + [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" @@ -1635,6 +1889,12 @@ version = "0.1.1" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[deps.ThreadsX]] +deps = ["Accessors", "ArgCheck", "BangBang", "ConstructionBase", "InitialValues", "MicroCollections", "Referenceables", "SplittablesBase", "Transducers"] +git-tree-sha1 = "70bd8244f4834d46c3d68bd09e7792d8f571ef04" +uuid = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" +version = "0.1.12" + [[deps.TimeZones]] deps = ["Dates", "Downloads", "InlineStrings", "Mocking", "Printf", "Scratch", "TZJData", "Unicode", "p7zip_jll"] git-tree-sha1 = "a6ae8d7a27940c33624f8c7bde5528de21ba730d" @@ -2024,6 +2284,12 @@ deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" version = "5.8.0+1" +[[deps.libdecor_jll]] +deps = ["Artifacts", "Dbus_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "Pango_jll", "Wayland_jll", "xkbcommon_jll"] +git-tree-sha1 = "9bf7903af251d2050b467f76bdbe57ce541f7f4f" +uuid = "1183f4f0-6f2a-5f1a-908b-139f9cdfea6f" +version = "0.2.2+0" + [[deps.libevdev_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "141fe65dc3efabb0b1d5ba74e91f6ad26f84cc22" diff --git a/docs/Project.toml b/docs/Project.toml index ac46955..7118d5f 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -9,9 +9,11 @@ MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +TaijaPlotting = "bd7198b4-c7d6-400c-9bab-9a24614b0240" Trapz = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index 081bdbb..920567e 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -219,22 +219,6 @@ function MLJFlux.train( ) X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - if !isa(chain, AbstractLaplace) - la = LaplaceRedux.Laplace( - chain; - likelihood=:regression, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀, - ) - else - la = chain - end - # Initialize history: history = [] verbose_laplace = false @@ -263,6 +247,22 @@ function MLJFlux.train( push!(history, current_loss) end + if !isa(chain, AbstractLaplace) + la = LaplaceRedux.Laplace( + chain; + likelihood=:regression, + subset_of_weights=model.subset_of_weights, + subnetwork_indices=model.subnetwork_indices, + hessian_structure=model.hessian_structure, + backend=model.backend, + σ=model.σ, + μ₀=model.μ₀, + P₀=model.P₀, + ) + else + la = chain + end + # fit the Laplace model: LaplaceRedux.fit!(la, zip(X, y)) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) @@ -387,22 +387,6 @@ function MLJFlux.train( ) X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - if !isa(chain, AbstractLaplace) - la = LaplaceRedux.Laplace( - chain; - likelihood=:classification, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀, - ) - else - la = chain - end - # Initialize history: history = [] verbose_laplace = false @@ -432,6 +416,21 @@ function MLJFlux.train( push!(history, current_loss) end + if !isa(chain, AbstractLaplace) + la = LaplaceRedux.Laplace( + chain; + likelihood=:classification, + subset_of_weights=model.subset_of_weights, + subnetwork_indices=model.subnetwork_indices, + hessian_structure=model.hessian_structure, + backend=model.backend, + σ=model.σ, + μ₀=model.μ₀, + P₀=model.P₀, + ) + else + la = chain + end # fit the Laplace model: LaplaceRedux.fit!(la, zip(X, y)) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)