diff --git a/ext/SimpleNonlinearSolveADLinearSolveExt.jl b/ext/SimpleNonlinearSolveADLinearSolveExt.jl index 96ba916..d0a97eb 100644 --- a/ext/SimpleNonlinearSolveADLinearSolveExt.jl +++ b/ext/SimpleNonlinearSolveADLinearSolveExt.jl @@ -1,8 +1,10 @@ module SimpleNonlinearSolveADLinearSolveExt -using AbstractDifferentiation, ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve, +using AbstractDifferentiation, + ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve, SimpleNonlinearSolve, SciMLBase -import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _result_from_storage, _get_tolerance, @maybeinplace +import SimpleNonlinearSolve: _construct_batched_problem_structure, + _get_storage, _result_from_storage, _get_tolerance, @maybeinplace const AD = AbstractDifferentiation @@ -20,19 +22,18 @@ function SimpleNonlinearSolve.SimpleBatchedNewtonRaphson(; chunk_size = Val{0}() # TODO: Use `diff_type`. FiniteDiff.jl is currently not available in AD.jl chunksize = SciMLBase._unwrap_val(chunk_size) == 0 ? nothing : chunk_size ad = SciMLBase._unwrap_val(autodiff) ? - AD.ForwardDiffBackend(; chunksize) : - AD.FiniteDifferencesBackend() - return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}( - ad, + AD.ForwardDiffBackend(; chunksize) : + AD.FiniteDifferencesBackend() + return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(ad, nothing, termination_condition) end function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBatchedNewtonRaphson; - abstol=nothing, - reltol=nothing, - maxiters=1000, + abstol = nothing, + reltol = nothing, + maxiters = 1000, kwargs...) iip = isinplace(prob) @assert !iip "SimpleBatchedNewtonRaphson currently only supports out-of-place nonlinear problems." @@ -57,9 +58,9 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg, reconstruct(xₙ), reconstruct(fₙ); - retcode=ReturnCode.Success) + retcode = ReturnCode.Success) - solve(LinearProblem(𝓙, vec(fₙ); u0=vec(δx)), alg.linsolve; kwargs...) + solve(LinearProblem(𝓙, vec(fₙ); u0 = vec(δx)), alg.linsolve; kwargs...) xₙ .-= δx if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) @@ -83,7 +84,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg, reconstruct(xₙ), reconstruct(fₙ); - retcode=ReturnCode.MaxIters) + retcode = ReturnCode.MaxIters) end end diff --git a/ext/SimpleNonlinearSolveNNlibExt.jl b/ext/SimpleNonlinearSolveNNlibExt.jl index e62e2bd..c0faefd 100644 --- a/ext/SimpleNonlinearSolveNNlibExt.jl +++ b/ext/SimpleNonlinearSolveNNlibExt.jl @@ -1,18 +1,20 @@ module SimpleNonlinearSolveNNlibExt using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase -import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace +import SimpleNonlinearSolve: _construct_batched_problem_structure, + _get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace function __init__() SimpleNonlinearSolve.NNlibExtLoaded[] = true return end +# Broyden's method @views function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedBroyden; - abstol=nothing, - reltol=nothing, - maxiters=1000, + abstol = nothing, + reltol = nothing, + maxiters = 1000, kwargs...) iip = isinplace(prob) @@ -24,7 +26,7 @@ end storage = _get_storage(mode, u) - xₙ, xₙ₋₁, δx, δf = ntuple(_ -> copy(u), 4) + xₙ, xₙ₋₁, δxₙ, δf = ntuple(_ -> copy(u), 4) T = eltype(u) atol = _get_tolerance(abstol, tc.abstol, T) @@ -41,16 +43,16 @@ end xₙ .= xₙ₋₁ .- 𝓙⁻¹f @maybeinplace iip fₙ=f(xₙ) - δx .= xₙ .- xₙ₋₁ + δxₙ .= xₙ .- xₙ₋₁ δf .= fₙ .- fₙ₋₁ batched_mul!(reshape(𝓙⁻¹f, L, 1, N), 𝓙⁻¹, reshape(δf, L, 1, N)) - δxᵀ = reshape(δx, 1, L, N) + δxₙᵀ = reshape(δxₙ, 1, L, N) - batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxᵀ, reshape(𝓙⁻¹f, L, 1, N)) - batched_mul!(xᵀ𝓙⁻¹, δxᵀ, 𝓙⁻¹) - δx .= (δx .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5)) - batched_mul!(𝓙⁻¹, reshape(δx, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T)) + batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxₙᵀ, reshape(𝓙⁻¹f, L, 1, N)) + batched_mul!(xᵀ𝓙⁻¹, δxₙᵀ, 𝓙⁻¹) + δxₙ .= (δxₙ .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5)) + batched_mul!(𝓙⁻¹, reshape(δxₙ, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T)) if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip) @@ -74,7 +76,103 @@ end alg, reconstruct(xₙ), reconstruct(fₙ); - retcode=ReturnCode.MaxIters) + retcode = ReturnCode.MaxIters) +end + +# Limited Memory Broyden's method +@views function SciMLBase.__solve(prob::NonlinearProblem, + alg::BatchedLBroyden; + abstol = nothing, + reltol = nothing, + maxiters = 1000, + kwargs...) + iip = isinplace(prob) + + u, f, reconstruct = _construct_batched_problem_structure(prob) + L, N = size(u) + T = eltype(u) + + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) + + storage = _get_storage(mode, u) + + η = min(maxiters, alg.threshold) + U = fill!(similar(u, (η, L, N)), zero(T)) + Vᵀ = fill!(similar(u, (L, η, N)), zero(T)) + + xₙ, xₙ₋₁, δfₙ = ntuple(_ -> copy(u), 3) + + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) + termination_condition = tc(storage) + + @maybeinplace iip fₙ₋₁=f(xₙ) u + iip && (fₙ = copy(fₙ₋₁)) + δxₙ = -copy(fₙ₋₁) + ηNx = similar(xₙ, η, N) + + for i in 1:maxiters + @. xₙ = xₙ₋₁ - δxₙ + @maybeinplace iip fₙ=f(xₙ) + @. δxₙ = xₙ - xₙ₋₁ + @. δfₙ = fₙ - fₙ₋₁ + + if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) + retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip) + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xₙ), + reconstruct(fₙ); + retcode) + end + + _L = min(i, η) + _U = U[1:_L, :, :] + _Vᵀ = Vᵀ[:, 1:_L, :] + + idx = mod1(i, η) + + if i > 1 + partial_ηNx = ηNx[1:_L, :] + + _ηNx = reshape(partial_ηNx, 1, :, N) + batched_mul!(_ηNx, reshape(δxₙ, 1, L, N), _Vᵀ) + batched_mul!(Vᵀ[:, idx:idx, :], _ηNx, _U) + Vᵀ[:, idx, :] .-= δxₙ + + _ηNx = reshape(partial_ηNx, :, 1, N) + batched_mul!(_ηNx, _U, reshape(δfₙ, L, 1, N)) + batched_mul!(U[idx:idx, :, :], _Vᵀ, _ηNx) + U[idx, :, :] .-= δfₙ + else + Vᵀ[:, idx, :] .= -δxₙ + U[idx, :, :] .= -δfₙ + end + + U[idx, :, :] .= (δxₙ .- U[idx, :, :]) ./ + (sum(Vᵀ[:, idx, :] .* δfₙ; dims = 1) .+ + convert(T, 1e-5)) + + _L = min(i + 1, η) + _ηNx = reshape(ηNx[1:_L, :], :, 1, N) + batched_mul!(_ηNx, U[1:_L, :, :], reshape(δfₙ, L, 1, N)) + batched_mul!(reshape(δxₙ, L, 1, N), Vᵀ[:, 1:_L, :], _ηNx) + + xₙ₋₁ .= xₙ + fₙ₋₁ .= fₙ + end + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + xₙ = storage.u + @maybeinplace iip fₙ=f(xₙ) + end + + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xₙ), + reconstruct(fₙ); + retcode = ReturnCode.MaxIters) end end diff --git a/src/batched/dfsane.jl b/src/batched/dfsane.jl index fe7cbcd..88f02eb 100644 --- a/src/batched/dfsane.jl +++ b/src/batched/dfsane.jl @@ -1,4 +1,4 @@ -@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <: +Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <: AbstractBatchedNonlinearSolveAlgorithm σₘᵢₙ::T = 1.0f-10 σₘₐₓ::T = 1.0f+10 diff --git a/src/batched/lbroyden.jl b/src/batched/lbroyden.jl index e69de29..5934c8f 100644 --- a/src/batched/lbroyden.jl +++ b/src/batched/lbroyden.jl @@ -0,0 +1,7 @@ +struct BatchedLBroyden{TC <: NLSolveTerminationCondition} <: + AbstractBatchedNonlinearSolveAlgorithm + termination_condition::TC + threshold::Int +end + +# Implementation of solve using Package Extensions \ No newline at end of file diff --git a/src/broyden.jl b/src/broyden.jl index 6c5c3ce..adf94b0 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -30,6 +30,9 @@ end function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) + if SciMLBase.isinplace(prob) + error("Broyden currently only supports out-of-place nonlinear problems") + end tc = alg.termination_condition mode = DiffEqBase.get_termination_mode(tc) f = Base.Fix2(prob.f, prob.p) @@ -39,10 +42,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; T = eltype(x) J⁻¹ = init_J(x) - if SciMLBase.isinplace(prob) - error("Broyden currently only supports out-of-place nonlinear problems") - end - atol = _get_tolerance(abstol, tc.abstol, T) rtol = _get_tolerance(reltol, tc.reltol, T) @@ -50,8 +49,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; error("Broyden currently doesn't support SAFE_BEST termination modes") end - storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : - nothing + storage = _get_storage(mode, x) termination_condition = tc(storage) xₙ = x diff --git a/src/lbroyden.jl b/src/lbroyden.jl index fc2b51a..95ec389 100644 --- a/src/lbroyden.jl +++ b/src/lbroyden.jl @@ -11,134 +11,121 @@ Broyden's method. This method is not very stable and can diverge even for very simple problems. This has mostly been tested for neural networks in DeepEquilibriumNetworks.jl. + +!!! note + + To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or + `import NNlib` must be present in your code. """ struct LBroyden{batched, TC <: NLSolveTerminationCondition} <: AbstractSimpleNonlinearSolveAlgorithm termination_condition::TC threshold::Int +end - function LBroyden(; batched = false, threshold::Int = 27, - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, - reltol = nothing)) - return new{batched, typeof(termination_condition)}(termination_condition, threshold) +function LBroyden(; batched = false, threshold::Int = 27, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) + if batched + @assert NNlibExtLoaded[] "Please install and load `NNlib.jl` to use batched Broyden." + return BatchedLBroyden(termination_condition, threshold) end + return LBroyden{true, typeof(termination_condition)}(termination_condition, threshold) end -@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden{batched}, args...; +@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, - kwargs...) where {batched} + kwargs...) + if SciMLBase.isinplace(prob) + error("LBroyden currently only supports out-of-place nonlinear problems") + end tc = alg.termination_condition mode = DiffEqBase.get_termination_mode(tc) - threshold = min(maxiters, alg.threshold) + η = min(maxiters, alg.threshold) x = float(prob.u0) - batched && @assert ndims(x)==2 "Batched LBroyden only supports 2D arrays" - + # FIXME: The scalar case currently is very inefficient if x isa Number restore_scalar = true x = [x] - f = u -> prob.f(u[], prob.p) + f = u -> [prob.f(u[], prob.p)] else f = Base.Fix2(prob.f, prob.p) restore_scalar = false end - fₙ = f(x) + L = length(x) T = eltype(x) - if SciMLBase.isinplace(prob) - error("LBroyden currently only supports out-of-place nonlinear problems") - end - - U, Vᵀ = _init_lbroyden_state(batched, x, threshold) + U = fill!(similar(x, (η, L)), zero(T)) + Vᵀ = fill!(similar(x, (L, η)), zero(T)) - atol = abstol !== nothing ? abstol : - (tc.abstol !== nothing ? tc.abstol : - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) - rtol = reltol !== nothing ? reltol : - (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) - - if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES - error("LBroyden currently doesn't support SAFE_BEST termination modes") - end - - storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : - nothing + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) + storage = _get_storage(mode, x) termination_condition = tc(storage) - xₙ = x - xₙ₋₁ = x - fₙ₋₁ = fₙ - update = fₙ + xₙ, xₙ₋₁, δfₙ = ntuple(_ -> copy(x), 3) + fₙ₋₁ = f(x) + δxₙ = -copy(fₙ₋₁) + ηNx = similar(xₙ, η) + for i in 1:maxiters - xₙ = xₙ₋₁ .+ update + @. xₙ = xₙ₋₁ - δxₙ fₙ = f(xₙ) - Δxₙ = xₙ .- xₙ₋₁ - Δfₙ = fₙ .- fₙ₋₁ + @. δxₙ = xₙ - xₙ₋₁ + @. δfₙ = fₙ - fₙ₋₁ - if termination_condition(restore_scalar ? [fₙ] : fₙ, xₙ, xₙ₋₁, atol, rtol) + if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) + retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, Val(false)) xₙ = restore_scalar ? xₙ[] : xₙ - return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success) + fₙ = restore_scalar ? fₙ[] : fₙ + return DiffEqBase.build_solution(prob, alg, xₙ, fₙ; retcode) end - _U = selectdim(U, 1, 1:min(threshold, i)) - _Vᵀ = selectdim(Vᵀ, 2, 1:min(threshold, i)) - - vᵀ = _rmatvec(_U, _Vᵀ, Δxₙ) - mvec = _matvec(_U, _Vᵀ, Δfₙ) - u = (Δxₙ .- mvec) ./ (sum(vᵀ .* Δfₙ) .+ convert(T, 1e-5)) + _L = min(i, η) + _U = U[1:_L, :] + _Vᵀ = Vᵀ[:, 1:_L] - selectdim(Vᵀ, 2, mod1(i, threshold)) .= vᵀ - selectdim(U, 1, mod1(i, threshold)) .= u + idx = mod1(i, η) - update = -_matvec(selectdim(U, 1, 1:min(threshold, i + 1)), - selectdim(Vᵀ, 2, 1:min(threshold, i + 1)), fₙ) + partial_ηNx = ηNx[1:_L] - xₙ₋₁ = xₙ - fₙ₋₁ = fₙ - end + if i > 1 + _ηNx = reshape(partial_ηNx, 1, :) + mul!(_ηNx, reshape(δxₙ, 1, L), _Vᵀ) + mul!(Vᵀ[:, idx:idx], _ηNx, _U) + Vᵀ[:, idx] .-= δxₙ - xₙ = restore_scalar ? xₙ[] : xₙ - return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters) -end + _ηNx = reshape(partial_ηNx, :, 1) + mul!(_ηNx, _U, reshape(δfₙ, L, 1)) + mul!(U[idx:idx, :], _Vᵀ, _ηNx) + U[idx, :] .-= δfₙ + else + Vᵀ[:, idx] .= -δxₙ + U[idx, :] .= -δfₙ + end -function _init_lbroyden_state(batched::Bool, x, threshold) - T = eltype(x) - if batched - U = fill!(similar(x, (threshold, size(x, 1), size(x, 2))), zero(T)) - Vᵀ = fill!(similar(x, (size(x, 1), threshold, size(x, 2))), zero(T)) - else - U = fill!(similar(x, (threshold, length(x))), zero(T)) - Vᵀ = fill!(similar(x, (length(x), threshold)), zero(T)) - end - return U, Vᵀ -end + U[idx, :] .= (δxₙ .- U[idx, :]) ./ + (sum(Vᵀ[:, idx] .* δfₙ) .+ + convert(T, 1e-5)) -function _rmatvec(U::AbstractMatrix, Vᵀ::AbstractMatrix, - x::Union{<:AbstractVector, <:Number}) - length(U) == 0 && return x - return -x .+ vec((x' * Vᵀ) * U) -end + _L = min(i + 1, η) + _ηNx = reshape(ηNx[1:_L], :, 1) + mul!(_ηNx, U[1:_L, :], reshape(δfₙ, L, 1)) + mul!(reshape(δxₙ, L, 1), Vᵀ[:, 1:_L], _ηNx) -function _rmatvec(U::AbstractArray{T1, 3}, Vᵀ::AbstractArray{T2, 3}, - x::AbstractMatrix) where {T1, T2} - length(U) == 0 && return x - Vᵀx = sum(Vᵀ .* reshape(x, size(x, 1), 1, size(x, 2)); dims = 1) - return -x .+ _drdims_sum(U .* permutedims(Vᵀx, (2, 1, 3)); dims = 1) -end + xₙ₋₁ .= xₙ + fₙ₋₁ .= fₙ + end -function _matvec(U::AbstractMatrix, Vᵀ::AbstractMatrix, - x::Union{<:AbstractVector, <:Number}) - length(U) == 0 && return x - return -x .+ vec(Vᵀ * (U * x)) -end + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + xₙ = storage.u + fₙ = f(xₙ) + end -function _matvec(U::AbstractArray{T1, 3}, Vᵀ::AbstractArray{T2, 3}, - x::AbstractMatrix) where {T1, T2} - length(U) == 0 && return x - xUᵀ = sum(reshape(x, size(x, 1), 1, size(x, 2)) .* permutedims(U, (2, 1, 3)); dims = 1) - return -x .+ _drdims_sum(xUᵀ .* Vᵀ; dims = 2) + xₙ = restore_scalar ? xₙ[] : xₙ + fₙ = restore_scalar ? fₙ[] : fₙ + return DiffEqBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters) end - -_drdims_sum(args...; dims = :) = dropdims(sum(args...; dims); dims)