Skip to content

Commit

Permalink
Fix the dispatch on polyalg
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 24, 2024
1 parent 7f41f1c commit 638babf
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 10 deletions.
6 changes: 3 additions & 3 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,9 @@ version = "0.1.6"

[[deps.GenericSchur]]
deps = ["LinearAlgebra", "Printf"]
git-tree-sha1 = "fb69b2a645fa69ba5f474af09221b9308b160ce6"
git-tree-sha1 = "af49a0851f8113fcfae2ef5027c6d49d0acec39b"
uuid = "c145ed77-6b09-5dd9-b285-bf645a82121e"
version = "0.5.3"
version = "0.5.4"

[[deps.Graphs]]
deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"]
Expand Down Expand Up @@ -608,7 +608,7 @@ version = "1.2.0"
[[deps.NonlinearSolve]]
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Preferences", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "TimerOutputs"]
git-tree-sha1 = "0e464ca0e5d44a88c91f394c3f9a9448523e378b"
repo-rev = "ap/tstable_findmin"
repo-rev = "master"
repo-url = "https://github.com/SciML/NonlinearSolve.jl.git"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
version = "3.8.2"
Expand Down
4 changes: 2 additions & 2 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function __solve_nlproblem!(

# NOTE: u_at_nodes is updated inplace
nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(prob, alg.nlsolve)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, alg.nlsolve)
__solve(nlprob, nlsolve_alg; kwargs..., alias_u0 = true)

return nothing
Expand Down Expand Up @@ -188,7 +188,7 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_

# NOTE: u_at_nodes is updated inplace
nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(prob, alg.nlsolve)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, alg.nlsolve)
__solve(nlprob, nlsolve_alg; kwargs..., alias_u0 = true)

return nothing
Expand Down
2 changes: 1 addition & 1 deletion src/solve/single_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;),
nlf = __unsafe_nonlinearfunction{iip}(
loss_fn; jac_prototype, resid_prototype, jac = jac_fn)
nlprob = __internal_nlsolve_problem(prob, resid_prototype, u0, nlf, vec(u0), prob.p)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(prob, alg.nlsolve)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, alg.nlsolve)
nlsol = __solve(nlprob, nlsolve_alg; nlsolve_kwargs..., verbose, kwargs...)

# There is no way to reinit with the same cache with different cache. But not saving
Expand Down
6 changes: 2 additions & 4 deletions test/shooting/nlls_tests.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
# FIXME: The nonlinear solve polyalgorithm for NLLS is currently broken because of Bastin
# Jv & Jᵀv computation with the cached ODE solve
@testitem "Overconstrained BVP" begin
using LinearAlgebra, JET

SOLVERS = [
# Shooting(Tsit5()),
Shooting(Tsit5()),
Shooting(
Tsit5(), LevenbergMarquardt(; autodiff = AutoForwardDiff(; chunksize = 2))), Shooting(
Tsit5(), LevenbergMarquardt(; autodiff = AutoFiniteDiff())),
Shooting(Tsit5(), GaussNewton(; autodiff = AutoForwardDiff(; chunksize = 2))),
Shooting(Tsit5(), GaussNewton(; autodiff = AutoFiniteDiff())),
Shooting(Tsit5(), TrustRegion(; autodiff = AutoForwardDiff(; chunksize = 2))),
Shooting(Tsit5(), TrustRegion(; autodiff = AutoFiniteDiff())),
# MultipleShooting(10, Tsit5()),
MultipleShooting(10, Tsit5()),
MultipleShooting(
10, Tsit5(), LevenbergMarquardt(; autodiff = AutoForwardDiff(; chunksize = 2))), MultipleShooting(
10, Tsit5(), LevenbergMarquardt(; autodiff = AutoFiniteDiff())),
Expand Down

0 comments on commit 638babf

Please sign in to comment.