Skip to content

Commit

Permalink
LBroyden
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 27, 2023
1 parent 409d36c commit 061ec7d
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 119 deletions.
25 changes: 13 additions & 12 deletions ext/SimpleNonlinearSolveADLinearSolveExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
module SimpleNonlinearSolveADLinearSolveExt

using AbstractDifferentiation, ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve,
using AbstractDifferentiation,
ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve,
SimpleNonlinearSolve, SciMLBase
import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _result_from_storage, _get_tolerance, @maybeinplace
import SimpleNonlinearSolve: _construct_batched_problem_structure,
_get_storage, _result_from_storage, _get_tolerance, @maybeinplace

const AD = AbstractDifferentiation

Expand All @@ -20,19 +22,18 @@ function SimpleNonlinearSolve.SimpleBatchedNewtonRaphson(; chunk_size = Val{0}()
# TODO: Use `diff_type`. FiniteDiff.jl is currently not available in AD.jl
chunksize = SciMLBase._unwrap_val(chunk_size) == 0 ? nothing : chunk_size
ad = SciMLBase._unwrap_val(autodiff) ?

Check warning on line 24 in ext/SimpleNonlinearSolveADLinearSolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveADLinearSolveExt.jl#L23-L24

Added lines #L23 - L24 were not covered by tests
AD.ForwardDiffBackend(; chunksize) :
AD.FiniteDifferencesBackend()
return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(
ad,
AD.ForwardDiffBackend(; chunksize) :
AD.FiniteDifferencesBackend()
return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(ad,

Check warning on line 27 in ext/SimpleNonlinearSolveADLinearSolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveADLinearSolveExt.jl#L27

Added line #L27 was not covered by tests
nothing,
termination_condition)
end

function SciMLBase.__solve(prob::NonlinearProblem,

Check warning on line 32 in ext/SimpleNonlinearSolveADLinearSolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveADLinearSolveExt.jl#L32

Added line #L32 was not covered by tests
alg::SimpleBatchedNewtonRaphson;
abstol=nothing,
reltol=nothing,
maxiters=1000,
abstol = nothing,
reltol = nothing,
maxiters = 1000,
kwargs...)
iip = isinplace(prob)
@assert !iip "SimpleBatchedNewtonRaphson currently only supports out-of-place nonlinear problems."
Expand All @@ -57,9 +58,9 @@ function SciMLBase.__solve(prob::NonlinearProblem,
alg,
reconstruct(xₙ),
reconstruct(fₙ);
retcode=ReturnCode.Success)
retcode = ReturnCode.Success)

solve(LinearProblem(𝓙, vec(fₙ); u0=vec(δx)), alg.linsolve; kwargs...)
solve(LinearProblem(𝓙, vec(fₙ); u0 = vec(δx)), alg.linsolve; kwargs...)
xₙ .-= δx

Check warning on line 64 in ext/SimpleNonlinearSolveADLinearSolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveADLinearSolveExt.jl#L63-L64

Added lines #L63 - L64 were not covered by tests

if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
Expand All @@ -83,7 +84,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
alg,
reconstruct(xₙ),
reconstruct(fₙ);
retcode=ReturnCode.MaxIters)
retcode = ReturnCode.MaxIters)
end

end
122 changes: 110 additions & 12 deletions ext/SimpleNonlinearSolveNNlibExt.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
module SimpleNonlinearSolveNNlibExt

using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase
import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace
import SimpleNonlinearSolve: _construct_batched_problem_structure,
_get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace

function __init__()
SimpleNonlinearSolve.NNlibExtLoaded[] = true
return

Check warning on line 9 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L7-L9

Added lines #L7 - L9 were not covered by tests
end

# Broyden's method
@views function SciMLBase.__solve(prob::NonlinearProblem,

Check warning on line 13 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L13

Added line #L13 was not covered by tests
alg::BatchedBroyden;
abstol=nothing,
reltol=nothing,
maxiters=1000,
abstol = nothing,
reltol = nothing,
maxiters = 1000,
kwargs...)
iip = isinplace(prob)

Check warning on line 19 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L19

Added line #L19 was not covered by tests

Expand All @@ -24,7 +26,7 @@ end

storage = _get_storage(mode, u)

Check warning on line 27 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L27

Added line #L27 was not covered by tests

xₙ, xₙ₋₁, δx, δf = ntuple(_ -> copy(u), 4)
xₙ, xₙ₋₁, δxₙ, δf = ntuple(_ -> copy(u), 4)
T = eltype(u)

