Skip to content

Commit

Permalink
Better handling of PolyAlgs
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 22, 2024
1 parent 0effc01 commit 7f41f1c
Show file tree
Hide file tree
Showing 11 changed files with 310 additions and 249 deletions.
2 changes: 2 additions & 0 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ include("sparse_jacobians.jl")
include("adaptivity.jl")
include("interpolation.jl")

include("default_nlsolve.jl")

function __solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
Expand Down
49 changes: 49 additions & 0 deletions src/default_nlsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Currently there are some problems with the default NonlinearSolver selection for
# BoundaryValueDiffEq
# See https://github.com/SciML/BoundaryValueDiffEq.jl/issues/175
# and https://github.com/SciML/BoundaryValueDiffEq.jl/issues/163
# These are not meant to be user facing and we should delete these once those issues are
# resolved
function __FastShortcutBVPCompatibleNLLSPolyalg(
::Type{T} = Float64; concrete_jac = nothing, linsolve = nothing,
precs = NonlinearSolve.DEFAULT_PRECS, autodiff = nothing, kwargs...) where {T}
if NonlinearSolve.__is_complex(T)
algs = (GaussNewton(; concrete_jac, linsolve, precs, autodiff, kwargs...),
LevenbergMarquardt(;
linsolve, precs, autodiff, disable_geodesic = Val(true), kwargs...),
LevenbergMarquardt(; linsolve, precs, autodiff, kwargs...))
else
algs = (GaussNewton(; concrete_jac, linsolve, precs, autodiff, kwargs...),
LevenbergMarquardt(;
linsolve, precs, disable_geodesic = Val(true), autodiff, kwargs...),
TrustRegion(; concrete_jac, linsolve, precs, autodiff, kwargs...),
GaussNewton(; concrete_jac, linsolve, precs,
linesearch = LineSearchesJL(; method = BackTracking()),
autodiff, kwargs...),
LevenbergMarquardt(; linsolve, precs, autodiff, kwargs...))
end
return NonlinearSolvePolyAlgorithm(algs, Val(:NLLS))
end

