Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert to immutable for CUDA tests #151

Closed
wants to merge 12 commits into from
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down Expand Up @@ -63,7 +63,6 @@ SciMLBase = "2.37.0"
SciMLSensitivity = "7.58"
Setfield = "1.1.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4.2"
Test = "1.10"
Tracker = "0.2.33"
Zygote = "0.6.69"
Expand Down
25 changes: 19 additions & 6 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati
mul!, norm, transpose
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
using Reexport: @reexport
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
AbstractNonlinearFunction, StandardNonlinearProblem,
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val
build_solution, isinplace, _unwrap_val, warn_paramtype
using Setfield: @set!
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
using StaticArrays: StaticArray, SVector, SMatrix, SArray, MArray, Size, SizedVector, SizedMatrix
end

const DI = DifferentiationInterface
Expand All @@ -37,7 +38,7 @@ abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorit
abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end

@inline __is_extension_loaded(::Val) = false

include("immutable_nonlinear_problem.jl")
include("utils.jl")
include("linesearch.jl")

Expand Down Expand Up @@ -69,9 +70,21 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...;
return solve(prob, ITP(), args...; prob.kwargs..., kwargs...)
end

# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
# Bypass the highlevel checks for NonlinearProblem for Simple Algorithms
function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
prob = convert(ImmutableNonlinearProblem, prob)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
new_u0 = u0 !== nothing ? u0 : prob.u0
new_p = p !== nothing ? p : prob.p
return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p,
p === nothing, alg, args...; prob.kwargs..., kwargs...)
end

function SciMLBase.solve(prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
Expand All @@ -81,7 +94,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSol
p === nothing, alg, args...; prob.kwargs..., kwargs...)
end

function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed,
function __internal_solve_up(_prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, sensealg, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
return SciMLBase.__solve(prob, alg, args...; kwargs...)
Expand Down
68 changes: 68 additions & 0 deletions src/immutable_nonlinear_problem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
struct ImmutableNonlinearProblem{uType, isinplace, P, F, K, PT} <:
AbstractNonlinearProblem{uType, isinplace}
f::F
u0::uType
p::P
problem_type::PT
kwargs::K
@add_kwonly function ImmutableNonlinearProblem{iip}(f::AbstractNonlinearFunction{iip}, u0,
p = NullParameters(),
problem_type = StandardNonlinearProblem();
kwargs...) where {iip}
if haskey(kwargs, :p)
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to `NonlinearProblem`. This is not supported.")
end
warn_paramtype(p)
new{typeof(u0), iip, typeof(p), typeof(f),
typeof(kwargs), typeof(problem_type)}(f,
u0,
p,
problem_type,
kwargs)
end

"""

Define a steady state problem using the given function.
`isinplace` optionally sets whether the function is inplace or not.
This is determined automatically, but not inferred.
"""
function ImmutableNonlinearProblem{iip}(f, u0, p = NullParameters(); kwargs...) where {iip}
ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
end
end

"""

Define a nonlinear problem using an instance of
[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction).
"""
function ImmutableNonlinearProblem(f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
end

function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
end

"""

Define a ImmutableNonlinearProblem problem from SteadyStateProblem
"""
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p)
end

staticarray_itize(x) = x
staticarray_itize(x::Vector) = SVector{length(x)}(x)
staticarray_itize(x::SizedVector) = SVector{length(x)}(x)
staticarray_itize(x::Matrix) = SMatrix{size(x)...}(x)
staticarray_itize(x::SizedMatrix) = SMatrix{size(x)...}(x)

function Base.convert(::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
ImmutableNonlinearProblem{isinplace(prob)}(prob.f,
staticarray_itize(prob.u0),
staticarray_itize(prob.p),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't going to be type stable, we shouldn't copy this part. Don't automatically make things static array, just immutable the problem.

prob.problem_type;
prob.kwargs...)
end
2 changes: 1 addition & 1 deletion src/nlsolve/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ end

__get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS)

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleBroyden, args...;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of these unions should be required. If done correctly, the top level should change to an ImmutableNonlinearProblem, and then they should all be that.

abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real =
σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, η_strategy)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...;
function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleDFSane{M}, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false,
termination_condition = nothing, kwargs...) where {M}
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ A low-overhead implementation of Halley's Method.
autodiff = nothing
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleHalley, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ method is non-allocating on scalar and static array problems.
"""
struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...;
function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleKlement, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
6 changes: 3 additions & 3 deletions src/nlsolve/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function SimpleLimitedMemoryBroyden(;
return SimpleLimitedMemoryBroyden{_unwrap_val(threshold), _unwrap_val(linesearch)}(alpha)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleLimitedMemoryBroyden,
args...; termination_condition = nothing, kwargs...)
if prob.u0 isa SArray
if termination_condition === nothing ||
Expand All @@ -44,7 +44,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyd
return __generic_solve(prob, alg, args...; termination_condition, kwargs...)
end

@views function __generic_solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
@views function __generic_solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleLimitedMemoryBroyden,
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down Expand Up @@ -114,7 +114,7 @@ end
# Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite
# finicky, so we'll implement it separately from the generic version
# Ignore termination_condition. Don't pass things into internal functions
function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
function __static_solve(prob::Union{NonlinearProblem{<:SArray}, ImmutableNonlinearProblem{<:SArray}}, alg::SimpleLimitedMemoryBroyden,
args...; abstol = nothing, maxiters = 1000, kwargs...)
x = prob.u0
fx = _get_fx(prob, x)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ end

const SimpleGaussNewton = SimpleNewtonRaphson

function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
alg::SimpleNewtonRaphson, args...; abstol = nothing, reltol = nothing,
maxiters = 1000, termination_condition = nothing, alias_u0 = false, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ scalar and static array problems.
nlsolve_update_rule = Val(false)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args...;
function SciMLBase.__solve(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, alg::SimpleTrustRegion, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ end
error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype`")
return _get_fx(prob.f, x, prob.p)
end
@inline _get_fx(prob::NonlinearProblem, x) = _get_fx(prob.f, x, prob.p)
@inline _get_fx(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, x) = _get_fx(prob.f, x, prob.p)
@inline function _get_fx(f::NonlinearFunction, x, p)
if isinplace(f)
if f.resid_prototype !== nothing
Expand All @@ -145,7 +145,7 @@ end
# different. NonlinearSolve is more for robust / cached solvers while SimpleNonlinearSolve
# is meant for low overhead solvers, users can opt into the other termination modes but the
# default is to use the least overhead version.
function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing)
function init_termination_cache(prob::Union{NonlinearProblem, ImmutableNonlinearProblem}, abstol, reltol, du, u, ::Nothing)
return init_termination_cache(
prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix1(maximum, abs)))
end
Expand All @@ -155,7 +155,7 @@ function init_termination_cache(
prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix2(norm, 2)))
end

function init_termination_cache(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
function init_termination_cache(prob::Union{NonlinearProblem, ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
T = promote_type(eltype(du), eltype(u))
abstol = __get_tolerance(u, abstol, T)
Expand Down
1 change: 1 addition & 0 deletions test/gpu/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ end
end

prob = NonlinearProblem{false}(f, @SVector[1.0f0, 1.0f0])
prob = convert(SimpleNonlinearSolve.ImmutableNonlinearProblem, prob)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved

@testset "$(nameof(typeof(alg)))" for alg in (
SimpleNewtonRaphson(), SimpleDFSane(), SimpleTrustRegion(),
Expand Down
Loading