Check warning on line 30 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L29-L30

Added lines #L29 - L30 were not covered by tests

atol = _get_tolerance(abstol, tc.abstol, T)
Expand All @@ -41,16 +43,16 @@ end
xₙ .= xₙ₋₁ .- 𝓙⁻¹f

Check warning on line 43 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L39-L43

Added lines #L39 - L43 were not covered by tests

@maybeinplace iip fₙ=f(xₙ)
δx .= xₙ .- xₙ₋₁
δxₙ .= xₙ .- xₙ₋₁
δf .= fₙ .- fₙ₋₁

Check warning on line 47 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L45-L47

Added lines #L45 - L47 were not covered by tests

batched_mul!(reshape(𝓙⁻¹f, L, 1, N), 𝓙⁻¹, reshape(δf, L, 1, N))
δxᵀ = reshape(δx, 1, L, N)
δxₙᵀ = reshape(δxₙ, 1, L, N)

Check warning on line 50 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L49-L50

Added lines #L49 - L50 were not covered by tests

batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxᵀ, reshape(𝓙⁻¹f, L, 1, N))
batched_mul!(xᵀ𝓙⁻¹, δxᵀ, 𝓙⁻¹)
δx .= (δx .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5))
batched_mul!(𝓙⁻¹, reshape(δx, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T))
batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxₙᵀ, reshape(𝓙⁻¹f, L, 1, N))
batched_mul!(xᵀ𝓙⁻¹, δxₙᵀ, 𝓙⁻¹)
δxₙ .= (δxₙ .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5))
batched_mul!(𝓙⁻¹, reshape(δxₙ, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T))

Check warning on line 55 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L52-L55

Added lines #L52 - L55 were not covered by tests

if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip)
Expand All @@ -74,7 +76,103 @@ end
alg,
reconstruct(xₙ),
reconstruct(fₙ);
retcode=ReturnCode.MaxIters)
retcode = ReturnCode.MaxIters)
end

# Limited Memory Broyden's method
@views function SciMLBase.__solve(prob::NonlinearProblem,

Check warning on line 83 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L83

Added line #L83 was not covered by tests
alg::BatchedLBroyden;
abstol = nothing,
reltol = nothing,
maxiters = 1000,
kwargs...)
iip = isinplace(prob)

Check warning on line 89 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L89

Added line #L89 was not covered by tests

u, f, reconstruct = _construct_batched_problem_structure(prob)
L, N = size(u)
T = eltype(u)

Check warning on line 93 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L91-L93

Added lines #L91 - L93 were not covered by tests

tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)

Check warning on line 96 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L95-L96

Added lines #L95 - L96 were not covered by tests

storage = _get_storage(mode, u)

Check warning on line 98 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L98

Added line #L98 was not covered by tests

η = min(maxiters, alg.threshold)
U = fill!(similar(u, (η, L, N)), zero(T))
Vᵀ = fill!(similar(u, (L, η, N)), zero(T))

Check warning on line 102 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L100-L102

Added lines #L100 - L102 were not covered by tests

xₙ, xₙ₋₁, δfₙ = ntuple(_ -> copy(u), 3)

Check warning on line 104 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L104

Added line #L104 was not covered by tests

atol = _get_tolerance(abstol, tc.abstol, T)
rtol = _get_tolerance(reltol, tc.reltol, T)
termination_condition = tc(storage)

Check warning on line 108 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L106-L108

Added lines #L106 - L108 were not covered by tests

@maybeinplace iip fₙ₋₁=f(xₙ) u
iip && (fₙ = copy(fₙ₋₁))
δxₙ = -copy(fₙ₋₁)
ηNx = similar(xₙ, η, N)

Check warning on line 113 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L110-L113

Added lines #L110 - L113 were not covered by tests

for i in 1:maxiters
@. xₙ = xₙ₋₁ - δxₙ
@maybeinplace iip fₙ=f(xₙ)
@. δxₙ = xₙ - xₙ₋₁
@. δfₙ = fₙ - fₙ₋₁

Check warning on line 119 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L115-L119

Added lines #L115 - L119 were not covered by tests

if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip)
return DiffEqBase.build_solution(prob,

Check warning on line 123 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L121-L123

Added lines #L121 - L123 were not covered by tests
alg,
reconstruct(xₙ),
reconstruct(fₙ);
retcode)
end

_L = min(i, η)
_U = U[1:_L, :, :]
_Vᵀ = Vᵀ[:, 1:_L, :]

