From d4936bc5a9a5af27c006de120beea222efc6d49f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 23 Nov 2023 22:51:03 -0500 Subject: [PATCH] Move out the @bb macro into a separate package --- Project.toml | 1 + src/SimpleNonlinearSolve.jl | 10 +-- src/bracketing/ridder.jl | 2 +- src/nlsolve/klement.jl | 8 +- src/rewrite_inplace.jl | 161 ------------------------------------ 5 files changed, 10 insertions(+), 172 deletions(-) delete mode 100644 src/rewrite_inplace.jl diff --git a/Project.toml b/Project.toml index 75af934..8e6b0f5 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index ab7026b..cdfc95b 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -4,28 +4,25 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat @recompile_invalidations begin using ADTypes, - ArrayInterface, ConcreteStructs, DiffEqBase, Reexport, LinearAlgebra, - SciMLBase + ArrayInterface, ConcreteStructs, DiffEqBase, Reexport, LinearAlgebra, SciMLBase import DiffEqBase: AbstractNonlinearTerminationMode, AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode, NonlinearSafeTerminationReturnCode, get_termination_mode using FiniteDiff, ForwardDiff import ForwardDiff: Dual + import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray end @reexport using ADTypes, SciMLBase -# const NNlibExtLoaded = Ref{Bool}(false) - abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end include("utils.jl") -include("rewrite_inplace.jl") # Nonlinear Solvera include("nlsolve/raphson.jl") @@ -50,7 +47,6 @@ include("bracketing/itp.jl") ## Default algorithm # Set the default bracketing method to ITP - function SciMLBase.solve(prob::IntervalNonlinearProblem; kwargs...) return solve(prob, ITP(); kwargs...) end @@ -60,8 +56,6 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, return solve(prob, ITP(), args...; kwargs...) end -# import PrecompileTools - @setup_workload begin for T in (Float32, Float64) # prob_no_brack = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) diff --git a/src/bracketing/ridder.jl b/src/bracketing/ridder.jl index 11b7604..20e0db4 100644 --- a/src/bracketing/ridder.jl +++ b/src/bracketing/ridder.jl @@ -70,7 +70,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...; sol, i, left, right, fl, fr = __bisection(left, right, fl, fr, f; abstol, maxiters = maxiters - i, prob, alg) - sol !== nothing && return sol + sol !== nothing && return sol return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters, left, right) diff --git a/src/nlsolve/klement.jl b/src/nlsolve/klement.jl index 7b9a878..56d6ccd 100644 --- a/src/nlsolve/klement.jl +++ b/src/nlsolve/klement.jl @@ -54,7 +54,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...; end @bb copyto!(δx, fprev) - δx = __ldiv!!(F_, δx) + if setindex_trait(δx) === CanSetindex() + ldiv!(F_, δx) + else + δx = F_ \ δx + end @bb @. x = xo - δx fx = __eval_f(prob, fx, x) @@ -74,7 +78,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...; @bb δx² = J × vec(δx) @bb @. δf = (δf - δx²) / d - _vδf, _vδx = vec(δf), vec(δx) + _vδf, _vδx = _vec(δf), _vec(δx) @bb J_cache = _vδf × transpose(_vδx) @bb @. J_cache *= J @bb J_cache2 = J_cache × J diff --git a/src/rewrite_inplace.jl b/src/rewrite_inplace.jl deleted file mode 100644 index f0d80af..0000000 --- a/src/rewrite_inplace.jl +++ /dev/null @@ -1,161 +0,0 @@ -# Take a inplace code and rewrite it to be maybe-inplace -# I will take this code out into a separate package because this is useful even in -# NonlinearSolve.jl -function __bangbang(M, expr; depth = 1) - new_expr = nothing - if expr.head == :call - @assert length(expr.args)≥2 "Expected a function call with atleast 1 argument. \ - Got `$(expr)`." - f, a, args... = expr.args - g = get(OP_MAPPING, f, nothing) - if f == :copy && length(args) == 0 - # Special case for copy with single argument - new_expr = :($(g)($(setindex_trait)($(a)), $(a))) - elseif g !== nothing - new_expr = :($(a) = $(g)($(setindex_trait)($(a)), $(a), $(args...))) - end - elseif expr.head == :(=) - a, rhs_expr = expr.args - if rhs_expr.head == :call - f, b, args... = rhs_expr.args - g = get(OP_MAPPING, f, nothing) - if g !== nothing - new_expr = :($(a) = $(g)($(setindex_trait)($(b)), $(b), $(args...))) - elseif f == :× - @debug "Custom operator `×` detected in `$(expr)`." - c, args... = args - @assert length(args)==0 "Expected `×` to have only 2 arguments. \ - Got `$(expr)`." - is_b_vec = b isa Expr && b.head == :call && b.args[1] == :vec - is_c_vec = c isa Expr && c.head == :call && c.args[1] == :vec - a_sym = gensym("a") - if is_b_vec - if is_c_vec - error("2 `vec`s detected with `×` in `$(expr)`. Use `dot` instead.") - else - new_expr = quote - if $(setindex_trait)($(a)) === CanSetindex() - $(a_sym) = $(_vec)($a) - mul!($(a_sym), $(_vec)($(b.args[2])), $(c)) - $(a) = $(_restructure)($a, $(a_sym)) - else - $(a) = $(_restructure)($a, $(_vec)($(b.args[2])) * $(c)) - end - end - end - else - if is_c_vec - new_expr = quote - if $(setindex_trait)($(a)) === CanSetindex() - $(a_sym) = $(_vec)($a) - mul!($(a), $(b), $(_vec)($(c.args[2]))) - $(a) = $(_restructure)($a, $(a_sym)) - else - $(a) = $(_restructure)($a, $(b) * $(_vec)($(c.args[2]))) - end - end - else - new_expr = quote - if $(setindex_trait)($(a)) === CanSetindex() - mul!($(a), $(b), $(c)) - else - $(a) = $(b) * $(c) - end - end - end - end - end - end - elseif expr.head == :(.=) - a, rhs_expr = expr.args - if rhs_expr isa Expr && rhs_expr.head == :(.) - f, arg_expr = rhs_expr.args - # f_broadcast = :(Base.Broadcast.BroadcastFunction($(f))) - new_expr = quote - if $(setindex_trait)($(a)) === CanSetindex() - broadcast!($(f), $(a), $(arg_expr)...) - else - $(a) = broadcast($(f), $(arg_expr)...) - end - end - end - elseif expr.head == :macrocall - # For @__dot__ there is a easier alternative - if expr.args[1] == Symbol("@__dot__") - main_expr = last(expr.args) - if main_expr isa Expr && main_expr.head == :(=) - a, rhs_expr = main_expr.args - new_expr = quote - if $(setindex_trait)($(a)) === CanSetindex() - @. $(main_expr) - else - $(a) = @. $(rhs_expr) - end - end - end - end - if new_expr === nothing - new_expr = __bangbang(M, Base.macroexpand(M, expr; recursive = true); - depth = depth + 1) - end - else - f = expr.head # Things like :.-=, etc. - a, args... = expr.args - g = get(OP_MAPPING, f, nothing) - if g !== nothing - new_expr = :($(a) = $(g)($(setindex_trait)($(a)), $(a), $(args...))) - end - end - if new_expr !== nothing - if depth == 1 - @debug "Replacing `$(expr)` with `$(new_expr)`" - return esc(new_expr) - else - return new_expr - end - end - error("`$(expr)` cannot be handled. Check the documentation for allowed expressions.") -end - -macro bangbang(expr) - return __bangbang(__module__, expr) -end - -# `bb` is the short form of bang-bang -macro bb(expr) - return __bangbang(__module__, expr) -end - -# Is Mutable or Not? -abstract type AbstractMaybeSetindex end -struct CannotSetindex <: AbstractMaybeSetindex end -struct CanSetindex <: AbstractMaybeSetindex end - -# Common types should overload this via extensions, else it butchers type-inference -setindex_trait(::Union{Number, SArray}) = CannotSetindex() -setindex_trait(::Union{MArray, Array}) = CanSetindex() -setindex_trait(A) = ifelse(ArrayInterface.can_setindex(A), CanSetindex(), CannotSetindex()) - -# Operations -const OP_MAPPING = Dict{Symbol, Symbol}(:copyto! => :__copyto!!, - :.-= => :__sub!!, - :.+= => :__add!!, - :.*= => :__mul!!, - :./= => :__div!!, - :copy => :__copy) - -@inline __copyto!!(::CannotSetindex, x, y) = y -@inline __copyto!!(::CanSetindex, x, y) = (copyto!(x, y); x) - -@inline __broadcast!!(::CannotSetindex, op, x, args...) = broadcast(op, args...) -@inline __broadcast!!(::CanSetindex, op, x, args...) = (broadcast!(op, x, args...); x) - -@inline __sub!!(S, x, args...) = __broadcast!!(S, -, x, x, args...) -@inline __add!!(S, x, args...) = __broadcast!!(S, +, x, x, args...) -@inline __mul!!(S, x, args...) = __broadcast!!(S, *, x, x, args...) -@inline __div!!(S, x, args...) = __broadcast!!(S, /, x, x, args...) - -@inline __copy(::CannotSetindex, x) = x -@inline __copy(::CanSetindex, x) = copy(x) -@inline __copy(::CannotSetindex, x, y) = y -@inline __copy(::CanSetindex, x, y) = copy(y)