Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 26, 2023
1 parent 1033c15 commit b8cc83d
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 15 deletions.
8 changes: 7 additions & 1 deletion ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,10 @@ using SimpleNonlinearSolve, PolyesterForwardDiff
return J
end

end
@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f::F, J, x,
chunksize) where {F}
PolyesterForwardDiff.threaded_jacobian!(f, J, x, chunksize)
return J
end

end
9 changes: 5 additions & 4 deletions src/nlsolve/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ A low-overhead implementation of Halley's Method.
### Keyword Arguments
- `autodiff`: determines the backend used for the Hessian. Defaults to
`AutoForwardDiff()`. Valid choices are `AutoForwardDiff()` or `AutoFiniteDiff()`.
- `autodiff`: determines the backend used for the Hessian. Defaults to `nothing`. Valid
choices are `AutoForwardDiff()` or `AutoFiniteDiff()`.
!!! warning
Inplace Problems are currently not supported by this method.
"""
@kwdef @concrete struct SimpleHalley <: AbstractNewtonAlgorithm
autodiff = AutoForwardDiff()
autodiff = nothing
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
Expand All @@ -33,6 +33,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
fx = _get_fx(prob, x)
T = eltype(x)

autodiff = __get_concrete_autodiff(prob, alg.autodiff; polyester = Val(false))
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
termination_condition)

Expand All @@ -50,7 +51,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;

for i in 1:maxiters
# Hessian Computation is unfortunately type unstable
fx, dfx, d2fx = compute_jacobian_and_hessian(alg.autodiff, prob, fx, x)
fx, dfx, d2fx = compute_jacobian_and_hessian(autodiff, prob, fx, x)
setindex_trait(x) === CannotSetindex() && (A = dfx)

# Factorize Once and Reuse
Expand Down
18 changes: 12 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ function value_and_jacobian(ad, f::F, y, x::Number, p, cache; J = nothing) where
T = typeof(__standard_tag(ad.tag, x))
out = f(ForwardDiff.Dual{T}(x, one(x)), p)
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
elseif ad isa AutoPolyesterForwardDiff
# Just use ForwardDiff
T = typeof(__standard_tag(nothing, x))
out = f(ForwardDiff.Dual{T}(x, one(x)), p)
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
elseif ad isa AutoFiniteDiff
_f = Base.Fix2(f, p)
return _f(x), FiniteDiff.finite_difference_derivative(_f, x, ad.fdtype)
Expand Down Expand Up @@ -153,7 +158,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
J = ArrayInterface.can_setindex(x) ? similar(y, length(y), length(x)) : nothing
return J, __get_jacobian_config(ad, _f, x)
elseif ad isa AutoPolyesterForwardDiff
@assert ArrayInterface.can_setindex(x) "PolyesterForwardDiff requires mutable inputs."
@assert ArrayInterface.can_setindex(x) "PolyesterForwardDiff requires mutable inputs. Use AutoForwardDiff instead."
J = similar(y, length(y), length(x))
return J, __get_jacobian_config(ad, _f, x)
elseif ad isa AutoFiniteDiff
Expand Down Expand Up @@ -362,16 +367,17 @@ end
end

# Decide which AD backend to use
@inline __get_concrete_autodiff(prob, ad::ADTypes.AbstractADType) = ad
@inline function __get_concrete_autodiff(prob, ::Nothing)
@inline __get_concrete_autodiff(prob, ad::ADTypes.AbstractADType; kwargs...) = ad
@inline function __get_concrete_autodiff(prob, ::Nothing; polyester::Val{P} = Val(true),
kwargs...) where {P}
if ForwardDiff.can_dual(eltype(prob.u0))
if __is_extension_loaded(Val(:PolyesterForwardDiff)) && !(prob.u0 isa Number) &&
ArrayInterface.can_setindex(prob.u0)
if P && __is_extension_loaded(Val(:PolyesterForwardDiff)) &&
!(prob.u0 isa Number) && ArrayInterface.can_setindex(prob.u0)
return AutoPolyesterForwardDiff()
else
return AutoForwardDiff()
end
else
return AutoFiniteDiff()

Check warning on line 381 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L381

Added line #L381 was not covered by tests
end
end
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand Down
10 changes: 6 additions & 4 deletions test/basictests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using AllocCheck, BenchmarkTools, LinearSolve, SimpleNonlinearSolve, StaticArrays, Random,
LinearAlgebra, Test, ForwardDiff, DiffEqBase
import PolyesterForwardDiff

_nameof(x) = applicable(nameof, x) ? nameof(x) : _nameof(typeof(x))

Expand Down Expand Up @@ -29,20 +30,21 @@ const TERMINATION_CONDITIONS = [
@testset "$(alg)" for alg in (SimpleNewtonRaphson, SimpleTrustRegion)
# Eval else the alg is type unstable
@eval begin
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = AutoForwardDiff())
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = nothing)
prob = NonlinearProblem{false}(f, u0, p)
return solve(prob, $(alg)(; autodiff), abstol = 1e-9)
end

function benchmark_nlsolve_iip(f, u0, p = 2.0; autodiff = AutoForwardDiff())
function benchmark_nlsolve_iip(f, u0, p = 2.0; autodiff = nothing)
prob = NonlinearProblem{true}(f, u0, p)
return solve(prob, $(alg)(; autodiff), abstol = 1e-9)
end
end

@testset "AutoDiff: $(_nameof(autodiff))" for autodiff in (AutoFiniteDiff(),
AutoForwardDiff())
AutoForwardDiff(), AutoPolyesterForwardDiff())
@testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
u0 isa SVector && autodiff isa AutoPolyesterForwardDiff && continue
sol = benchmark_nlsolve_oop(quadratic_f, u0; autodiff)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
Expand Down Expand Up @@ -103,7 +105,7 @@ end
# --- SimpleHalley tests ---

@testset "SimpleHalley" begin
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = AutoForwardDiff())
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = nothing)
prob = NonlinearProblem{false}(f, u0, p)
return solve(prob, SimpleHalley(; autodiff), abstol = 1e-9)
end
Expand Down

0 comments on commit b8cc83d

Please sign in to comment.