diff --git a/Project.toml b/Project.toml index 225b360..8d279f3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "1.6.0" +version = "1.7.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -39,13 +39,13 @@ SimpleNonlinearSolveZygoteExt = "Zygote" ADTypes = "0.2.6" AllocCheck = "0.1.1" Aqua = "0.8" -ArrayInterface = "7.7" +ArrayInterface = "7.8" CUDA = "5.2" ChainRulesCore = "1.22" ConcreteStructs = "0.2.3" -DiffEqBase = "6.146" +DiffEqBase = "6.149" DiffResults = "1.1" -FastClosures = "0.3" +FastClosures = "0.3.2" FiniteDiff = "2.22" ForwardDiff = "0.10.36" LinearAlgebra = "1.10" @@ -59,7 +59,7 @@ Random = "1.10" ReTestItems = "1.23" Reexport = "1.2" ReverseDiff = "1.15" -SciMLBase = "2.26.3" +SciMLBase = "2.28.0" SciMLSensitivity = "7.56" StaticArrays = "1.9" StaticArraysCore = "1.4.2" diff --git a/src/nlsolve/broyden.jl b/src/nlsolve/broyden.jl index 15e5447..6fe1214 100644 --- a/src/nlsolve/broyden.jl +++ b/src/nlsolve/broyden.jl @@ -48,7 +48,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...; @bb δJ⁻¹n = copy(x) @bb δJ⁻¹ = copy(J⁻¹) - abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x, + abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, termination_condition) ls_cache = __get_linesearch(alg) === Val(true) ? diff --git a/src/nlsolve/dfsane.jl b/src/nlsolve/dfsane.jl index 856e31f..9f09264 100644 --- a/src/nlsolve/dfsane.jl +++ b/src/nlsolve/dfsane.jl @@ -70,7 +70,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args... τ_min = T(alg.τ_min) τ_max = T(alg.τ_max) - abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x, + abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, termination_condition) fx_norm = NONLINEARSOLVE_DEFAULT_NORM(fx)^nexp diff --git a/src/nlsolve/halley.jl b/src/nlsolve/halley.jl index 4623322..934dc47 100644 --- a/src/nlsolve/halley.jl +++ b/src/nlsolve/halley.jl @@ -34,7 +34,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...; T = eltype(x) autodiff = __get_concrete_autodiff(prob, alg.autodiff) - abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x, + abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, termination_condition) @bb xo = copy(x) diff --git a/src/nlsolve/klement.jl b/src/nlsolve/klement.jl index 680b9cd..c2c8b44 100644 --- a/src/nlsolve/klement.jl +++ b/src/nlsolve/klement.jl @@ -13,7 +13,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...; T = eltype(x) fx = _get_fx(prob, x) - abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x, + abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, termination_condition) @bb δx = copy(x) diff --git a/src/nlsolve/lbroyden.jl b/src/nlsolve/lbroyden.jl index 145a546..600892e 100644 --- a/src/nlsolve/lbroyden.jl +++ b/src/nlsolve/lbroyden.jl @@ -61,7 +61,7 @@ end U, Vᵀ = __init_low_rank_jacobian(x, fx, x isa StaticArray ? threshold : Val(η)) - abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x, + abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, termination_condition) @bb xo = copy(x) diff --git a/src/nlsolve/raphson.jl b/src/nlsolve/raphson.jl index e84f595..9735d0c 100644 --- a/src/nlsolve/raphson.jl +++ b/src/nlsolve/raphson.jl @@ -32,7 +32,7 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresPr @bb xo = copy(x) J, jac_cache = jacobian_cache(autodiff, prob.f, fx, x, prob.p) - abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x, + abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, termination_condition) for i in 1:maxiters diff --git a/src/nlsolve/trustRegion.jl b/src/nlsolve/trustRegion.jl index 03c4692..e6ccf65 100644 --- a/src/nlsolve/trustRegion.jl +++ b/src/nlsolve/trustRegion.jl @@ -88,7 +88,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args. J, jac_cache = jacobian_cache(autodiff, prob.f, fx, x, prob.p) fx, ∇f = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J) - abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x, + abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x, termination_condition) # Set default trust region radius if not specified by user. diff --git a/src/utils.jl b/src/utils.jl index 38be343..76e91fc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -288,14 +288,30 @@ end # different. NonlinearSolve is more for robust / cached solvers while SimpleNonlinearSolve # is meant for low overhead solvers, users can opt into the other termination modes but the # default is to use the least overhead version. -function init_termination_cache(abstol, reltol, du, u, ::Nothing) - return init_termination_cache(abstol, reltol, du, u, AbsNormTerminationMode()) +function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing) + return init_termination_cache(prob, abstol, reltol, du, u, + AbsNormTerminationMode(Base.Fix1(maximum, abs))) end -function init_termination_cache(abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode) +function init_termination_cache( + prob::NonlinearLeastSquaresProblem, abstol, reltol, du, u, ::Nothing) + return init_termination_cache(prob, abstol, reltol, du, u, + AbsNormTerminationMode(Base.Fix2(norm, 2))) +end + +function init_termination_cache( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, + abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode) T = promote_type(eltype(du), eltype(u)) abstol = __get_tolerance(u, abstol, T) reltol = __get_tolerance(u, reltol, T) - tc_cache = init(du, u, tc; abstol, reltol) + tc_ = if hasfield(typeof(tc), :internalnorm) && tc.internalnorm === nothing + internalnorm = ifelse( + prob isa NonlinearProblem, Base.Fix1(maximum, abs), Base.Fix2(norm, 2)) + DiffEqBase.set_termination_mode_internalnorm(tc, internalnorm) + else + tc + end + tc_cache = init(du, u, tc_; abstol, reltol, use_deprecated_retcodes = Val(false)) return DiffEqBase.get_abstol(tc_cache), DiffEqBase.get_reltol(tc_cache), tc_cache end @@ -305,45 +321,25 @@ function check_termination(tc_cache, fx, x, xo, prob, alg) end function check_termination(tc_cache, fx, x, xo, prob, alg, ::AbstractNonlinearTerminationMode) - if Bool(tc_cache(fx, x, xo)) + tc_cache(fx, x, xo) && return build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) - end return nothing end function check_termination(tc_cache, fx, x, xo, prob, alg, ::AbstractSafeNonlinearTerminationMode) - if Bool(tc_cache(fx, x, xo)) - if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success - retcode = ReturnCode.Success - elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination - retcode = ReturnCode.ConvergenceFailure - elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination - retcode = ReturnCode.Unstable - else - error("Unknown termination code: $(tc_cache.retcode)") - end - return build_solution(prob, alg, x, fx; retcode) - end + tc_cache(fx, x, xo) && + return build_solution(prob, alg, x, fx; retcode = tc_cache.retcode) return nothing end function check_termination(tc_cache, fx, x, xo, prob, alg, ::AbstractSafeBestNonlinearTerminationMode) - if Bool(tc_cache(fx, x, xo)) - if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success - retcode = ReturnCode.Success - elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination - retcode = ReturnCode.ConvergenceFailure - elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination - retcode = ReturnCode.Unstable - else - error("Unknown termination code: $(tc_cache.retcode)") - end + if tc_cache(fx, x, xo) if isinplace(prob) prob.f(fx, x, prob.p) else fx = prob.f(x, prob.p) end - return build_solution(prob, alg, tc_cache.u, fx; retcode) + return build_solution(prob, alg, tc_cache.u, fx; retcode = tc_cache.retcode) end return nothing end