Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamp working of Batched Solvers #68

Merged
merged 13 commits into from
Jul 18, 2023
22 changes: 7 additions & 15 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,46 +1,38 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "0.1.16"
version = "0.1.17"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

[extensions]
SimpleBatchedNonlinearSolveExt = "NNlib"
SimpleNonlinearSolveNNlibExt = "NNlib"

[compat]
ArrayInterface = "6, 7"
DiffEqBase = "6.123.0"
DiffEqBase = "6.126"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
NNlib = "0.8, 0.9"
PackageExtensionCompat = "1"
PrecompileTools = "1"
Reexport = "0.2, 1"
Requires = "1"
SciMLBase = "1.73"
PrecompileTools = "1"
StaticArraysCore = "1.4"
julia = "1.6"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib"]
90 changes: 0 additions & 90 deletions ext/SimpleBatchedNonlinearSolveExt.jl

This file was deleted.

81 changes: 81 additions & 0 deletions ext/SimpleNonlinearSolveNNlibExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
module SimpleNonlinearSolveNNlibExt

using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase
import SimpleNonlinearSolve: _construct_batched_problem_structure,
_get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace

function __init__()
SimpleNonlinearSolve.NNlibExtLoaded[] = true
return
end

@views function SciMLBase.__solve(prob::NonlinearProblem,
alg::BatchedBroyden;
abstol = nothing,
reltol = nothing,
maxiters = 1000,
kwargs...)
iip = isinplace(prob)

u, f, reconstruct = _construct_batched_problem_structure(prob)
L, N = size(u)

tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)

storage = _get_storage(mode, u)

xₙ, xₙ₋₁, δx, δf = ntuple(_ -> copy(u), 4)
T = eltype(u)

atol = _get_tolerance(abstol, tc.abstol, T)
rtol = _get_tolerance(reltol, tc.reltol, T)
termination_condition = tc(storage)

𝓙⁻¹ = _init_𝓙(xₙ) # L × L × N
𝓙⁻¹f, xᵀ𝓙⁻¹δf, xᵀ𝓙⁻¹ = similar(𝓙⁻¹, L, N), similar(𝓙⁻¹, 1, N), similar(𝓙⁻¹, 1, L, N)

@maybeinplace iip fₙ₋₁=f(xₙ) u
iip && (fₙ = copy(fₙ₋₁))
for n in 1:maxiters
batched_mul!(reshape(𝓙⁻¹f, L, 1, N), 𝓙⁻¹, reshape(fₙ₋₁, L, 1, N))
xₙ .= xₙ₋₁ .- 𝓙⁻¹f

@maybeinplace iip fₙ=f(xₙ)
δx .= xₙ .- xₙ₋₁
δf .= fₙ .- fₙ₋₁

batched_mul!(reshape(𝓙⁻¹f, L, 1, N), 𝓙⁻¹, reshape(δf, L, 1, N))
δxᵀ = reshape(δx, 1, L, N)

batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxᵀ, reshape(𝓙⁻¹f, L, 1, N))
batched_mul!(xᵀ𝓙⁻¹, δxᵀ, 𝓙⁻¹)
δx .= (δx .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5))
batched_mul!(𝓙⁻¹, reshape(δx, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T))

if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip)
return DiffEqBase.build_solution(prob,
alg,
reconstruct(xₙ),
reconstruct(fₙ);
retcode)
end

xₙ₋₁ .= xₙ
fₙ₋₁ .= fₙ
end

if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES
xₙ = storage.u
@maybeinplace iip fₙ=f(xₙ)

Check warning on line 71 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L69-L71

Added lines #L69 - L71 were not covered by tests
end

return DiffEqBase.build_solution(prob,

Check warning on line 74 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L74

Added line #L74 was not covered by tests
alg,
reconstruct(xₙ),
reconstruct(fₙ);
retcode = ReturnCode.MaxIters)
end

end
22 changes: 13 additions & 9 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,19 @@ using DiffEqBase

@reexport using SciMLBase

if !isdefined(Base, :get_extension)
using Requires
end

using PackageExtensionCompat
function __init__()
@static if !isdefined(Base, :get_extension)
@require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin
include("../ext/SimpleBatchedNonlinearSolveExt.jl")
end
end
@require_extensions
end

const NNlibExtLoaded = Ref{Bool}(false)

abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end
abstract type AbstractImmutableNonlinearSolver <: AbstractSimpleNonlinearSolveAlgorithm end
abstract type AbstractBatchedNonlinearSolveAlgorithm <:
AbstractSimpleNonlinearSolveAlgorithm end

include("utils.jl")
include("bisection.jl")
Expand All @@ -42,6 +39,12 @@ include("ad.jl")
include("halley.jl")
include("alefeld.jl")

# Batched Solver Support
include("batched/utils.jl")
include("batched/raphson.jl")
include("batched/dfsane.jl")
include("batched/broyden.jl")

import PrecompileTools

PrecompileTools.@compile_workload begin
Expand Down Expand Up @@ -74,5 +77,6 @@ end
# DiffEq styled algorithms
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement,
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld
export BatchedBroyden, BatchedSimpleNewtonRaphson, BatchedSimpleDFSane

end # module
6 changes: 6 additions & 0 deletions src/batched/broyden.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
struct BatchedBroyden{TC <: NLSolveTerminationCondition} <:
AbstractBatchedNonlinearSolveAlgorithm
termination_condition::TC
end

# Implementation of solve using Package Extensions
Loading
Loading