function __FastShortcutBVPCompatibleNonlinearPolyalg(
::Type{T} = Float64; concrete_jac = nothing, linsolve = nothing,
precs = NonlinearSolve.DEFAULT_PRECS, autodiff = nothing) where {T}
if NonlinearSolve.__is_complex(T)
algs = (NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),)
else
algs = (NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
NewtonRaphson(; concrete_jac, linsolve, precs,
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
TrustRegion(; concrete_jac, linsolve, precs, autodiff))
end
return NonlinearSolvePolyAlgorithm(algs, Val(:NLS))
end

@inline __concrete_nonlinearsolve_algorithm(prob, alg) = alg
@inline function __concrete_nonlinearsolve_algorithm(prob, ::Nothing)
if prob isa NonlinearLeastSquaresProblem
return __FastShortcutBVPCompatibleNLLSPolyalg(eltype(prob.u0))
else
return __FastShortcutBVPCompatibleNonlinearPolyalg(eltype(prob.u0))
end
end
95 changes: 55 additions & 40 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,58 +138,73 @@ function __split_mirk_kwargs(;
end

function SciMLBase.solve!(cache::MIRKCache)
(defect_threshold, MxNsub, abstol, adaptive, _), kwargs = __split_mirk_kwargs(;
cache.kwargs...)
(; y, y₀, prob, alg, mesh, mesh_dt, TU, ITU) = cache
(_, _, abstol, adaptive, _), kwargs = __split_mirk_kwargs(; cache.kwargs...)
info::ReturnCode.T = ReturnCode.Success
defect_norm = 2 * abstol

while SciMLBase.successful_retcode(info) && defect_norm > abstol
nlprob = __construct_nlproblem(cache, recursive_flatten(y₀))
sol_nlprob = __solve(nlprob, alg.nlsolve; abstol, kwargs..., alias_u0 = true)
recursive_unflatten!(cache.y₀, sol_nlprob.u)
# We do the first iteration outside the loop to preserve type-stability of the
# `original` field of the solution
sol_nlprob, info, defect_norm = __perform_mirk_iteration(
cache, abstol, adaptive; kwargs...)

info = sol_nlprob.retcode
if adaptive
while SciMLBase.successful_retcode(info) && defect_norm > abstol
sol_nlprob, info, defect_norm = __perform_mirk_iteration(
cache, abstol, adaptive; kwargs...)
end
end

!adaptive && break
u = [reshape(y, cache.in_size) for y in cache.y₀]

if info == ReturnCode.Success
defect_norm = defect_estimate!(cache)
# The defect is greater than 10%, the solution is not acceptable
defect_norm > defect_threshold && (info = ReturnCode.Failure)
end
odesol = DiffEqBase.build_solution(cache.prob, cache.alg, cache.mesh, u;
interp = MIRKInterpolation(cache.mesh, u, cache), retcode = info)
return __build_solution(cache.prob, odesol, sol_nlprob)
end

function __perform_mirk_iteration(cache::MIRKCache, abstol, adaptive; kwargs...)
nlprob = __construct_nlproblem(cache, recursive_flatten(cache.y₀))
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
sol_nlprob = __solve(nlprob, nlsolve_alg; abstol, kwargs..., alias_u0 = true)
recursive_unflatten!(cache.y₀, sol_nlprob.u)

defect_norm = 2 * abstol

if info == ReturnCode.Success
if defect_norm > abstol
# We construct a new mesh to equidistribute the defect
mesh, mesh_dt, _, info = mesh_selector!(cache)
if info == ReturnCode.Success
__append_similar!(cache.y₀, length(cache.mesh), cache.M)
for (i, m) in enumerate(cache.mesh)
interp_eval!(cache.y₀[i], cache, m, mesh, mesh_dt)
end
__expand_cache!(cache)
# Early terminate if non-adaptive
adaptive || return sol_nlprob, sol_nlprob.retcode, defect_norm

info = sol_nlprob.retcode

if info == ReturnCode.Success # Nonlinear Solve was successful
defect_norm = defect_estimate!(cache)
# The defect is greater than 10%, the solution is not acceptable
defect_norm > cache.alg.defect_threshold && (info = ReturnCode.Failure)
end

if info == ReturnCode.Success # Nonlinear Solve Successful and defect norm is acceptable
if defect_norm > abstol
# We construct a new mesh to equidistribute the defect
mesh, mesh_dt, _, info = mesh_selector!(cache)
if info == ReturnCode.Success
__append_similar!(cache.y₀, length(cache.mesh), cache.M)
for (i, m) in enumerate(cache.mesh)
interp_eval!(cache.y₀[i], cache, m, mesh, mesh_dt)
end
end
else
# We cannot obtain a solution for the current mesh
if 2 * (length(cache.mesh) - 1) > MxNsub
# New mesh would be too large
info = ReturnCode.Failure
else
half_mesh!(cache)
__expand_cache!(cache)
recursive_fill!(cache.y₀, 0)
info = ReturnCode.Success # Force a restart
defect_norm = 2 * abstol
end
end
else # Something bad happened
# We cannot obtain a solution for the current mesh
if 2 * (length(cache.mesh) - 1) > cache.alg.max_num_subintervals
# New mesh would be too large
info = ReturnCode.Failure
else
half_mesh!(cache)
__expand_cache!(cache)
recursive_fill!(cache.y₀, 0)
info = ReturnCode.Success # Force a restart
end
end

u = [reshape(y, cache.in_size) for y in cache.y₀]
# TODO: Return `nlsol` as original
return DiffEqBase.build_solution(prob, alg, cache.mesh, u;
interp = MIRKInterpolation(cache.mesh, u, cache), retcode = info)
return sol_nlprob, info, defect_norm
end

# Constructing the Nonlinear Problem
Expand Down
9 changes: 6 additions & 3 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
ig, T, N, Nig, u0 = __extract_problem_details(prob; dt = 0.1)
has_initial_guess = _unwrap_val(ig)

@assert u0 isa AbstractVector "Non-Vector Inputs for Multiple-Shooting hasn't been implemented yet!"

bcresid_prototype, resid_size = __get_bcresid_prototype(prob, u0)
iip, bc, u0, u0_size = isinplace(prob), prob.f.bc, deepcopy(u0), size(u0)

Expand Down Expand Up @@ -126,7 +128,8 @@ function __solve_nlproblem!(

# NOTE: u_at_nodes is updated inplace
nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p)
__solve(nlprob, alg.nlsolve; kwargs..., alias_u0 = true)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(prob, alg.nlsolve)
__solve(nlprob, nlsolve_alg; kwargs..., alias_u0 = true)

return nothing
end
Expand Down Expand Up @@ -185,8 +188,8 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_

# NOTE: u_at_nodes is updated inplace
nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p)
# FIXME: This is not safe for polyalgorithms
__solve(nlprob, alg.nlsolve; kwargs..., alias_u0 = true)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(prob, alg.nlsolve)
__solve(nlprob, nlsolve_alg; kwargs..., alias_u0 = true)

return nothing
end
Expand Down
3 changes: 2 additions & 1 deletion src/solve/single_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;),
nlf = __unsafe_nonlinearfunction{iip}(
loss_fn; jac_prototype, resid_prototype, jac = jac_fn)
nlprob = __internal_nlsolve_problem(prob, resid_prototype, u0, nlf, vec(u0), prob.p)
nlsol = __solve(nlprob, alg.nlsolve; nlsolve_kwargs..., verbose, kwargs...)
nlsolve_alg = __concrete_nonlinearsolve_algorithm(prob, alg.nlsolve)
nlsol = __solve(nlprob, nlsolve_alg; nlsolve_kwargs..., verbose, kwargs...)

