Skip to content

Commit

Permalink
Merge pull request #153 from HodgeLab/mb/immutable
Browse files Browse the repository at this point in the history
add ImmutableNonlinearProblem
  • Loading branch information
ChrisRackauckas committed Jul 22, 2024
2 parents 40958d2 + f7c176b commit e63b1a8
Show file tree
Hide file tree
Showing 16 changed files with 132 additions and 40 deletions.
6 changes: 3 additions & 3 deletions ext/SimpleNonlinearSolveChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ module SimpleNonlinearSolveChainRulesCoreExt

using ChainRulesCore: ChainRulesCore, NoTangent
using DiffEqBase: DiffEqBase
using SciMLBase: ChainRulesOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
using SimpleNonlinearSolve: SimpleNonlinearSolve
using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem

# The expectation here is that no-one is using this directly inside a GPU kernel. We can
# eventually lift this requirement using a custom adjoint
function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up),
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
out, ∇internal = DiffEqBase._solve_adjoint(
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)
Expand Down
6 changes: 3 additions & 3 deletions ext/SimpleNonlinearSolveReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ module SimpleNonlinearSolveReverseDiffExt
using ArrayInterface: ArrayInterface
using DiffEqBase: DiffEqBase
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
using SimpleNonlinearSolve: SimpleNonlinearSolve
using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem
import SimpleNonlinearSolve: __internal_solve_up

for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
@eval begin
function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
Expand Down
6 changes: 3 additions & 3 deletions ext/SimpleNonlinearSolveTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
module SimpleNonlinearSolveTrackerExt

using DiffEqBase: DiffEqBase
using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem, remake
using SimpleNonlinearSolve: SimpleNonlinearSolve
using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem
using Tracker: Tracker, TrackedArray

for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
@eval begin
function SimpleNonlinearSolve.__internal_solve_up(
prob::$(pType), sensealg, u0::TrackedArray,
Expand Down
21 changes: 17 additions & 4 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess
norm, transpose
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
using Reexport: @reexport
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
AbstractNonlinearFunction, StandardNonlinearProblem,
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val
build_solution, isinplace, _unwrap_val, warn_paramtype
using Setfield: @set!
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size

Expand All @@ -35,7 +36,7 @@ abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorit
abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end

@inline __is_extension_loaded(::Val) = false

include("immutable_nonlinear_problem.jl")
include("utils.jl")
include("linesearch.jl")

Expand Down Expand Up @@ -70,6 +71,18 @@ end
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
prob = convert(ImmutableNonlinearProblem, prob)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
new_u0 = u0 !== nothing ? u0 : prob.u0
new_p = p !== nothing ? p : prob.p
return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p,
p === nothing, alg, args...; prob.kwargs..., kwargs...)
end

function SciMLBase.solve(prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
Expand All @@ -79,7 +92,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSol
p === nothing, alg, args...; prob.kwargs..., kwargs...)
end