Check warning on line 132 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L130-L132

Added lines #L130 - L132 were not covered by tests

idx = mod1(i, η)

Check warning on line 134 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L134

Added line #L134 was not covered by tests

if i > 1
partial_ηNx = ηNx[1:_L, :]

Check warning on line 137 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L136-L137

Added lines #L136 - L137 were not covered by tests

_ηNx = reshape(partial_ηNx, 1, :, N)
batched_mul!(_ηNx, reshape(δxₙ, 1, L, N), _Vᵀ)
batched_mul!(Vᵀ[:, idx:idx, :], _ηNx, _U)
Vᵀ[:, idx, :] .-= δxₙ

Check warning on line 142 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L139-L142

Added lines #L139 - L142 were not covered by tests

_ηNx = reshape(partial_ηNx, :, 1, N)
batched_mul!(_ηNx, _U, reshape(δfₙ, L, 1, N))
batched_mul!(U[idx:idx, :, :], _Vᵀ, _ηNx)
U[idx, :, :] .-= δfₙ

Check warning on line 147 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L144-L147

Added lines #L144 - L147 were not covered by tests
else
Vᵀ[:, idx, :] .= -δxₙ
U[idx, :, :] .= -δfₙ

Check warning on line 150 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L149-L150

Added lines #L149 - L150 were not covered by tests
end

U[idx, :, :] .= (δxₙ .- U[idx, :, :]) ./

Check warning on line 153 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L153

Added line #L153 was not covered by tests
(sum(Vᵀ[:, idx, :] .* δfₙ; dims = 1) .+
convert(T, 1e-5))

_L = min(i + 1, η)
_ηNx = reshape(ηNx[1:_L, :], :, 1, N)
batched_mul!(_ηNx, U[1:_L, :, :], reshape(δfₙ, L, 1, N))
batched_mul!(reshape(δxₙ, L, 1, N), Vᵀ[:, 1:_L, :], _ηNx)

Check warning on line 160 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L157-L160

Added lines #L157 - L160 were not covered by tests

xₙ₋₁ .= xₙ
fₙ₋₁ .= fₙ
end

Check warning on line 164 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L162-L164

Added lines #L162 - L164 were not covered by tests

if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
xₙ = storage.u
@maybeinplace iip fₙ=f(xₙ)

Check warning on line 168 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L166-L168

Added lines #L166 - L168 were not covered by tests
end

return DiffEqBase.build_solution(prob,

Check warning on line 171 in ext/SimpleNonlinearSolveNNlibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveNNlibExt.jl#L171

Added line #L171 was not covered by tests
alg,
reconstruct(xₙ),
reconstruct(fₙ);
retcode = ReturnCode.MaxIters)
end

end
2 changes: 1 addition & 1 deletion src/batched/dfsane.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <:
Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <:
AbstractBatchedNonlinearSolveAlgorithm
σₘᵢₙ::T = 1.0f-10
σₘₐₓ::T = 1.0f+10
Expand Down
7 changes: 7 additions & 0 deletions src/batched/lbroyden.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
struct BatchedLBroyden{TC <: NLSolveTerminationCondition} <:
AbstractBatchedNonlinearSolveAlgorithm
termination_condition::TC
threshold::Int
end

# Implementation of solve using Package Extensions
10 changes: 4 additions & 6 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ end

function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...;

Check warning on line 31 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L31

Added line #L31 was not covered by tests
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
if SciMLBase.isinplace(prob)
error("Broyden currently only supports out-of-place nonlinear problems")

Check warning on line 34 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L33-L34

Added lines #L33 - L34 were not covered by tests
end
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)
f = Base.Fix2(prob.f, prob.p)
Expand All @@ -39,19 +42,14 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...;
T = eltype(x)
J⁻¹ = init_J(x)

if SciMLBase.isinplace(prob)
error("Broyden currently only supports out-of-place nonlinear problems")
end

atol = _get_tolerance(abstol, tc.abstol, T)
rtol = _get_tolerance(reltol, tc.reltol, T)

Check warning on line 46 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L45-L46

Added lines #L45 - L46 were not covered by tests

if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
error("Broyden currently doesn't support SAFE_BEST termination modes")
end

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing
storage = _get_storage(mode, x)

Check warning on line 52 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L52

Added line #L52 was not covered by tests
termination_condition = tc(storage)

xₙ = x
Expand Down
Loading

0 comments on commit 061ec7d

Please sign in to comment.