Skip to content

Commit

Permalink
Merge pull request #113 from avik-pal/ap/polyester_mode
Browse files Browse the repository at this point in the history
Add Polyester ForwardDiff support
  • Loading branch information
ChrisRackauckas committed Dec 26, 2023
2 parents 45f8d73 + b8cc83d commit da36df6
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 33 deletions.
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.0.4"
version = "1.1.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -17,8 +17,14 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[extensions]
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"

[weakdeps]
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"

[compat]
ADTypes = "0.2"
ADTypes = "0.2.6"
ArrayInterface = "7"
ConcreteStructs = "0.2"
DiffEqBase = "6.126"
Expand Down
19 changes: 19 additions & 0 deletions ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module SimpleNonlinearSolvePolyesterForwardDiffExt

using SimpleNonlinearSolve, PolyesterForwardDiff

@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:PolyesterForwardDiff}) = true

@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f!::F, y, J, x,
chunksize) where {F}
PolyesterForwardDiff.threaded_jacobian!(f!, y, J, x, chunksize)
return J
end

@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f::F, J, x,
chunksize) where {F}
PolyesterForwardDiff.threaded_jacobian!(f, J, x, chunksize)
return J
end

end
4 changes: 3 additions & 1 deletion src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat
import DiffEqBase: AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode,
NonlinearSafeTerminationReturnCode, get_termination_mode,
NONLINEARSOLVE_DEFAULT_NORM
NONLINEARSOLVE_DEFAULT_NORM, _get_tolerance
using FiniteDiff, ForwardDiff
import ForwardDiff: Dual
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
Expand All @@ -23,6 +23,8 @@ abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorith
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end

@inline __is_extension_loaded(::Val) = false

include("utils.jl")

## Nonlinear Solvers
Expand Down
16 changes: 10 additions & 6 deletions src/nlsolve/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ A low-overhead implementation of Halley's Method.
### Keyword Arguments
- `autodiff`: determines the backend used for the Hessian. Defaults to
`AutoForwardDiff()`. Valid choices are `AutoForwardDiff()` or `AutoFiniteDiff()`.
- `autodiff`: determines the backend used for the Hessian. Defaults to `nothing`. Valid
choices are `AutoForwardDiff()` or `AutoFiniteDiff()`.
!!! warning
Inplace Problems are currently not supported by this method.
"""
@kwdef @concrete struct SimpleHalley <: AbstractNewtonAlgorithm
autodiff = AutoForwardDiff()
autodiff = nothing
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
Expand All @@ -33,6 +33,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
fx = _get_fx(prob, x)
T = eltype(x)

autodiff = __get_concrete_autodiff(prob, alg.autodiff; polyester = Val(false))
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
termination_condition)

Expand All @@ -50,17 +51,20 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;

for i in 1:maxiters
# Hessian Computation is unfortunately type unstable
fx, dfx, d2fx = compute_jacobian_and_hessian(alg.autodiff, prob, fx, x)
fx, dfx, d2fx = compute_jacobian_and_hessian(autodiff, prob, fx, x)
setindex_trait(x) === CannotSetindex() && (A = dfx)

aᵢ = dfx \ _vec(fx)
# Factorize Once and Reuse
dfx_fact = factorize(dfx)

aᵢ = dfx_fact \ _vec(fx)
A_ = _vec(A)
@bb A_ = d2fx × aᵢ
A = _restructure(A, A_)

@bb Aaᵢ = A × aᵢ
@bb A .*= -1
bᵢ = dfx \ Aaᵢ
bᵢ = dfx_fact \ Aaᵢ

cᵢ_ = _vec(cᵢ)
@bb @. cᵢ_ = (aᵢ * aᵢ) / (-aᵢ + (T(0.5) * bᵢ))
Expand Down
12 changes: 7 additions & 5 deletions src/nlsolve/raphson.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
SimpleNewtonRaphson(autodiff)
SimpleNewtonRaphson(; autodiff = AutoForwardDiff())
SimpleNewtonRaphson(; autodiff = nothing)
A low-overhead implementation of Newton-Raphson. This method is non-allocating on scalar
and static array problems.
Expand All @@ -14,10 +14,11 @@ and static array problems.
### Keyword Arguments
- `autodiff`: determines the backend used for the Jacobian. Defaults to
`AutoForwardDiff()`. Valid choices are `AutoForwardDiff()` or `AutoFiniteDiff()`.
`nothing`. Valid choices are `AutoPolyesterForwardDiff()`, `AutoForwardDiff()` or
`AutoFiniteDiff()`.
"""
@kwdef @concrete struct SimpleNewtonRaphson <: AbstractNewtonAlgorithm
autodiff = AutoForwardDiff()
autodiff = nothing
end