function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed,
function __internal_solve_up(_prob::ImmutableNonlinearProblem, sensealg, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
return SciMLBase.__solve(prob, alg, args...; kwargs...)
Expand Down
37 changes: 24 additions & 13 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
@eval function SciMLBase.solve(
prob::$(pType){<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end
function SciMLBase.solve(
prob::NonlinearLeastSquaresProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

function SciMLBase.solve(
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
prob = convert(ImmutableNonlinearProblem, prob)
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
Expand All @@ -31,7 +42,7 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
end

function __nlsolve_ad(
prob::Union{IntervalNonlinearProblem, NonlinearProblem}, alg, args...; kwargs...)
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem}, alg, args...; kwargs...)
p = value(prob.p)
if prob isa IntervalNonlinearProblem
tspan = value.(prob.tspan)
Expand Down
67 changes: 67 additions & 0 deletions src/immutable_nonlinear_problem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
struct ImmutableNonlinearProblem{uType, isinplace, P, F, K, PT} <:
AbstractNonlinearProblem{uType, isinplace}
f::F
u0::uType
p::P
problem_type::PT
kwargs::K
@add_kwonly function ImmutableNonlinearProblem{iip}(f::AbstractNonlinearFunction{iip}, u0,
p = NullParameters(),
problem_type = StandardNonlinearProblem();
kwargs...) where {iip}
if haskey(kwargs, :p)
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to `NonlinearProblem`. This is not supported.")
end
warn_paramtype(p)
new{typeof(u0), iip, typeof(p), typeof(f),
typeof(kwargs), typeof(problem_type)}(f,
u0,
p,
problem_type,
kwargs)
end

"""
Define a steady state problem using the given function.
`isinplace` optionally sets whether the function is inplace or not.
This is determined automatically, but not inferred.
"""
function ImmutableNonlinearProblem{iip}(f, u0, p = NullParameters(); kwargs...) where {iip}
ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
end
end

"""
Define a nonlinear problem using an instance of
[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction).
"""
function ImmutableNonlinearProblem(f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
end

function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
end

"""
Define a ImmutableNonlinearProblem problem from SteadyStateProblem
"""
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p)
end


function Base.convert(::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
ImmutableNonlinearProblem{isinplace(prob)}(prob.f,
prob.u0,
prob.p,
prob.problem_type;
prob.kwargs...)
end

function DiffEqBase.get_concrete_problem(prob::ImmutableNonlinearProblem, isadapt; kwargs...)
u0 = DiffEqBase.get_concrete_u0(prob, isadapt, nothing, kwargs)
u0 = DiffEqBase.promote_u0(u0, prob.p, nothing)
p = DiffEqBase.get_concrete_p(prob, kwargs)
DiffEqBase.remake(prob; u0 = u0, p = p)
end
2 changes: 1 addition & 1 deletion src/nlsolve/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ end

__get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS)

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleBroyden, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real =
σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, η_strategy)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane{M}, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false,
termination_condition = nothing, kwargs...) where {M}
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ A low-overhead implementation of Halley's Method.
autodiff = nothing
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ method is non-allocating on scalar and static array problems.
"""
struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
6 changes: 3 additions & 3 deletions src/nlsolve/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function SimpleLimitedMemoryBroyden(;
return SimpleLimitedMemoryBroyden{_unwrap_val(threshold), _unwrap_val(linesearch)}(alpha)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden,
args...; termination_condition = nothing, kwargs...)
if prob.u0 isa SArray
if termination_condition === nothing ||
Expand All @@ -44,7 +44,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyd
return __generic_solve(prob, alg, args...; termination_condition, kwargs...)
end

@views function __generic_solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
@views function __generic_solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden,
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down Expand Up @@ -114,7 +114,7 @@ end
# Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite
# finicky, so we'll implement it separately from the generic version
# Ignore termination_condition. Don't pass things into internal functions
function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
function __static_solve(prob::ImmutableNonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
args...; abstol = nothing, maxiters = 1000, kwargs...)
x = prob.u0
fx = _get_fx(prob, x)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ end

const SimpleGaussNewton = SimpleNewtonRaphson

function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
function SciMLBase.__solve(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
alg::SimpleNewtonRaphson, args...; abstol = nothing, reltol = nothing,
maxiters = 1000, termination_condition = nothing, alias_u0 = false, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ scalar and static array problems.
nlsolve_update_rule = Val(false)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegion, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ end
error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype`")
return _get_fx(prob.f, x, prob.p)
end
@inline _get_fx(prob::NonlinearProblem, x) = _get_fx(prob.f, x, prob.p)
@inline _get_fx(prob::ImmutableNonlinearProblem, x) = _get_fx(prob.f, x, prob.p)
@inline function _get_fx(f::NonlinearFunction, x, p)
if isinplace(f)
if f.resid_prototype !== nothing
Expand All @@ -145,7 +145,7 @@ end
# different. NonlinearSolve is more for robust / cached solvers while SimpleNonlinearSolve
# is meant for low overhead solvers, users can opt into the other termination modes but the
# default is to use the least overhead version.
function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing)
function init_termination_cache(prob::ImmutableNonlinearProblem, abstol, reltol, du, u, ::Nothing)
return init_termination_cache(
prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix1(maximum, abs)))
end
Expand All @@ -155,14 +155,14 @@ function init_termination_cache(
prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix2(norm, 2)))
end

function init_termination_cache(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
function init_termination_cache(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
T = promote_type(eltype(du), eltype(u))
abstol = __get_tolerance(u, abstol, T)
reltol = __get_tolerance(u, reltol, T)
tc_ = if hasfield(typeof(tc), :internalnorm) && tc.internalnorm === nothing
internalnorm = ifelse(
prob isa NonlinearProblem, Base.Fix1(maximum, abs), Base.Fix2(norm, 2))
prob isa ImmutableNonlinearProblem, Base.Fix1(maximum, abs), Base.Fix2(norm, 2))
DiffEqBase.set_termination_mode_internalnorm(tc, internalnorm)
else
tc
Expand Down
2 changes: 1 addition & 1 deletion test/core/adjoint_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)
∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p)))
∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p)

@test ∂p_zygote ∂p_tracker ∂p_reversediff
@test ∂p_zygote ∂p_forwarddiff ∂p_tracker ∂p_reversediff
end
1 change: 1 addition & 0 deletions test/gpu/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ end
end

prob = NonlinearProblem{false}(f, @SVector[1.0f0, 1.0f0])
prob = convert(SimpleNonlinearSolve.ImmutableNonlinearProblem, prob)

@testset "$(nameof(typeof(alg)))" for alg in (
SimpleNewtonRaphson(), SimpleDFSane(), SimpleTrustRegion(),
Expand Down

0 comments on commit e63b1a8

Please sign in to comment.