From 5afc4ba60f829d4714eef309310bc908aae04de1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 26 Feb 2024 22:20:18 -0500 Subject: [PATCH 1/3] Test the new NonlinearSolveBase.jl --- Manifest.toml | 574 ++++++++++++++++++ Project.toml | 14 +- ...linearSolveChainRulesCoreDiffEqBaseExt.jl} | 2 +- src/SimpleNonlinearSolve.jl | 56 +- src/bracketing/bisection.jl | 2 +- src/bracketing/brent.jl | 2 +- src/bracketing/falsi.jl | 2 +- src/bracketing/itp.jl | 2 +- src/bracketing/ridder.jl | 2 +- src/linesearch.jl | 6 +- src/nlsolve/lbroyden.jl | 2 +- src/utils.jl | 72 +-- test/core/23_test_problems_tests.jl | 2 +- test/core/rootfind_tests.jl | 5 +- 14 files changed, 644 insertions(+), 99 deletions(-) create mode 100644 Manifest.toml rename ext/{SimpleNonlinearSolveChainRulesCoreExt.jl => SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt.jl} (94%) diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 0000000..0fa996c --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,574 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.1" +manifest_format = "2.0" +project_hash = "fd7dd831d23cbd09082d4fbd447d505786d32a50" + +[[deps.ADTypes]] +git-tree-sha1 = "41c37aa88889c171f1300ceac1313c06e891d245" +uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +version = "0.2.6" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "0fb305e0253fd4e833d486914367a2ee2c2e78d0" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.0.1" + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" + + [deps.Adapt.weakdeps] + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.ArrayInterface]] +deps = ["Adapt", "LinearAlgebra", "Requires", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "c5aeb516a84459e0318a02507d2261edad97eb75" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "7.7.1" + + [deps.ArrayInterface.extensions] + ArrayInterfaceBandedMatricesExt = "BandedMatrices" + ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" + ArrayInterfaceCUDAExt = "CUDA" + ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" + ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" + ArrayInterfaceTrackerExt = "Tracker" + + [deps.ArrayInterface.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" + StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.CommonSolve]] +git-tree-sha1 = "0eee5eb66b1cf62cd6ad1b460238e60e4b09400c" +uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" +version = "0.2.4" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.0+0" + +[[deps.ConcreteStructs]] +git-tree-sha1 = "f749037478283d372048690eb3b5f92a79432b34" +uuid = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +version = "0.2.3" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "c53fc348ca4d40d7b371e71fd52251839080cbc9" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.4" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.EnumX]] +git-tree-sha1 = "bdb1942cd4c45e3c678fd11569d5cccd80976237" +uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +version = "1.0.4" + +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + +[[deps.FastClosures]] +git-tree-sha1 = "acebe244d53ee1b461970f8910c235b259e772ef" +uuid = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +version = "0.3.2" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra", "Random"] +git-tree-sha1 = "5b93957f6dcd33fc343044af3d48c215be2562f1" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.9.3" + + [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" + + [deps.FillArrays.weakdeps] + PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.FiniteDiff]] +deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays"] +git-tree-sha1 = "73d1214fec245096717847c62d389a5d2ac86504" +uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" +version = "2.22.0" + + [deps.FiniteDiff.extensions] + FiniteDiffBandedMatricesExt = "BandedMatrices" + FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices" + FiniteDiffStaticArraysExt = "StaticArrays" + + [deps.FiniteDiff.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + + [deps.ForwardDiff.weakdeps] + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.FunctionWrappers]] +git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.3" + +[[deps.FunctionWrappersWrappers]] +deps = ["FunctionWrappers"] +git-tree-sha1 = "b104d487b34566608f8b4e1c39fb0b10aa279ff8" +uuid = "77dc65aa-8811-40c2-897b-53d922fa7daf" +version = "0.1.3" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.6" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.27" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.13" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MaybeInplace]] +deps = ["ArrayInterface", "LinearAlgebra", "MacroTools", "SparseArrays"] +git-tree-sha1 = "a85c6a98c9e5a2a7046bc1bb89f28a3241e1de4d" +uuid = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" +version = "0.1.1" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.NonlinearSolveBase]] +deps = ["ArrayInterface", "ConcreteStructs", "FastClosures", "LinearAlgebra", "Markdown", "PrecompileTools", "RecursiveArrayTools", "SciMLBase", "StaticArraysCore"] +git-tree-sha1 = "44c1ccea6b6cc6126fb83b83424ae6cbf61cc63f" +repo-rev = "master" +repo-url = "https://github.com/SciML/NonlinearSolveBase.jl" +uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" +version = "1.0.0" +weakdeps = ["SparseArrays"] + + [deps.NonlinearSolveBase.extensions] + NonlinearSolveBaseSparseArraysExt = "SparseArrays" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+4" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+2" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.0" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.1" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.RecipesBase]] +deps = ["PrecompileTools"] +git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "1.3.4" + +[[deps.RecursiveArrayTools]] +deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "SparseArrays", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] +git-tree-sha1 = "1bbc4bb050165cc57ca2876cd53cc23395948650" +uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" +version = "3.10.0" + + [deps.RecursiveArrayTools.extensions] + RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" + RecursiveArrayToolsForwardDiffExt = "ForwardDiff" + RecursiveArrayToolsMeasurementsExt = "Measurements" + RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" + RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] + RecursiveArrayToolsTrackerExt = "Tracker" + RecursiveArrayToolsZygoteExt = "Zygote" + + [deps.RecursiveArrayTools.weakdeps] + FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" + MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.RuntimeGeneratedFunctions]] +deps = ["ExprTools", "SHA", "Serialization"] +git-tree-sha1 = "6aacc5eefe8415f47b3e34214c1d79d2674a0ba2" +uuid = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" +version = "0.5.12" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.SciMLBase]] +deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FillArrays", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces"] +git-tree-sha1 = "375256db2d99fc730d2d134cca17939324d284d1" +uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +version = "2.28.0" + + [deps.SciMLBase.extensions] + SciMLBaseChainRulesCoreExt = "ChainRulesCore" + SciMLBaseMakieExt = "Makie" + SciMLBasePartialFunctionsExt = "PartialFunctions" + SciMLBasePyCallExt = "PyCall" + SciMLBasePythonCallExt = "PythonCall" + SciMLBaseRCallExt = "RCall" + SciMLBaseZygoteExt = "Zygote" + + [deps.SciMLBase.weakdeps] + ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" + PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" + PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" + PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" + RCall = "6f49c342-dc21-5d91-9882-a32aef131414" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.SciMLOperators]] +deps = ["ArrayInterface", "DocStringExtensions", "LinearAlgebra", "MacroTools", "Setfield", "SparseArrays", "StaticArraysCore"] +git-tree-sha1 = "10499f619ef6e890f3f4a38914481cc868689cd5" +uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" +version = "0.3.8" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.3.1" + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + + [deps.SpecialFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.2" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + +[[deps.SymbolicIndexingInterface]] +git-tree-sha1 = "251bb311585143931a306175c3b7ced220300578" +uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +version = "0.3.8" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.11.1" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TruncatedStacktraces]] +deps = ["InteractiveUtils", "MacroTools", "Preferences"] +git-tree-sha1 = "ea3e54c2bdde39062abf5a9758a23735558705e1" +uuid = "781d530d-4396-4725-bb49-402e4bee1e77" +version = "1.4.0" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+1" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" diff --git a/Project.toml b/Project.toml index a73771c..f415c0f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,19 +1,19 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "1.5.0" +version = "1.6.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" +NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" @@ -21,12 +21,13 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore" +SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt = ["ChainRulesCore", "DiffEqBase"] SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff" SimpleNonlinearSolveStaticArraysExt = "StaticArrays" SimpleNonlinearSolveZygoteExt = "Zygote" @@ -34,8 +35,8 @@ SimpleNonlinearSolveZygoteExt = "Zygote" [compat] ADTypes = "0.2.6" AllocCheck = "0.1.1" -ArrayInterface = "7.7" Aqua = "0.8" +ArrayInterface = "7.7" CUDA = "5.2" ChainRulesCore = "1.22" ConcreteStructs = "0.2.3" @@ -48,6 +49,7 @@ LinearAlgebra = "1.10" LinearSolve = "2.25" MaybeInplace = "0.1.1" NonlinearProblemLibrary = "0.1.2" +NonlinearSolveBase = "1" Pkg = "1.10" PolyesterForwardDiff = "0.1.1" PrecompileTools = "1.2" @@ -66,12 +68,12 @@ julia = "1.10" AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141" +NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -83,4 +85,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff"] +test = ["Aqua", "AllocCheck", "NonlinearSolveBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff"] diff --git a/ext/SimpleNonlinearSolveChainRulesCoreExt.jl b/ext/SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt.jl similarity index 94% rename from ext/SimpleNonlinearSolveChainRulesCoreExt.jl rename to ext/SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt.jl index 23cd25f..96a89f0 100644 --- a/ext/SimpleNonlinearSolveChainRulesCoreExt.jl +++ b/ext/SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt.jl @@ -1,4 +1,4 @@ -module SimpleNonlinearSolveChainRulesCoreExt +module SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt using ChainRulesCore, DiffEqBase, SciMLBase, SimpleNonlinearSolve diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 710f2d6..28d5af8 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -3,22 +3,23 @@ module SimpleNonlinearSolve import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations @recompile_invalidations begin - using ADTypes, ArrayInterface, ConcreteStructs, DiffEqBase, FastClosures, FiniteDiff, - ForwardDiff, Reexport, LinearAlgebra, SciMLBase - - import DiffEqBase: AbstractNonlinearTerminationMode, - AbstractSafeNonlinearTerminationMode, - AbstractSafeBestNonlinearTerminationMode, - NonlinearSafeTerminationReturnCode, get_termination_mode, - NONLINEARSOLVE_DEFAULT_NORM + using ADTypes, ArrayInterface, FiniteDiff, ForwardDiff, NonlinearSolveBase, Reexport, + LinearAlgebra, SciMLBase + + import ConcreteStructs: @concrete import DiffResults + import FastClosures: @closure import ForwardDiff: Dual import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex + import NonlinearSolveBase: AbstractNonlinearTerminationMode, + AbstractSafeNonlinearTerminationMode, + AbstractSafeBestNonlinearTerminationMode, + get_termination_mode, NONLINEARSOLVE_DEFAULT_NORM import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, MMatrix, Size end -@reexport using ADTypes, SciMLBase +@reexport using ADTypes, SciMLBase # TODO: Reexport NonlinearSolveBase after the situation with NonlinearSolve.jl is resolved abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end @@ -58,23 +59,28 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...; end # By Pass the highlevel checks for NonlinearProblem for Simple Algorithms -function SciMLBase.solve( - prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, - args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) - if sensealg === nothing && haskey(prob.kwargs, :sensealg) - sensealg = prob.kwargs[:sensealg] - end - new_u0 = u0 !== nothing ? u0 : prob.u0 - new_p = p !== nothing ? p : prob.p - return __internal_solve_up( - prob, sensealg, new_u0, u0 === nothing, new_p, p === nothing, - alg, args...; prob.kwargs..., kwargs...) -end +# Using eval to prevent ambiguity +for pType in (NonlinearProblem, NonlinearLeastSquaresProblem) + @eval begin + function SciMLBase.solve( + prob::$(pType), alg::AbstractSimpleNonlinearSolveAlgorithm, args...; + sensealg = nothing, u0 = nothing, p = nothing, kwargs...) + if sensealg === nothing && haskey(prob.kwargs, :sensealg) + sensealg = prob.kwargs[:sensealg] + end + new_u0 = u0 !== nothing ? u0 : prob.u0 + new_p = p !== nothing ? p : prob.p + return __internal_solve_up( + prob, sensealg, new_u0, u0 === nothing, new_p, p === nothing, + alg, args...; prob.kwargs..., kwargs...) + end -function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed, p, - p_changed, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) - prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob - return SciMLBase.__solve(prob, alg, args...; kwargs...) + function __internal_solve_up(_prob::$(pType), sensealg, u0, u0_changed, p, + p_changed, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) + prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob + return SciMLBase.__solve(prob, alg, args...; kwargs...) + end + end end @setup_workload begin diff --git a/src/bracketing/bisection.jl b/src/bracketing/bisection.jl index acadf6a..ec7fb28 100644 --- a/src/bracketing/bisection.jl +++ b/src/bracketing/bisection.jl @@ -26,7 +26,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args... left, right = prob.tspan fl, fr = f(left), f(right) - abstol = __get_tolerance(nothing, abstol, + abstol = NonlinearSolveBase.get_tolerance(nothing, abstol, promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) if iszero(fl) diff --git a/src/bracketing/brent.jl b/src/bracketing/brent.jl index 89b2e60..53c4f93 100644 --- a/src/bracketing/brent.jl +++ b/src/bracketing/brent.jl @@ -13,7 +13,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...; fl, fr = f(left), f(right) ϵ = eps(convert(typeof(fl), 1)) - abstol = __get_tolerance(nothing, abstol, + abstol = NonlinearSolveBase.get_tolerance(nothing, abstol, promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) if iszero(fl) diff --git a/src/bracketing/falsi.jl b/src/bracketing/falsi.jl index 896e073..902ed92 100644 --- a/src/bracketing/falsi.jl +++ b/src/bracketing/falsi.jl @@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...; left, right = prob.tspan fl, fr = f(left), f(right) - abstol = __get_tolerance(nothing, abstol, + abstol = NonlinearSolveBase.get_tolerance(nothing, abstol, promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) if iszero(fl) diff --git a/src/bracketing/itp.jl b/src/bracketing/itp.jl index 3f2069b..940d65c 100644 --- a/src/bracketing/itp.jl +++ b/src/bracketing/itp.jl @@ -58,7 +58,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, args...; left, right = prob.tspan fl, fr = f(left), f(right) - abstol = __get_tolerance(nothing, abstol, + abstol = NonlinearSolveBase.get_tolerance(nothing, abstol, promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) if iszero(fl) diff --git a/src/bracketing/ridder.jl b/src/bracketing/ridder.jl index 3b23f42..6d4e2d7 100644 --- a/src/bracketing/ridder.jl +++ b/src/bracketing/ridder.jl @@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...; left, right = prob.tspan fl, fr = f(left), f(right) - abstol = __get_tolerance(nothing, abstol, + abstol = NonlinearSolveBase.get_tolerance(nothing, abstol, promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan)))) if iszero(fl) diff --git a/src/linesearch.jl b/src/linesearch.jl index c33253f..82866ee 100644 --- a/src/linesearch.jl +++ b/src/linesearch.jl @@ -73,7 +73,7 @@ function (cache::LiFukushimaLineSearchCache)(u, δu) fx_norm = ϕ(T(0)) # Non-Blocking exit if the norm is NaN or Inf - DiffEqBase.NAN_CHECK(fx_norm) && return cache.α + NonlinearSolveBase.NAN_CHECK(fx_norm) && return cache.α # Early Terminate based on Eq. 2.7 du_norm = NONLINEARSOLVE_DEFAULT_NORM(δu) @@ -84,12 +84,12 @@ function (cache::LiFukushimaLineSearchCache)(u, δu) fxλp_norm = ϕ(λ₂) if cache.nan_maxiters !== nothing - if DiffEqBase.NAN_CHECK(fxλp_norm) + if NonlinearSolveBase.NAN_CHECK(fxλp_norm) nan_converged = false for _ in 1:(cache.nan_maxiters) λ₁, λ₂ = λ₂, cache.β * λ₂ fxλp_norm = ϕ(λ₂) - nan_converged = DiffEqBase.NAN_CHECK(fxλp_norm)::Bool + nan_converged = NonlinearSolveBase.NAN_CHECK(fxλp_norm)::Bool nan_converged && break end nan_converged || return cache.α diff --git a/src/nlsolve/lbroyden.jl b/src/nlsolve/lbroyden.jl index 145a546..b1ba03f 100644 --- a/src/nlsolve/lbroyden.jl +++ b/src/nlsolve/lbroyden.jl @@ -121,7 +121,7 @@ function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemo U, Vᵀ = __init_low_rank_jacobian(vec(x), vec(fx), threshold) - abstol = __get_tolerance(x, abstol, eltype(x)) + abstol = NonlinearSolveBase.get_tolerance(x, abstol, eltype(x)) xo, δx, fo, δf = x, -fx, fx, fx diff --git a/src/utils.jl b/src/utils.jl index 2876ced..4e32f20 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -77,7 +77,7 @@ except `cache` (& `J` if not nothing) are mutated. function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F, X} if isinplace(f) _f = (du, u) -> f(du, u, p) - if DiffEqBase.has_jac(f) + if SciMLBase.has_jac(f) f.jac(J, x, p) _f(y, x) return y, J @@ -97,7 +97,7 @@ function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F, end else _f = Base.Fix2(f, p) - if DiffEqBase.has_jac(f) + if SciMLBase.has_jac(f) return _f(x), f.jac(x, p) elseif ad isa AutoForwardDiff if ArrayInterface.can_setindex(x) @@ -124,7 +124,7 @@ end function __polyester_forwarddiff_jacobian! end function value_and_jacobian(ad, f::F, y, x::Number, p, cache; J = nothing) where {F} - if DiffEqBase.has_jac(f) + if SciMLBase.has_jac(f) return f(x, p), f.jac(x, p) elseif ad isa AutoForwardDiff T = typeof(__standard_tag(ad.tag, x)) @@ -152,7 +152,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray} if isinplace(f) _f = (du, u) -> f(du, u, p) J = similar(y, length(y), length(x)) - if DiffEqBase.has_jac(f) + if SciMLBase.has_jac(f) return J, nothing elseif ad isa AutoForwardDiff || ad isa AutoPolyesterForwardDiff return J, __get_jacobian_config(ad, _f, y, x) @@ -163,7 +163,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray} end else _f = Base.Fix2(f, p) - if DiffEqBase.has_jac(f) + if SciMLBase.has_jac(f) return nothing, nothing elseif ad isa AutoForwardDiff J = ArrayInterface.can_setindex(x) ? similar(y, length(y), length(x)) : nothing @@ -292,58 +292,27 @@ function init_termination_cache(abstol, reltol, du, u, ::Nothing) return init_termination_cache(abstol, reltol, du, u, AbsNormTerminationMode()) end function init_termination_cache(abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode) - T = promote_type(eltype(du), eltype(u)) - abstol = __get_tolerance(u, abstol, T) - reltol = __get_tolerance(u, reltol, T) tc_cache = init(du, u, tc; abstol, reltol) - return DiffEqBase.get_abstol(tc_cache), DiffEqBase.get_reltol(tc_cache), tc_cache + return (NonlinearSolveBase.get_abstol(tc_cache), + NonlinearSolveBase.get_reltol(tc_cache), tc_cache) end function check_termination(tc_cache, fx, x, xo, prob, alg) return check_termination(tc_cache, fx, x, xo, prob, alg, - DiffEqBase.get_termination_mode(tc_cache)) -end -function check_termination(tc_cache, fx, x, xo, prob, alg, - ::AbstractNonlinearTerminationMode) - if Bool(tc_cache(fx, x, xo)) - return build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) - end - return nothing -end -function check_termination(tc_cache, fx, x, xo, prob, alg, - ::AbstractSafeNonlinearTerminationMode) - if Bool(tc_cache(fx, x, xo)) - if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success - retcode = ReturnCode.Success - elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination - retcode = ReturnCode.ConvergenceFailure - elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination - retcode = ReturnCode.Unstable - else - error("Unknown termination code: $(tc_cache.retcode)") - end - return build_solution(prob, alg, x, fx; retcode) - end - return nothing + NonlinearSolveBase.get_termination_mode(tc_cache)) end + function check_termination(tc_cache, fx, x, xo, prob, alg, - ::AbstractSafeBestNonlinearTerminationMode) + mode::AbstractNonlinearTerminationMode) if Bool(tc_cache(fx, x, xo)) - if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success - retcode = ReturnCode.Success - elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination - retcode = ReturnCode.ConvergenceFailure - elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination - retcode = ReturnCode.Unstable - else - error("Unknown termination code: $(tc_cache.retcode)") - end - if isinplace(prob) - prob.f(fx, x, prob.p) - else - fx = prob.f(x, prob.p) + if mode isa AbstractSafeBestNonlinearTerminationMode + if isinplace(prob) + prob.f(fx, x, prob.p) + else + fx = prob.f(x, prob.p) + end end - return build_solution(prob, alg, tc_cache.u, fx; retcode) + return build_solution(prob, alg, x, fx; retcode = tc_cache.retcode) end return nothing end @@ -382,12 +351,5 @@ end @inline __reshape(x::Number, args...) = x @inline __reshape(x::AbstractArray, args...) = reshape(x, args...) -# Override cases which might be used in a kernel launch -__get_tolerance(x, η, ::Type{T}) where {T} = DiffEqBase._get_tolerance(η, T) -function __get_tolerance(x::Union{SArray, Number}, ::Nothing, ::Type{T}) where {T} - η = real(oneunit(T)) * (eps(real(one(T))))^(real(T)(0.8)) - return T(η) -end - # Extension function __zygote_compute_nlls_vjp end diff --git a/test/core/23_test_problems_tests.jl b/test/core/23_test_problems_tests.jl index 8b8f239..ad530c2 100644 --- a/test/core/23_test_problems_tests.jl +++ b/test/core/23_test_problems_tests.jl @@ -1,5 +1,5 @@ @testsetup module RobustnessTesting -using LinearAlgebra, NonlinearProblemLibrary, DiffEqBase, Test +using LinearAlgebra, NonlinearProblemLibrary, NonlinearSolveBase, SciMLBase, Test problems = NonlinearProblemLibrary.problems dicts = NonlinearProblemLibrary.dicts diff --git a/test/core/rootfind_tests.jl b/test/core/rootfind_tests.jl index 726a6dd..848bf6c 100644 --- a/test/core/rootfind_tests.jl +++ b/test/core/rootfind_tests.jl @@ -1,7 +1,8 @@ @testsetup module RootfindingTesting using Reexport @reexport using AllocCheck, - LinearSolve, StaticArrays, Random, LinearAlgebra, ForwardDiff, DiffEqBase + LinearSolve, StaticArrays, Random, LinearAlgebra, ForwardDiff, + NonlinearSolveBase import PolyesterForwardDiff quadratic_f(u, p) = u .* u .- p @@ -89,7 +90,7 @@ end end end -@testitem "Derivative Free Metods" setup=[RootfindingTesting] begin +@testitem "Derivative Free Methods" setup=[RootfindingTesting] begin @testset "$(nameof(typeof(alg)))" for alg in [SimpleBroyden(), SimpleKlement(), SimpleDFSane(), SimpleLimitedMemoryBroyden(), SimpleBroyden(; linesearch = Val(true)), From 7649c804c05f6344a304b70fd3a3fb14f800674b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 27 Feb 2024 12:02:52 -0500 Subject: [PATCH 2/3] Move ForwardDiff and FiniteDiff into extensions --- Manifest.toml | 207 +--------------- Project.toml | 10 +- ...nlinearSolveChainRulesCoreDiffEqBaseExt.jl | 11 +- ...NonlinearSolveDiffResultsForwardDiffExt.jl | 26 ++ .../forward_ad.jl | 62 +++-- .../hessian.jl | 22 ++ .../jacobian.jl | 95 ++++++++ ext/SimpleNonlinearSolveFiniteDiffExt.jl | 59 +++++ ...leNonlinearSolvePolyesterForwardDiffExt.jl | 22 +- ext/SimpleNonlinearSolveStaticArraysExt.jl | 2 +- src/SimpleNonlinearSolve.jl | 30 ++- src/utils.jl | 230 ++++-------------- test/core/23_test_problems_tests.jl | 1 + test/core/least_squares_tests.jl | 2 +- test/core/matrix_resizing_tests.jl | 2 + test/core/rootfind_tests.jl | 5 +- test/gpu/cuda_tests.jl | 4 +- 17 files changed, 346 insertions(+), 444 deletions(-) create mode 100644 ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/SimpleNonlinearSolveDiffResultsForwardDiffExt.jl rename src/ad.jl => ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/forward_ad.jl (76%) create mode 100644 ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/hessian.jl create mode 100644 ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/jacobian.jl create mode 100644 ext/SimpleNonlinearSolveFiniteDiffExt.jl diff --git a/Manifest.toml b/Manifest.toml index 0fa996c..7b15964 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.1" manifest_format = "2.0" -project_hash = "fd7dd831d23cbd09082d4fbd447d505786d32a50" +project_hash = "0f8351b22e508389e9e69531474eb9c04e478a60" [[deps.ADTypes]] git-tree-sha1 = "41c37aa88889c171f1300ceac1313c06e891d245" @@ -21,10 +21,6 @@ version = "4.0.1" [deps.Adapt.weakdeps] StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - [[deps.ArrayInterface]] deps = ["Adapt", "LinearAlgebra", "Requires", "SparseArrays", "SuiteSparse"] git-tree-sha1 = "c5aeb516a84459e0318a02507d2261edad97eb75" @@ -58,12 +54,6 @@ git-tree-sha1 = "0eee5eb66b1cf62cd6ad1b460238e60e4b09400c" uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" version = "0.2.4" -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" @@ -102,18 +92,6 @@ version = "1.0.0" deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -124,11 +102,6 @@ git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" version = "0.9.3" -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - [[deps.EnumX]] git-tree-sha1 = "bdb1942cd4c45e3c678fd11569d5cccd80976237" uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" @@ -144,53 +117,6 @@ git-tree-sha1 = "acebe244d53ee1b461970f8910c235b259e772ef" uuid = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" version = "0.3.2" -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FillArrays]] -deps = ["LinearAlgebra", "Random"] -git-tree-sha1 = "5b93957f6dcd33fc343044af3d48c215be2562f1" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.9.3" - - [deps.FillArrays.extensions] - FillArraysPDMatsExt = "PDMats" - FillArraysSparseArraysExt = "SparseArrays" - FillArraysStatisticsExt = "Statistics" - - [deps.FillArrays.weakdeps] - PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[deps.FiniteDiff]] -deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays"] -git-tree-sha1 = "73d1214fec245096717847c62d389a5d2ac86504" -uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.22.0" - - [deps.FiniteDiff.extensions] - FiniteDiffBandedMatricesExt = "BandedMatrices" - FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices" - FiniteDiffStaticArraysExt = "StaticArrays" - - [deps.FiniteDiff.weakdeps] - BandedMatrices = "aae01518-5342-5314-be14-df237901396f" - BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - - [deps.ForwardDiff.weakdeps] - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - [[deps.FunctionWrappers]] git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" @@ -212,36 +138,11 @@ git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" version = "0.1.6" -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - [[deps.IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" - [[deps.LibGit2]] deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" @@ -263,22 +164,6 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.27" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -303,58 +188,36 @@ deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" version = "2.28.2+1" -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" [[deps.NonlinearSolveBase]] deps = ["ArrayInterface", "ConcreteStructs", "FastClosures", "LinearAlgebra", "Markdown", "PrecompileTools", "RecursiveArrayTools", "SciMLBase", "StaticArraysCore"] -git-tree-sha1 = "44c1ccea6b6cc6126fb83b83424ae6cbf61cc63f" +git-tree-sha1 = "4b21852907478b4039cf6bed35e070b3428279b5" repo-rev = "master" repo-url = "https://github.com/SciML/NonlinearSolveBase.jl" uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" version = "1.0.0" -weakdeps = ["SparseArrays"] [deps.NonlinearSolveBase.extensions] + NonlinearSolveBaseForwardDiffExt = "ForwardDiff" NonlinearSolveBaseSparseArraysExt = "SparseArrays" + [deps.NonlinearSolveBase.weakdeps] + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" version = "0.3.23+4" -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - [[deps.OrderedCollections]] git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.6.3" -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" - [[deps.PrecompileTools]] deps = ["Preferences"] git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f" @@ -371,10 +234,6 @@ version = "1.4.1" deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - [[deps.Random]] deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -387,9 +246,9 @@ version = "1.3.4" [[deps.RecursiveArrayTools]] deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "SparseArrays", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "1bbc4bb050165cc57ca2876cd53cc23395948650" +git-tree-sha1 = "dc428bb59c20dafd1ec500c3432b9e3d7e78e7f3" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "3.10.0" +version = "3.10.1" [deps.RecursiveArrayTools.extensions] RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" @@ -431,10 +290,10 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" [[deps.SciMLBase]] -deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FillArrays", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces"] -git-tree-sha1 = "375256db2d99fc730d2d134cca17939324d284d1" +deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] +git-tree-sha1 = "3a281a9fce9cd62b849d7f16e412933a5fe755cb" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "2.28.0" +version = "2.29.0" [deps.SciMLBase.extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" @@ -478,18 +337,6 @@ deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" version = "1.10.0" -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.1" - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - - [deps.SpecialFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - [[deps.StaticArraysCore]] git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -531,21 +378,6 @@ git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" version = "1.11.1" -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TruncatedStacktraces]] -deps = ["InteractiveUtils", "MacroTools", "Preferences"] -git-tree-sha1 = "ea3e54c2bdde39062abf5a9758a23735558705e1" -uuid = "781d530d-4396-4725-bb49-402e4bee1e77" -version = "1.4.0" - [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" @@ -553,22 +385,7 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" - [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" version = "5.8.0+1" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" diff --git a/Project.toml b/Project.toml index f415c0f..710cce3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,16 +1,13 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "1.6.0" +version = "1.6.0" # Bump to 2.0.0 before releasing [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" -FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" @@ -22,11 +19,16 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] +SimpleNonlinearSolveFiniteDiffExt = "FiniteDiff" +SimpleNonlinearSolveDiffResultsForwardDiffExt = ["DiffResults", "ForwardDiff"] SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt = ["ChainRulesCore", "DiffEqBase"] SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff" SimpleNonlinearSolveStaticArraysExt = "StaticArrays" diff --git a/ext/SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt.jl b/ext/SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt.jl index 96a89f0..434c839 100644 --- a/ext/SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt.jl +++ b/ext/SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt.jl @@ -1,18 +1,21 @@ module SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt -using ChainRulesCore, DiffEqBase, SciMLBase, SimpleNonlinearSolve +using SciMLBase +import DiffEqBase, SimpleNonlinearSolve +import ChainRulesCore as CRC # The expectation here is that no-one is using this directly inside a GPU kernel. We can # eventually lift this requirement using a custom adjoint -function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up), - prob::NonlinearProblem, +function CRC.rrule( + ::typeof(SimpleNonlinearSolve.__internal_solve_up), prob::NonlinearProblem, sensealg::Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm}, u0, u0_changed, p, p_changed, alg, args...; kwargs...) out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), alg, args...; kwargs...) function ∇__internal_solve_up(Δ) ∂f, ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ) - return (∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), ∂originator, + return ( + ∂f, ∂prob, ∂sensealg, ∂u0, CRC.NoTangent(), ∂p, CRC.NoTangent(), ∂originator, ∂args...) end return out, ∇__internal_solve_up diff --git a/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/SimpleNonlinearSolveDiffResultsForwardDiffExt.jl b/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/SimpleNonlinearSolveDiffResultsForwardDiffExt.jl new file mode 100644 index 0000000..9a2d904 --- /dev/null +++ b/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/SimpleNonlinearSolveDiffResultsForwardDiffExt.jl @@ -0,0 +1,26 @@ +module SimpleNonlinearSolveDiffResultsForwardDiffExt + +import ADTypes: AutoForwardDiff, AutoPolyesterForwardDiff +import ArrayInterface, SciMLBase, SimpleNonlinearSolve, DiffResults, ForwardDiff +import FastClosures: @closure +import LinearAlgebra: mul! +import SciMLBase: IntervalNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem, + solve +import SimpleNonlinearSolve: AbstractSimpleNonlinearSolveAlgorithm, __nlsolve_ad, + __nlsolve_dual_soln, __nlsolve_∂f_∂p, __nlsolve_∂f_∂u, + Bisection, Brent, Alefeld, Falsi, ITP, Ridder +import StaticArraysCore: StaticArray, SArray, Size + +@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:ForwardDiff}) = true + +@inline SimpleNonlinearSolve.__can_dual(x) = ForwardDiff.can_dual(x) + +@inline SimpleNonlinearSolve.value(x::ForwardDiff.Dual) = ForwardDiff.value(x) +@inline SimpleNonlinearSolve.value(x::AbstractArray{<:ForwardDiff.Dual}) = map( + ForwardDiff.value, x) + +include("jacobian.jl") +include("hessian.jl") +include("forward_ad.jl") + +end diff --git a/src/ad.jl b/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/forward_ad.jl similarity index 76% rename from src/ad.jl rename to ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/forward_ad.jl index d4e091c..13d33b3 100644 --- a/src/ad.jl +++ b/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/forward_ad.jl @@ -1,6 +1,7 @@ -function SciMLBase.solve( +function solve( prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, - iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, + iip, <:Union{ + <:ForwardDiff.Dual{T, V, P}, <:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}}, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip} sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) @@ -8,9 +9,9 @@ function SciMLBase.solve( prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) end -function SciMLBase.solve( +function solve( prob::NonlinearLeastSquaresProblem{<:AbstractArray, - iip, <:Union{<:AbstractArray{<:Dual{T, V, P}}}}, + iip, <:Union{<:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}}, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip} sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) @@ -22,25 +23,27 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder) @eval begin function SciMLBase.solve( prob::IntervalNonlinearProblem{uType, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, + <:Union{<:ForwardDiff.Dual{T, V, P}, + <:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}}, alg::$(algType), args...; kwargs...) where {uType, T, V, P, iip} sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, - sol.stats, sol.original, left = Dual{T, V, P}(sol.left, partials), - right = Dual{T, V, P}(sol.right, partials)) + sol.stats, sol.original, left = ForwardDiff.Dual{T, V, P}( + sol.left, partials), + right = ForwardDiff.Dual{T, V, P}(sol.right, partials)) end end end function __nlsolve_ad( prob::Union{IntervalNonlinearProblem, NonlinearProblem}, alg, args...; kwargs...) - p = value(prob.p) + p = SimpleNonlinearSolve.value(prob.p) if prob isa IntervalNonlinearProblem - tspan = value.(prob.tspan) + tspan = SimpleNonlinearSolve.value.(prob.tspan) newprob = IntervalNonlinearProblem(prob.f, tspan, p; prob.kwargs...) else - u0 = value(prob.u0) + u0 = SimpleNonlinearSolve.value(prob.u0) newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...) end @@ -66,8 +69,8 @@ function __nlsolve_ad( end function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...) - p = value(prob.p) - u0 = value(prob.u0) + p = SimpleNonlinearSolve.value(prob.p) + u0 = SimpleNonlinearSolve.value(prob.u0) newprob = NonlinearLeastSquaresProblem(prob.f, u0, p; prob.kwargs...) sol = solve(newprob, alg, args...; kwargs...) @@ -77,7 +80,7 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs.. # First check for custom `vjp` then custom `Jacobian` and if nothing is provided use # nested autodiff as the last resort if SciMLBase.has_vjp(prob.f) - if isinplace(prob) + if SciMLBase.isinplace(prob) _F = @closure (du, u, p) -> begin resid = similar(du, length(sol.resid)) prob.f(resid, u, p) @@ -92,7 +95,7 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs.. end end elseif SciMLBase.has_jac(prob.f) - if isinplace(prob) + if SciMLBase.isinplace(prob) _F = @closure (du, u, p) -> begin J = similar(du, length(sol.resid), length(u)) prob.f.jac(J, u, p) @@ -107,7 +110,7 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs.. end end else - if isinplace(prob) + if SciMLBase.isinplace(prob) _F = @closure (du, u, p) -> begin resid = similar(du, length(sol.resid)) res = DiffResults.DiffResult( @@ -120,8 +123,10 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs.. end else # For small problems, nesting ForwardDiff is actually quite fast - if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) ≥ 50) - _F = @closure (u, p) -> __zygote_compute_nlls_vjp(prob.f, u, p) + if SimpleNonlinearSolve.__is_extension_loaded(Val(:Zygote)) && + (length(uu) + length(sol.resid) ≥ 50) + _F = @closure (u, p) -> SimpleNonlinearSolve.__zygote_compute_nlls_vjp( + prob.f, u, p) else _F = @closure (u, p) -> begin T = promote_type(eltype(u), eltype(p)) @@ -156,7 +161,7 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs.. end @inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F} - if isinplace(prob) + if SciMLBase.isinplace(prob) __f = p -> begin du = similar(u, promote_type(eltype(u), eltype(p))) f(du, u, p) @@ -166,16 +171,16 @@ end __f = Base.Fix1(f, u) end if p isa Number - return __reshape(ForwardDiff.derivative(__f, p), :, 1) + return SimpleNonlinearSolve.__reshape(ForwardDiff.derivative(__f, p), :, 1) elseif u isa Number - return __reshape(ForwardDiff.gradient(__f, p), 1, :) + return SimpleNonlinearSolve.__reshape(ForwardDiff.gradient(__f, p), 1, :) else return ForwardDiff.jacobian(__f, p) end end @inline function __nlsolve_∂f_∂u(prob, f::F, u, p) where {F} - if isinplace(prob) + if SciMLBase.isinplace(prob) du = similar(u) __f = (du, u) -> f(du, u, p) ForwardDiff.jacobian(__f, du, u) @@ -190,12 +195,15 @@ end end @inline function __nlsolve_dual_soln(u::Number, partials, - ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} - return Dual{T, V, P}(u, partials) + ::Union{<:AbstractArray{<:ForwardDiff.Dual{T, V, P}}, ForwardDiff.Dual{T, V, P}}) where { + T, V, P} + return ForwardDiff.Dual{T, V, P}(u, partials) end -@inline function __nlsolve_dual_soln(u::AbstractArray, partials, - ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} - _partials = _restructure(u, partials) - return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials)) +@inline function __nlsolve_dual_soln(u::AbstractArray, + partials, + ::Union{<:AbstractArray{<:ForwardDiff.Dual{T, V, P}}, ForwardDiff.Dual{T, V, P}}) where { + T, V, P} + _partials = SimpleNonlinearSolve._restructure(u, partials) + return map(((uᵢ, pᵢ),) -> ForwardDiff.Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials)) end diff --git a/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/hessian.jl b/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/hessian.jl new file mode 100644 index 0000000..6a71b02 --- /dev/null +++ b/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/hessian.jl @@ -0,0 +1,22 @@ +function SimpleNonlinearSolve.compute_jacobian_and_hessian( + ad::AutoForwardDiff, prob, _, x::Number) + fx = prob.f(x, prob.p) + J_fn = Base.Fix1(ForwardDiff.derivative, Base.Fix2(prob.f, prob.p)) + dfx = J_fn(x) + d2fx = ForwardDiff.derivative(J_fn, x) + return fx, dfx, d2fx +end + +function SimpleNonlinearSolve.compute_jacobian_and_hessian( + ad::AutoForwardDiff, prob, fx, x) + if SciMLBase.isinplace(prob) + error("Inplace version for Nested ForwardDiff Not Implemented Yet!") + else + f = Base.Fix2(prob.f, prob.p) + fx = f(x) + J_fn = Base.Fix1(ForwardDiff.jacobian, f) + dfx = J_fn(x) + d2fx = ForwardDiff.jacobian(J_fn, x) + return fx, dfx, d2fx + end +end diff --git a/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/jacobian.jl b/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/jacobian.jl new file mode 100644 index 0000000..286076f --- /dev/null +++ b/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/jacobian.jl @@ -0,0 +1,95 @@ + +struct SimpleNonlinearSolveTag end + +function ForwardDiff.checktag(::Type{<:ForwardDiff.Tag{<:SimpleNonlinearSolveTag, <:T}}, + f::F, x::AbstractArray{T}) where {T, F} + return true +end + +@inline __standard_tag(::Nothing, x) = ForwardDiff.Tag(SimpleNonlinearSolveTag(), eltype(x)) +@inline __standard_tag(tag::ForwardDiff.Tag, _) = tag +@inline __standard_tag(tag, x) = ForwardDiff.Tag(tag, eltype(x)) + +function __pick_forwarddiff_chunk( + ad::Union{AutoForwardDiff{CS}, AutoPolyesterForwardDiff{CS}}, x) where {CS} + (CS === nothing || CS ≤ 0) && return __pick_forwarddiff_chunk(x) + return ForwardDiff.Chunk{CS}() +end +__pick_forwarddiff_chunk(x) = ForwardDiff.Chunk(length(x)) +function __pick_forwarddiff_chunk(x::StaticArray) + L = prod(Size(x)) + if L ≤ ForwardDiff.DEFAULT_CHUNK_THRESHOLD + return ForwardDiff.Chunk{L}() + else + return ForwardDiff.Chunk{ForwardDiff.DEFAULT_CHUNK_THRESHOLD}() + end +end + +# Jacobian +function __forwarddiff_jacobian_config(f::F, x, ck::ForwardDiff.Chunk, tag) where {F} + return ForwardDiff.JacobianConfig(f, x, ck, tag) +end +function __forwarddiff_jacobian_config( + f::F, x::SArray, ck::ForwardDiff.Chunk{N}, tag) where {F, N} + seeds = ForwardDiff.construct_seeds(ForwardDiff.Partials{N, eltype(x)}) + duals = ForwardDiff.Dual{typeof(tag), eltype(x), N}.(x) + return ForwardDiff.JacobianConfig{typeof(tag), eltype(x), N, typeof(duals)}(seeds, + duals) +end + +function __get_jacobian_config(ad::AutoForwardDiff{CS}, f::F, x) where {F, CS} + return __forwarddiff_jacobian_config( + f, x, __pick_forwarddiff_chunk(ad, x), __standard_tag(ad.tag, x)) +end + +function __get_jacobian_config(ad::AutoForwardDiff{CS}, f!::F, y, x) where {F, CS} + return ForwardDiff.JacobianConfig( + f!, y, x, __pick_forwarddiff_chunk(ad, x), __standard_tag(ad.tag, x)) +end + +function __get_jacobian_config(ad::AutoPolyesterForwardDiff{CS}, args...) where {CS} + return __pick_forwarddiff_chunk(ad, last(args)) +end + +function SimpleNonlinearSolve.__jacobian_cache( + ::Val{iip}, ad::Union{AutoForwardDiff, AutoPolyesterForwardDiff}, f::F, y, + x) where {iip, F} + if iip + J = similar(y, promote_type(eltype(x), eltype(y)), length(y), length(x)) + return J, __get_jacobian_config(ad, f, y, x) + end + if ad isa AutoPolyesterForwardDiff + @assert ArrayInterface.can_setindex(x) "PolyesterForwardDiff requires mutable \ + inputs. Use AutoForwardDiff instead." + end + J = ArrayInterface.can_setindex(x) ? + similar(y, promote_type(eltype(x), eltype(y)), length(y), length(x)) : nothing + return J, __get_jacobian_config(ad, f, x) +end + +function SimpleNonlinearSolve.__value_and_jacobian!( + ::Val{iip}, ad::AutoForwardDiff, J, f::F, y, x::AbstractArray, cache) where {iip, F} + if iip + res = DiffResults.DiffResult(y, J) + ForwardDiff.jacobian!(res, f, y, x, cache) + return DiffResults.value(res), DiffResults.jacobian(res) + end + if ArrayInterface.can_setindex(x) + res = DiffResults.DiffResult(y, J) + ForwardDiff.jacobian!(res, f, x, cache) + return DiffResults.value(res), DiffResults.jacobian(res) + end + return f(x), ForwardDiff.jacobian(f, x, cache) +end + +function SimpleNonlinearSolve.__value_and_jacobian!( + ::Val, ad::Union{AutoForwardDiff, AutoPolyesterForwardDiff}, + J, f::F, y, x::Number, cache) where {F} + if hasfield(typeof(ad), :tag) + T = typeof(__standard_tag(ad.tag, x)) + else + T = typeof(__standard_tag(nothing, x)) + end + out = f(ForwardDiff.Dual{T}(x, one(x))) + return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out) +end \ No newline at end of file diff --git a/ext/SimpleNonlinearSolveFiniteDiffExt.jl b/ext/SimpleNonlinearSolveFiniteDiffExt.jl new file mode 100644 index 0000000..3cefa38 --- /dev/null +++ b/ext/SimpleNonlinearSolveFiniteDiffExt.jl @@ -0,0 +1,59 @@ +module SimpleNonlinearSolveFiniteDiffExt + +import ADTypes: AutoFiniteDiff +import SciMLBase, SimpleNonlinearSolve, FiniteDiff +import StaticArraysCore: SArray + +@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:FiniteDiff}) = true + +# Jacobian +function SimpleNonlinearSolve.__jacobian_cache(::Val{iip}, ad::AutoFiniteDiff, f::F, y, + x) where {iip, F} + cache = FiniteDiff.JacobianCache(copy(x), copy(y), copy(y), ad.fdtype) + J = iip ? similar(y, promote_type(eltype(x), eltype(y)), length(y), length(x)) : + nothing + return J, cache +end +function SimpleNonlinearSolve.__jacobian_cache(::Val, ad::AutoFiniteDiff, f::F, y, + x::SArray) where {F} + return nothing, nothing +end + +function SimpleNonlinearSolve.__value_and_jacobian!( + ::Val{iip}, ad::AutoFiniteDiff, J, f::F, y, x, cache) where {iip, F} + x isa Number && return (f(x), FiniteDiff.finite_difference_derivative(f, x, ad.fdtype)) + if iip + FiniteDiff.finite_difference_jacobian!(J, f, x, cache) + f(y, x) + return y, J + end + cache === nothing && return f(x), FiniteDiff.finite_difference_jacobian(f, x) + return f(x), FiniteDiff.finite_difference_jacobian(f, x, cache) +end + +# Hessian +function SimpleNonlinearSolve.compute_jacobian_and_hessian( + ad::AutoFiniteDiff, prob, _, x::Number) + fx = prob.f(x, prob.p) + J_fn = x -> FiniteDiff.finite_difference_derivative(Base.Fix2(prob.f, prob.p), x, + ad.fdtype) + dfx = J_fn(x) + d2fx = FiniteDiff.finite_difference_derivative(J_fn, x, ad.fdtype) + return fx, dfx, d2fx +end + +function SimpleNonlinearSolve.compute_jacobian_and_hessian( + ad::AutoFiniteDiff, prob, fx, x) + if SciMLBase.isinplace(prob) + error("Inplace version for Nested FiniteDiff Not Implemented Yet!") + else + f = Base.Fix2(prob.f, prob.p) + fx = f(x) + J_fn = x -> FiniteDiff.finite_difference_jacobian(f, x, ad.fdtype) + dfx = J_fn(x) + d2fx = FiniteDiff.finite_difference_jacobian(J_fn, x, ad.fdtype) + return fx, dfx, d2fx + end +end + +end diff --git a/ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl b/ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl index 81cee48..8b31f0c 100644 --- a/ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl +++ b/ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl @@ -1,19 +1,19 @@ module SimpleNonlinearSolvePolyesterForwardDiffExt -using SimpleNonlinearSolve, PolyesterForwardDiff +import ADTypes: AutoPolyesterForwardDiff +import SimpleNonlinearSolve, PolyesterForwardDiff @inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:PolyesterForwardDiff}) = true -@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f!::F, y, J, x, - chunksize) where {F} - PolyesterForwardDiff.threaded_jacobian!(f!, y, J, x, chunksize) - return J -end - -@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f::F, J, x, - chunksize) where {F} - PolyesterForwardDiff.threaded_jacobian!(f, J, x, chunksize) - return J +function SimpleNonlinearSolve.__value_and_jacobian!( + ::Val{iip}, ad::AutoPolyesterForwardDiff, J, f::F, y, + x::AbstractArray, cache) where {iip, F} + if iip + PolyesterForwardDiff.threaded_jacobian!(f, y, J, x, cache) + f(y, x) + return y, J + end + return f(x), PolyesterForwardDiff.threaded_jacobian!(f, J, x, cache) end end diff --git a/ext/SimpleNonlinearSolveStaticArraysExt.jl b/ext/SimpleNonlinearSolveStaticArraysExt.jl index 90318a8..a913efc 100644 --- a/ext/SimpleNonlinearSolveStaticArraysExt.jl +++ b/ext/SimpleNonlinearSolveStaticArraysExt.jl @@ -1,6 +1,6 @@ module SimpleNonlinearSolveStaticArraysExt -using SimpleNonlinearSolve +import SimpleNonlinearSolve @inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:StaticArrays}) = true diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 28d5af8..0c6b59d 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -3,13 +3,10 @@ module SimpleNonlinearSolve import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations @recompile_invalidations begin - using ADTypes, ArrayInterface, FiniteDiff, ForwardDiff, NonlinearSolveBase, Reexport, - LinearAlgebra, SciMLBase + using ADTypes, ArrayInterface, NonlinearSolveBase, Reexport, LinearAlgebra, SciMLBase import ConcreteStructs: @concrete - import DiffResults import FastClosures: @closure - import ForwardDiff: Dual import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex import NonlinearSolveBase: AbstractNonlinearTerminationMode, AbstractSafeNonlinearTerminationMode, @@ -47,8 +44,12 @@ include("bracketing/brent.jl") include("bracketing/alefeld.jl") include("bracketing/itp.jl") -# AD -include("ad.jl") +# AD: Defined in Extension +## DONT REMOVE THESE: They are used in NonlinearSolve.jl +function __nlsolve_ad end +function __nlsolve_∂f_∂p end +function __nlsolve_∂f_∂u end +function __nlsolve_dual_soln end ## Default algorithm @@ -85,16 +86,19 @@ end @setup_workload begin for T in (Float32, Float64) + # FIXME Move this precompilation into the forwarddiff & finitediff extensions prob_no_brack_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) prob_no_brack_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, T.([1.0, 1.0, 1.0]), T(2)) prob_no_brack_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, T.([1.0, 1.0, 1.0]), T(2)) - algs = [SimpleNewtonRaphson(), SimpleBroyden(), SimpleKlement(), SimpleDFSane(), - SimpleTrustRegion(), SimpleLimitedMemoryBroyden(; threshold = 2)] + algs = [SimpleBroyden(), SimpleKlement(), SimpleDFSane(), + SimpleLimitedMemoryBroyden(; threshold = 2)] - algs_no_iip = [SimpleHalley()] + # algs = [SimpleNewtonRaphson(), SimpleTrustRegion()] + + # algs_no_iip = [SimpleHalley()] @compile_workload begin for alg in algs @@ -103,10 +107,10 @@ end solve(prob_no_brack_oop, alg, abstol = T(1e-2)) end - for alg in algs_no_iip - solve(prob_no_brack_scalar, alg, abstol = T(1e-2)) - solve(prob_no_brack_oop, alg, abstol = T(1e-2)) - end + # for alg in algs_no_iip + # solve(prob_no_brack_scalar, alg, abstol = T(1e-2)) + # solve(prob_no_brack_oop, alg, abstol = T(1e-2)) + # end end prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p, diff --git a/src/utils.jl b/src/utils.jl index 4e32f20..fbe3e95 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,10 +1,3 @@ -struct SimpleNonlinearSolveTag end - -function ForwardDiff.checktag(::Type{<:ForwardDiff.Tag{<:SimpleNonlinearSolveTag, <:T}}, - f::F, x::AbstractArray{T}) where {T, F} - return true -end - """ __prevfloat_tdir(x, x0, x1) @@ -26,47 +19,7 @@ Return the maximum of `a` and `b` if `x1 > x0`, otherwise return the minimum. """ __max_tdir(a, b, x0, x1) = ifelse(x1 > x0, max(a, b), min(a, b)) -__standard_tag(::Nothing, x) = ForwardDiff.Tag(SimpleNonlinearSolveTag(), eltype(x)) -__standard_tag(tag::ForwardDiff.Tag, _) = tag -__standard_tag(tag, x) = ForwardDiff.Tag(tag, eltype(x)) - -__pick_forwarddiff_chunk(x) = ForwardDiff.Chunk(length(x)) -function __pick_forwarddiff_chunk(x::StaticArray) - L = prod(Size(x)) - if L ≤ ForwardDiff.DEFAULT_CHUNK_THRESHOLD - return ForwardDiff.Chunk{L}() - else - return ForwardDiff.Chunk{ForwardDiff.DEFAULT_CHUNK_THRESHOLD}() - end -end - -function __get_jacobian_config(ad::AutoForwardDiff{CS}, f::F, x) where {F, CS} - ck = (CS === nothing || CS ≤ 0) ? __pick_forwarddiff_chunk(x) : ForwardDiff.Chunk{CS}() - tag = __standard_tag(ad.tag, x) - return __forwarddiff_jacobian_config(f, x, ck, tag) -end -function __get_jacobian_config(ad::AutoForwardDiff{CS}, f!::F, y, x) where {F, CS} - ck = (CS === nothing || CS ≤ 0) ? __pick_forwarddiff_chunk(x) : ForwardDiff.Chunk{CS}() - tag = __standard_tag(ad.tag, x) - return ForwardDiff.JacobianConfig(f!, y, x, ck, tag) -end - -function __forwarddiff_jacobian_config(f::F, x, ck::ForwardDiff.Chunk, tag) where {F} - return ForwardDiff.JacobianConfig(f, x, ck, tag) -end -function __forwarddiff_jacobian_config( - f::F, x::SArray, ck::ForwardDiff.Chunk{N}, tag) where {F, N} - seeds = ForwardDiff.construct_seeds(ForwardDiff.Partials{N, eltype(x)}) - duals = ForwardDiff.Dual{typeof(tag), eltype(x), N}.(x) - return ForwardDiff.JacobianConfig{typeof(tag), eltype(x), N, typeof(duals)}(seeds, - duals) -end - -function __get_jacobian_config(ad::AutoPolyesterForwardDiff{CS}, args...) where {CS} - x = last(args) - return (CS === nothing || CS ≤ 0) ? __pick_forwarddiff_chunk(x) : - ForwardDiff.Chunk{CS}() -end +function __value_and_jacobian! end """ value_and_jacobian(ad, f, y, x, p, cache; J = nothing) @@ -74,72 +27,17 @@ end Compute `f(x), d/dx f(x)` in the most efficient way based on `ad`. None of the arguments except `cache` (& `J` if not nothing) are mutated. """ -function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F, X} +function value_and_jacobian(ad, f::F, y, x, p, cache; J = nothing) where {F} if isinplace(f) - _f = (du, u) -> f(du, u, p) if SciMLBase.has_jac(f) f.jac(J, x, p) - _f(y, x) - return y, J - elseif ad isa AutoForwardDiff - res = DiffResults.DiffResult(y, J) - ForwardDiff.jacobian!(res, _f, y, x, cache) - return DiffResults.value(res), DiffResults.jacobian(res) - elseif ad isa AutoFiniteDiff - FiniteDiff.finite_difference_jacobian!(J, _f, x, cache) - _f(y, x) - return y, J - elseif ad isa AutoPolyesterForwardDiff - __polyester_forwarddiff_jacobian!(_f, y, J, x, cache) + f(y, x, p) return y, J - else - throw(ArgumentError("Unsupported AD method: $(ad)")) end + __value_and_jacobian!(Val(true), ad, J, @closure((du, u)->f(du, u, p)), y, x, cache) else - _f = Base.Fix2(f, p) - if SciMLBase.has_jac(f) - return _f(x), f.jac(x, p) - elseif ad isa AutoForwardDiff - if ArrayInterface.can_setindex(x) - res = DiffResults.DiffResult(y, J) - ForwardDiff.jacobian!(res, _f, x, cache) - return DiffResults.value(res), DiffResults.jacobian(res) - else - J_fd = ForwardDiff.jacobian(_f, x, cache) - return _f(x), J_fd - end - elseif ad isa AutoFiniteDiff - J_fd = FiniteDiff.finite_difference_jacobian(_f, x, cache) - return _f(x), J_fd - elseif ad isa AutoPolyesterForwardDiff - __polyester_forwarddiff_jacobian!(_f, J, x, cache) - return _f(x), J - else - throw(ArgumentError("Unsupported AD method: $(ad)")) - end - end -end - -# Declare functions -function __polyester_forwarddiff_jacobian! end - -function value_and_jacobian(ad, f::F, y, x::Number, p, cache; J = nothing) where {F} - if SciMLBase.has_jac(f) - return f(x, p), f.jac(x, p) - elseif ad isa AutoForwardDiff - T = typeof(__standard_tag(ad.tag, x)) - out = f(ForwardDiff.Dual{T}(x, one(x)), p) - return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out) - elseif ad isa AutoPolyesterForwardDiff - # Just use ForwardDiff - T = typeof(__standard_tag(nothing, x)) - out = f(ForwardDiff.Dual{T}(x, one(x)), p) - return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out) - elseif ad isa AutoFiniteDiff - _f = Base.Fix2(f, p) - return _f(x), FiniteDiff.finite_difference_derivative(_f, x, ad.fdtype) - else - throw(ArgumentError("Unsupported AD method: $(ad)")) + SciMLBase.has_jac(f) && return f(x, p), f.jac(x, p) + __value_and_jacobian!(Val(false), ad, J, Base.Fix2(f, p), y, x, cache) end end @@ -150,79 +48,24 @@ Returns a Jacobian Matrix and a cache for the Jacobian computation. """ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray} if isinplace(f) - _f = (du, u) -> f(du, u, p) - J = similar(y, length(y), length(x)) if SciMLBase.has_jac(f) - return J, nothing - elseif ad isa AutoForwardDiff || ad isa AutoPolyesterForwardDiff - return J, __get_jacobian_config(ad, _f, y, x) - elseif ad isa AutoFiniteDiff - return J, FiniteDiff.JacobianCache(copy(x), copy(y), copy(y), ad.fdtype) - else - throw(ArgumentError("Unsupported AD method: $(ad)")) + return (similar(y, promote_type(eltype(x), eltype(y)), length(y), length(x)), + nothing) end + return __jacobian_cache(Val(true), ad, @closure((du, u)->f(du, u, p)), y, x) else - _f = Base.Fix2(f, p) - if SciMLBase.has_jac(f) - return nothing, nothing - elseif ad isa AutoForwardDiff - J = ArrayInterface.can_setindex(x) ? similar(y, length(y), length(x)) : nothing - return J, __get_jacobian_config(ad, _f, x) - elseif ad isa AutoPolyesterForwardDiff - @assert ArrayInterface.can_setindex(x) "PolyesterForwardDiff requires mutable inputs. Use AutoForwardDiff instead." - J = similar(y, length(y), length(x)) - return J, __get_jacobian_config(ad, _f, x) - elseif ad isa AutoFiniteDiff - return nothing, FiniteDiff.JacobianCache(copy(x), copy(y), copy(y), ad.fdtype) - else - throw(ArgumentError("Unsupported AD method: $(ad)")) - end + SciMLBase.has_jac(f) && return nothing, nothing + return __jacobian_cache(Val(false), ad, Base.Fix2(f, p), y, x) end end jacobian_cache(ad, f::F, y, x::Number, p) where {F} = nothing, nothing -function compute_jacobian_and_hessian(ad::AutoForwardDiff, prob, _, x::Number) - fx = prob.f(x, prob.p) - J_fn = Base.Fix1(ForwardDiff.derivative, Base.Fix2(prob.f, prob.p)) - dfx = J_fn(x) - d2fx = ForwardDiff.derivative(J_fn, x) - return fx, dfx, d2fx -end - -function compute_jacobian_and_hessian(ad::AutoForwardDiff, prob, fx, x) - if isinplace(prob) - error("Inplace version for Nested ForwardDiff Not Implemented Yet!") - else - f = Base.Fix2(prob.f, prob.p) - fx = f(x) - J_fn = Base.Fix1(ForwardDiff.jacobian, f) - dfx = J_fn(x) - d2fx = ForwardDiff.jacobian(J_fn, x) - return fx, dfx, d2fx - end -end - -function compute_jacobian_and_hessian(ad::AutoFiniteDiff, prob, _, x::Number) - fx = prob.f(x, prob.p) - J_fn = x -> FiniteDiff.finite_difference_derivative(Base.Fix2(prob.f, prob.p), x, - ad.fdtype) - dfx = J_fn(x) - d2fx = FiniteDiff.finite_difference_derivative(J_fn, x, ad.fdtype) - return fx, dfx, d2fx -end +__jacobian_cache(::Val, ad, f::F, y, x) where {F} = __test_loaded_backend(ad, eltype(x)) -function compute_jacobian_and_hessian(ad::AutoFiniteDiff, prob, fx, x) - if isinplace(prob) - error("Inplace version for Nested FiniteDiff Not Implemented Yet!") - else - f = Base.Fix2(prob.f, prob.p) - fx = f(x) - J_fn = x -> FiniteDiff.finite_difference_jacobian(f, x, ad.fdtype) - dfx = J_fn(x) - d2fx = FiniteDiff.finite_difference_jacobian(J_fn, x, ad.fdtype) - return fx, dfx, d2fx - end +function compute_jacobian_and_hessian(ad, prob, fx, x) + __test_loaded_backend(ad, x) + error("`compute_jacobian_and_hessian` not implemented for $(ad).") end __init_identity_jacobian(u::Number, fu, α = true) = oftype(u, α) @@ -318,8 +161,6 @@ function check_termination(tc_cache, fx, x, xo, prob, alg, end @inline value(x) = x -@inline value(x::Dual) = ForwardDiff.value(x) -@inline value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) @inline __eval_f(prob, fx, x) = isinplace(prob) ? (prob.f(fx, x, prob.p); fx) : prob.f(x, prob.p) @@ -333,21 +174,44 @@ end end # Decide which AD backend to use -@inline __get_concrete_autodiff(prob, ad::ADTypes.AbstractADType; kwargs...) = ad +@inline function __get_concrete_autodiff(prob, ad::ADTypes.AbstractADType; kwargs...) + return __test_loaded_backend(ad, prob.u0) +end @inline function __get_concrete_autodiff(prob, ::Nothing; polyester::Val{P} = Val(true), kwargs...) where {P} - if ForwardDiff.can_dual(eltype(prob.u0)) - if P && __is_extension_loaded(Val(:PolyesterForwardDiff)) && - !(prob.u0 isa Number) && ArrayInterface.can_setindex(prob.u0) - return AutoPolyesterForwardDiff() - else - return AutoForwardDiff() - end - else + if P && __is_extension_loaded(Val(:PolyesterForwardDiff)) && + __can_dual(eltype(prob.u0)) && !(prob.u0 isa Number) && + ArrayInterface.can_setindex(prob.u0) + return AutoPolyesterForwardDiff() + elseif __is_extension_loaded(Val(:ForwardDiff)) && __can_dual(eltype(prob.u0)) + return AutoForwardDiff() + elseif __is_extension_loaded(Val(:FiniteDiff)) return AutoFiniteDiff() + else + error("No AD Package is Loaded: Please install and load `PolyesterForwardDiff.jl`, \ + `ForwardDiff.jl`, or `FiniteDiff.jl`.") end end +for backend in (:PolyesterForwardDiff, :ForwardDiff, :FiniteDiff, :Zygote) + adtype = Symbol(:Auto, backend) + msg1 = "ADType: `$(adtype)` is not compatible with !(ForwardDiff.can_dual(eltype(x)))." + msg2 = "ADType: `$(adtype)` requires the `$(backend).jl` package to be loaded." + @eval begin + function __test_loaded_backend(ad::$(adtype), x) + if __is_extension_loaded($(Val(backend))) + __compatible_ad_with_eltype(ad, x) && return ad + error($(msg1)) + end + error($(msg2)) + end + end +end + +function __can_dual end +@inline __compatible_ad_with_eltype(::Union{AutoForwardDiff, AutoPolyesterForwardDiff}, x) = __can_dual(eltype(x)) +@inline __compatible_ad_with_eltype(::ADTypes.AbstractADType, x) = true + @inline __reshape(x::Number, args...) = x @inline __reshape(x::AbstractArray, args...) = reshape(x, args...) diff --git a/test/core/23_test_problems_tests.jl b/test/core/23_test_problems_tests.jl index ad530c2..9d6d5b8 100644 --- a/test/core/23_test_problems_tests.jl +++ b/test/core/23_test_problems_tests.jl @@ -1,5 +1,6 @@ @testsetup module RobustnessTesting using LinearAlgebra, NonlinearProblemLibrary, NonlinearSolveBase, SciMLBase, Test +using FiniteDiff, ForwardDiff problems = NonlinearProblemLibrary.problems dicts = NonlinearProblemLibrary.dicts diff --git a/test/core/least_squares_tests.jl b/test/core/least_squares_tests.jl index 840a4f2..98a232f 100644 --- a/test/core/least_squares_tests.jl +++ b/test/core/least_squares_tests.jl @@ -1,5 +1,5 @@ @testitem "Nonlinear Least Squares" begin - using LinearAlgebra + using LinearAlgebra, FiniteDiff, ForwardDiff true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]) diff --git a/test/core/matrix_resizing_tests.jl b/test/core/matrix_resizing_tests.jl index 54cf86b..4baffde 100644 --- a/test/core/matrix_resizing_tests.jl +++ b/test/core/matrix_resizing_tests.jl @@ -1,4 +1,6 @@ @testitem "Matrix Resizing" begin + using FiniteDiff, ForwardDiff + ff(u, p) = u .* u .- p u0 = ones(2, 3) p = 2.0 diff --git a/test/core/rootfind_tests.jl b/test/core/rootfind_tests.jl index 848bf6c..b2c043b 100644 --- a/test/core/rootfind_tests.jl +++ b/test/core/rootfind_tests.jl @@ -1,8 +1,7 @@ @testsetup module RootfindingTesting using Reexport -@reexport using AllocCheck, - LinearSolve, StaticArrays, Random, LinearAlgebra, ForwardDiff, - NonlinearSolveBase +@reexport using AllocCheck, LinearSolve, StaticArrays, Random, LinearAlgebra, ForwardDiff, + NonlinearSolveBase, FiniteDiff import PolyesterForwardDiff quadratic_f(u, p) = u .* u .- p diff --git a/test/gpu/cuda_tests.jl b/test/gpu/cuda_tests.jl index 37999da..6fa4130 100644 --- a/test/gpu/cuda_tests.jl +++ b/test/gpu/cuda_tests.jl @@ -1,5 +1,5 @@ @testitem "Solving on GPUs" begin - using StaticArrays, CUDA + using StaticArrays, CUDA, FiniteDiff, ForwardDiff CUDA.allowscalar(false) @@ -37,7 +37,7 @@ end @testitem "CUDA Kernel Launch Test" begin - using StaticArrays, CUDA + using StaticArrays, CUDA, FiniteDiff, ForwardDiff CUDA.allowscalar(false) From 2d21b01b58c65883b9b935248e69590e203dfdf9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 27 Feb 2024 12:12:55 -0500 Subject: [PATCH 3/3] Precompile in extensions --- ...NonlinearSolveDiffResultsForwardDiffExt.jl | 30 ++++++++++++++++++ .../jacobian.jl | 2 +- ext/SimpleNonlinearSolveFiniteDiffExt.jl | 31 +++++++++++++++++++ src/SimpleNonlinearSolve.jl | 10 ------ 4 files changed, 62 insertions(+), 11 deletions(-) diff --git a/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/SimpleNonlinearSolveDiffResultsForwardDiffExt.jl b/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/SimpleNonlinearSolveDiffResultsForwardDiffExt.jl index 9a2d904..c66ded4 100644 --- a/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/SimpleNonlinearSolveDiffResultsForwardDiffExt.jl +++ b/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/SimpleNonlinearSolveDiffResultsForwardDiffExt.jl @@ -1,5 +1,7 @@ module SimpleNonlinearSolveDiffResultsForwardDiffExt +import PrecompileTools: @compile_workload, @setup_workload + import ADTypes: AutoForwardDiff, AutoPolyesterForwardDiff import ArrayInterface, SciMLBase, SimpleNonlinearSolve, DiffResults, ForwardDiff import FastClosures: @closure @@ -23,4 +25,32 @@ include("jacobian.jl") include("hessian.jl") include("forward_ad.jl") +@setup_workload begin + for T in (Float32, Float64) + prob_no_brack_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) + prob_no_brack_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, + T.([1.0, 1.0, 1.0]), T(2)) + prob_no_brack_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, + T.([1.0, 1.0, 1.0]), T(2)) + + algs = [SimpleNonlinearSolve.SimpleNewtonRaphson(; autodiff = AutoForwardDiff()), + SimpleNonlinearSolve.SimpleTrustRegion(; autodiff = AutoForwardDiff())] + + algs_no_iip = [SimpleNonlinearSolve.SimpleHalley(; autodiff = AutoForwardDiff())] + + @compile_workload begin + for alg in algs + solve(prob_no_brack_scalar, alg, abstol = T(1e-2)) + solve(prob_no_brack_iip, alg, abstol = T(1e-2)) + solve(prob_no_brack_oop, alg, abstol = T(1e-2)) + end + + for alg in algs_no_iip + solve(prob_no_brack_scalar, alg, abstol = T(1e-2)) + solve(prob_no_brack_oop, alg, abstol = T(1e-2)) + end + end + end +end + end diff --git a/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/jacobian.jl b/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/jacobian.jl index 286076f..fcc1130 100644 --- a/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/jacobian.jl +++ b/ext/SimpleNonlinearSolveDiffResultsForwardDiffExt/jacobian.jl @@ -92,4 +92,4 @@ function SimpleNonlinearSolve.__value_and_jacobian!( end out = f(ForwardDiff.Dual{T}(x, one(x))) return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out) -end \ No newline at end of file +end diff --git a/ext/SimpleNonlinearSolveFiniteDiffExt.jl b/ext/SimpleNonlinearSolveFiniteDiffExt.jl index 3cefa38..a3c02d8 100644 --- a/ext/SimpleNonlinearSolveFiniteDiffExt.jl +++ b/ext/SimpleNonlinearSolveFiniteDiffExt.jl @@ -1,7 +1,10 @@ module SimpleNonlinearSolveFiniteDiffExt +import PrecompileTools: @compile_workload, @setup_workload + import ADTypes: AutoFiniteDiff import SciMLBase, SimpleNonlinearSolve, FiniteDiff +import SciMLBase: NonlinearProblem, NonlinearLeastSquaresProblem, solve import StaticArraysCore: SArray @inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:FiniteDiff}) = true @@ -56,4 +59,32 @@ function SimpleNonlinearSolve.compute_jacobian_and_hessian( end end +@setup_workload begin + for T in (Float32, Float64) + prob_no_brack_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) + prob_no_brack_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, + T.([1.0, 1.0, 1.0]), T(2)) + prob_no_brack_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, + T.([1.0, 1.0, 1.0]), T(2)) + + algs = [SimpleNonlinearSolve.SimpleNewtonRaphson(; autodiff = AutoFiniteDiff()), + SimpleNonlinearSolve.SimpleTrustRegion(; autodiff = AutoFiniteDiff())] + + algs_no_iip = [SimpleNonlinearSolve.SimpleHalley(; autodiff = AutoFiniteDiff())] + + @compile_workload begin + for alg in algs + solve(prob_no_brack_scalar, alg, abstol = T(1e-2)) + solve(prob_no_brack_iip, alg, abstol = T(1e-2)) + solve(prob_no_brack_oop, alg, abstol = T(1e-2)) + end + + for alg in algs_no_iip + solve(prob_no_brack_scalar, alg, abstol = T(1e-2)) + solve(prob_no_brack_oop, alg, abstol = T(1e-2)) + end + end + end +end + end diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 0c6b59d..fd6a9d5 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -86,7 +86,6 @@ end @setup_workload begin for T in (Float32, Float64) - # FIXME Move this precompilation into the forwarddiff & finitediff extensions prob_no_brack_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) prob_no_brack_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, T.([1.0, 1.0, 1.0]), T(2)) @@ -96,21 +95,12 @@ end algs = [SimpleBroyden(), SimpleKlement(), SimpleDFSane(), SimpleLimitedMemoryBroyden(; threshold = 2)] - # algs = [SimpleNewtonRaphson(), SimpleTrustRegion()] - - # algs_no_iip = [SimpleHalley()] - @compile_workload begin for alg in algs solve(prob_no_brack_scalar, alg, abstol = T(1e-2)) solve(prob_no_brack_iip, alg, abstol = T(1e-2)) solve(prob_no_brack_oop, alg, abstol = T(1e-2)) end - - # for alg in algs_no_iip - # solve(prob_no_brack_scalar, alg, abstol = T(1e-2)) - # solve(prob_no_brack_oop, alg, abstol = T(1e-2)) - # end end prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p,