Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert to immutable for CUDA tests #151

Closed
wants to merge 12 commits into from
8 changes: 5 additions & 3 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati
mul!, norm, transpose
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
using Reexport: @reexport
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
AbstractNonlinearFunction, StandardNonlinearProblem,
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val
build_solution, isinplace, _unwrap_val, warn_paramtype
using Setfield: @set!
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size, SizedVector, SizedMatrix
end

const DI = DifferentiationInterface
Expand Down Expand Up @@ -60,6 +61,7 @@ include("bracketing/itp.jl")

# AD
include("ad.jl")
include("immutable_nonlinear_problem.jl")

## Default algorithm

Expand Down
68 changes: 68 additions & 0 deletions src/immutable_nonlinear_problem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
struct ImmutableNonlinearProblem{uType, isinplace, P, F, K, PT} <:
AbstractNonlinearProblem{uType, isinplace}
f::F
u0::uType
p::P
problem_type::PT
kwargs::K
@add_kwonly function ImmutableNonlinearProblem{iip}(f::AbstractNonlinearFunction{iip}, u0,
p = NullParameters(),
problem_type = StandardNonlinearProblem();
kwargs...) where {iip}
if haskey(kwargs, :p)
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to `NonlinearProblem`. This is not supported.")
end
warn_paramtype(p)
new{typeof(u0), iip, typeof(p), typeof(f),
typeof(kwargs), typeof(problem_type)}(f,
u0,
p,
problem_type,
kwargs)
end

"""

Define a steady state problem using the given function.
`isinplace` optionally sets whether the function is inplace or not.
This is determined automatically, but not inferred.
"""
function ImmutableNonlinearProblem{iip}(f, u0, p = NullParameters(); kwargs...) where {iip}
ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
end
end

"""

Define a nonlinear problem using an instance of
[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction).
"""
function ImmutableNonlinearProblem(f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
end

function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
end

"""

Define a ImmutableNonlinearProblem problem from SteadyStateProblem
"""
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p)
end

staticarray_itize(x) = x
staticarray_itize(x::Vector) = SVector{length(x)}(x)
staticarray_itize(x::SizedVector) = SVector{length(x)}(x)
staticarray_itize(x::Matrix) = SMatrix{size(x)...}(x)
staticarray_itize(x::SizedMatrix) = SMatrix{size(x)...}(x)

function Base.convert(::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
ImmutableNonlinearProblem(prob.f,
m-bossart marked this conversation as resolved.
Show resolved Hide resolved
staticarray_itize(prob.u0),
staticarray_itize(prob.p),
prob.problem_type;
prob.kwargs...)
end
1 change: 1 addition & 0 deletions test/gpu/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ end
end

prob = NonlinearProblem{false}(f, @SVector[1.0f0, 1.0f0])
prob = convert(SimpleNonlinearSolve.ImmutableNonlinearProblem, prob)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved

@testset "$(nameof(typeof(alg)))" for alg in (
SimpleNewtonRaphson(), SimpleDFSane(), SimpleTrustRegion(),
Expand Down