const SimpleGaussNewton = SimpleNewtonRaphson
Expand All @@ -27,14 +28,15 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresPr
maxiters = 1000, termination_condition = nothing, alias_u0 = false, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
fx = _get_fx(prob, x)
autodiff = __get_concrete_autodiff(prob, alg.autodiff)
@bb xo = copy(x)
J, jac_cache = jacobian_cache(alg.autodiff, prob.f, fx, x, prob.p)
J, jac_cache = jacobian_cache(autodiff, prob.f, fx, x, prob.p)

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
termination_condition)

for i in 1:maxiters
fx, dfx = value_and_jacobian(alg.autodiff, prob.f, fx, x, prob.p, jac_cache; J)
fx, dfx = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)

if i == 1
iszero(fx) && build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
Expand Down
12 changes: 7 additions & 5 deletions src/nlsolve/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ scalar and static array problems.
### Keyword Arguments
- `autodiff`: determines the backend used for the Jacobian. Defaults to
`AutoForwardDiff()`. Valid choices are `AutoForwardDiff()` or `AutoFiniteDiff()`.
`nothing`. Valid choices are `AutoPolyesterForwardDiff()`, `AutoForwardDiff()` or
`AutoFiniteDiff()`.
- `max_trust_radius`: the maximum radius of the trust region. Defaults to
`max(norm(f(u0)), maximum(u0) - minimum(u0))`.
- `initial_trust_radius`: the initial trust region radius. Defaults to
Expand All @@ -37,7 +38,7 @@ scalar and static array problems.
row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`.
"""
@kwdef @concrete struct SimpleTrustRegion <: AbstractNewtonAlgorithm
autodiff = AutoForwardDiff()
autodiff = nothing
max_trust_radius = 0.0
initial_trust_radius = 0.0
step_threshold = 0.0001
Expand All @@ -61,11 +62,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args.
t₁ = T(alg.shrink_factor)
t₂ = T(alg.expand_factor)
max_shrink_times = alg.max_shrink_times
autodiff = __get_concrete_autodiff(prob, alg.autodiff)

fx = _get_fx(prob, x)
@bb xo = copy(x)
J, jac_cache = jacobian_cache(alg.autodiff, prob.f, fx, x, prob.p)
fx, ∇f = value_and_jacobian(alg.autodiff, prob.f, fx, x, prob.p, jac_cache; J)
J, jac_cache = jacobian_cache(autodiff, prob.f, fx, x, prob.p)
fx, ∇f = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
termination_condition)
Expand Down Expand Up @@ -116,7 +118,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args.
# Take the step.
@bb @. xo = x

fx, ∇f = value_and_jacobian(alg.autodiff, prob.f, fx, x, prob.p, jac_cache; J)
fx, ∇f = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)

# Update the trust region radius.
(r > η₃) && (norm(δ) Δ) &&= min(t₂ * Δ, Δₘₐₓ))
Expand Down
51 changes: 41 additions & 10 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,6 @@ Return the maximum of `a` and `b` if `x1 > x0`, otherwise return the minimum.
"""
__max_tdir(a, b, x0, x1) = ifelse(x1 > x0, max(a, b), min(a, b))

__cvt_real(::Type{T}, ::Nothing) where {T} = nothing
__cvt_real(::Type{T}, x) where {T} = real(T(x))

_get_tolerance(η, ::Type{T}) where {T} = __cvt_real(T, η)
function _get_tolerance(::Nothing, ::Type{T}) where {T}
η = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)
return _get_tolerance(η, T)
end

__standard_tag(::Nothing, x) = ForwardDiff.Tag(SimpleNonlinearSolveTag(), eltype(x))
__standard_tag(tag::ForwardDiff.Tag, _) = tag
__standard_tag(tag, x) = ForwardDiff.Tag(tag, eltype(x))
Expand All @@ -60,6 +51,12 @@ function __get_jacobian_config(ad::AutoForwardDiff{CS}, f!, y, x) where {CS}
return ForwardDiff.JacobianConfig(f!, y, x, ck, tag)
end

function __get_jacobian_config(ad::AutoPolyesterForwardDiff{CS}, args...) where {CS}
x = last(args)
return (CS === nothing || CS 0) ? __pick_forwarddiff_chunk(x) :
ForwardDiff.Chunk{CS}()
end

"""
value_and_jacobian(ad, f, y, x, p, cache; J = nothing)
Expand All @@ -81,6 +78,9 @@ function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F,
FiniteDiff.finite_difference_jacobian!(J, _f, x, cache)
_f(y, x)
return y, J
elseif ad isa AutoPolyesterForwardDiff
__polyester_forwarddiff_jacobian!(_f, y, J, x, cache)
return y, J
else
throw(ArgumentError("Unsupported AD method: $(ad)"))
end
Expand All @@ -100,19 +100,30 @@ function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F,
elseif ad isa AutoFiniteDiff
J_fd = FiniteDiff.finite_difference_jacobian(_f, x, cache)
return _f(x), J_fd
elseif ad isa AutoPolyesterForwardDiff
__polyester_forwarddiff_jacobian!(_f, J, x, cache)
return _f(x), J
else
throw(ArgumentError("Unsupported AD method: $(ad)"))
end
end
end

# Declare functions
function __polyester_forwarddiff_jacobian! end

function value_and_jacobian(ad, f::F, y, x::Number, p, cache; J = nothing) where {F}
if DiffEqBase.has_jac(f)
return f(x, p), f.jac(x, p)
elseif ad isa AutoForwardDiff
T = typeof(__standard_tag(ad.tag, x))
out = f(ForwardDiff.Dual{T}(x, one(x)), p)
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
elseif ad isa AutoPolyesterForwardDiff
# Just use ForwardDiff
T = typeof(__standard_tag(nothing, x))
out = f(ForwardDiff.Dual{T}(x, one(x)), p)
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
elseif ad isa AutoFiniteDiff
_f = Base.Fix2(f, p)
return _f(x), FiniteDiff.finite_difference_derivative(_f, x, ad.fdtype)
Expand All @@ -132,7 +143,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
J = similar(y, length(y), length(x))
if DiffEqBase.has_jac(f)
return J, nothing
elseif ad isa AutoForwardDiff
elseif ad isa AutoForwardDiff || ad isa AutoPolyesterForwardDiff
return J, __get_jacobian_config(ad, _f, y, x)
elseif ad isa AutoFiniteDiff
return J, FiniteDiff.JacobianCache(copy(x), copy(y), copy(y), ad.fdtype)
Expand All @@ -146,6 +157,10 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
elseif ad isa AutoForwardDiff
J = ArrayInterface.can_setindex(x) ? similar(y, length(y), length(x)) : nothing
return J, __get_jacobian_config(ad, _f, x)
elseif ad isa AutoPolyesterForwardDiff
@assert ArrayInterface.can_setindex(x) "PolyesterForwardDiff requires mutable inputs. Use AutoForwardDiff instead."
J = similar(y, length(y), length(x))
return J, __get_jacobian_config(ad, _f, x)
elseif ad isa AutoFiniteDiff
return nothing, FiniteDiff.JacobianCache(copy(x), copy(y), copy(y), ad.fdtype)
else
Expand Down Expand Up @@ -350,3 +365,19 @@ end
(alias || !ArrayInterface.can_setindex(typeof(x))) && return x
return deepcopy(x)
end

# Decide which AD backend to use
@inline __get_concrete_autodiff(prob, ad::ADTypes.AbstractADType; kwargs...) = ad
@inline function __get_concrete_autodiff(prob, ::Nothing; polyester::Val{P} = Val(true),
kwargs...) where {P}
if ForwardDiff.can_dual(eltype(prob.u0))
if P && __is_extension_loaded(Val(:PolyesterForwardDiff)) &&
!(prob.u0 isa Number) && ArrayInterface.can_setindex(prob.u0)
return AutoPolyesterForwardDiff()
else
return AutoForwardDiff()
end
else
return AutoFiniteDiff()
end
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand Down
10 changes: 6 additions & 4 deletions test/basictests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using AllocCheck, BenchmarkTools, LinearSolve, SimpleNonlinearSolve, StaticArrays, Random,
LinearAlgebra, Test, ForwardDiff, DiffEqBase
import PolyesterForwardDiff

_nameof(x) = applicable(nameof, x) ? nameof(x) : _nameof(typeof(x))

Expand Down Expand Up @@ -29,20 +30,21 @@ const TERMINATION_CONDITIONS = [
@testset "$(alg)" for alg in (SimpleNewtonRaphson, SimpleTrustRegion)
# Eval else the alg is type unstable
@eval begin
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = AutoForwardDiff())
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = nothing)
prob = NonlinearProblem{false}(f, u0, p)
return solve(prob, $(alg)(; autodiff), abstol = 1e-9)
end

function benchmark_nlsolve_iip(f, u0, p = 2.0; autodiff = AutoForwardDiff())
function benchmark_nlsolve_iip(f, u0, p = 2.0; autodiff = nothing)
prob = NonlinearProblem{true}(f, u0, p)
return solve(prob, $(alg)(; autodiff), abstol = 1e-9)
end
end

@testset "AutoDiff: $(_nameof(autodiff))" for autodiff in (AutoFiniteDiff(),
AutoForwardDiff())
AutoForwardDiff(), AutoPolyesterForwardDiff())
@testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
u0 isa SVector && autodiff isa AutoPolyesterForwardDiff && continue
sol = benchmark_nlsolve_oop(quadratic_f, u0; autodiff)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
Expand Down Expand Up @@ -103,7 +105,7 @@ end
# --- SimpleHalley tests ---

@testset "SimpleHalley" begin
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = AutoForwardDiff())
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = nothing)
prob = NonlinearProblem{false}(f, u0, p)
return solve(prob, SimpleHalley(; autodiff), abstol = 1e-9)
end
Expand Down

2 comments on commit da36df6

@avik-pal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/97767

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.1.0 -m "<description of version>" da36df6d9d001b4f3a882b936b75c39647da67a4
git push origin v1.1.0

Please sign in to comment.