# There is no way to reinit with the same cache with different cache. But not saving
# the internal values gives a significant speedup. So we just create a new cache
Expand Down
53 changes: 42 additions & 11 deletions test/mirk/mirk_basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,18 @@ bcresid_prototype = (Array{Float64}(undef, 1), Array{Float64}(undef, 1))
tspan = (0.0, 5.0)
u0 = [5.0, -3.5]

probArr = [BVProblem(odef1!, boundary!, u0, tspan),
BVProblem(odef1, boundary, u0, tspan),
BVProblem(odef2!, boundary!, u0, tspan),
BVProblem(odef2, boundary, u0, tspan),
probArr = [BVProblem(odef1!, boundary!, u0, tspan, nlls = Val(false)),
BVProblem(odef1, boundary, u0, tspan, nlls = Val(false)),
BVProblem(odef2!, boundary!, u0, tspan, nlls = Val(false)),
BVProblem(odef2, boundary, u0, tspan, nlls = Val(false)),
TwoPointBVProblem(odef1!, (boundary_two_point_a!, boundary_two_point_b!),
u0, tspan; bcresid_prototype),
TwoPointBVProblem(
odef1, (boundary_two_point_a, boundary_two_point_b), u0, tspan; bcresid_prototype),
u0, tspan; bcresid_prototype, nlls = Val(false)),
TwoPointBVProblem(odef1, (boundary_two_point_a, boundary_two_point_b),
u0, tspan; bcresid_prototype, nlls = Val(false)),
TwoPointBVProblem(odef2!, (boundary_two_point_a!, boundary_two_point_b!),
u0, tspan; bcresid_prototype),
TwoPointBVProblem(
odef2, (boundary_two_point_a, boundary_two_point_b), u0, tspan; bcresid_prototype)]
u0, tspan; bcresid_prototype, nlls = Val(false)),
TwoPointBVProblem(odef2, (boundary_two_point_a, boundary_two_point_b),
u0, tspan; bcresid_prototype, nlls = Val(false))]

testTol = 0.2
affineTol = 1e-2
Expand Down Expand Up @@ -93,7 +93,7 @@ end
@testset "Problem: $i" for i in 1:8
prob = probArr[i]
@testset "MIRK$order" for order in (2, 3, 4, 5, 6)
solver = mirk_solver(Val(order);
solver = mirk_solver(Val(order); nlsolve = NewtonRaphson(),
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2)))
@test_opt target_modules=(NonlinearSolve, BoundaryValueDiffEq) solve(
prob, solver; dt = 0.2)
Expand Down Expand Up @@ -182,3 +182,34 @@ end
@test sol(0.001)[0.998687464, -1.312035941] atol=testTol
end
end

@testitem "Swirling Flow III" begin
# Reported in https://github.com/SciML/BoundaryValueDiffEq.jl/issues/153
eps = 0.01
function swirling_flow!(du, u, p, t)
eps = p
du[1] = u[2]
du[2] = (u[1] * u[4] - u[3] * u[2]) / eps
du[3] = u[4]
du[4] = u[5]
du[5] = u[6]
du[6] = (-u[3] * u[6] - u[1] * u[2]) / eps
return
end

function swirling_flow_bc!(res, u, p, t)
res[1] = u[1][1] + 1.0
res[2] = u[1][3]
res[3] = u[1][4]
res[4] = u[end][1] - 1.0
res[5] = u[end][3]
res[6] = u[end][4]
return
end

tspan = (0.0, 1.0)
u0 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
prob = BVProblem(swirling_flow!, swirling_flow_bc!, u0, tspan, eps)

@test_nowarn solve(prob, MIRK4(); dt = 0.01)
end
52 changes: 0 additions & 52 deletions test/misc/affine_geodesic.jl

This file was deleted.

Loading

0 comments on commit 7f41f1c

Please sign in to comment.