Skip to content

Commit

Permalink
Handle Dynamic Dispatches in Multiple Shooting
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 1, 2024
1 parent d52eac3 commit eb3a028
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 233 deletions.
4 changes: 2 additions & 2 deletions ext/BoundaryValueDiffEqODEInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ function __solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, kwargs...
bvpm2_destroy(obj)
bvpm2_destroy(sol)

return SciMLBase.build_solution(prob, ivpsol)
return SciMLBase.build_solution(prob, ivpsol, nothing)
end

#-------
Expand Down Expand Up @@ -180,7 +180,7 @@ function __solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol = 1e-3, d
map(x -> reshape(convert(Vector{eltype(u0_)}, x), u0_size), eachcol(sol_x));
retcode = retcode 0 ? ReturnCode.Success : ReturnCode.Failure, stats)

return SciMLBase.build_solution(prob, ivpsol)
return SciMLBase.build_solution(prob, ivpsol, nothing)
end

#-------
Expand Down
10 changes: 0 additions & 10 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,6 @@ Significantly more stable than Single Shooting.
grid_coarsening
end

# function Base.show(io::IO, alg::MultipleShooting)
# print(io, "MultipleShooting(")
# modifiers = String[]
# alg.nlsolve !== nothing && push!(modifiers, "nlsolve = $(alg.nlsolve)")
# alg.jac_alg !== nothing && push!(modifiers, "jac_alg = $(alg.jac_alg)")
# alg.ode_alg !== nothing && push!(modifiers, "ode_alg = $(__nameof(alg.ode_alg))()")
# print(io, join(modifiers, ", "))
# print(io, ")")
# end

function concretize_jacobian_algorithm(alg::MultipleShooting, prob)
jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg)
return MultipleShooting(alg.ode_alg, alg.nlsolve, jac_alg, alg.nshoots,
Expand Down
62 changes: 30 additions & 32 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...)
@unpack f, tspan = prob
(; f, tspan) = prob

@assert (ensemblealg isa EnsembleSerial)||(ensemblealg isa EnsembleThreads) "Currently MultipleShooting only supports `EnsembleSerial` and `EnsembleThreads`!"
if !(ensemblealg isa EnsembleSerial) && !(ensemblealg isa EnsembleThreads)
throw(ArgumentError("Currently MultipleShooting only supports `EnsembleSerial` and \
`EnsembleThreads`!"))
end

ig, T, N, Nig, u0 = __extract_problem_details(prob; dt = 0.1)
has_initial_guess = _unwrap_val(ig)
Expand Down Expand Up @@ -30,11 +33,9 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),

internal_ode_kwargs = (; verbose, kwargs..., odesolve_kwargs..., save_end = true)

function solve_internal_odes!(resid_nodes::T1, us::T2, p::T3, cur_nshoot::Int,
nodes::T4, odecache::C) where {T1, T2, T3, T4, C}
return __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoot,
odecache, nodes, u0_size, N, ensemblealg)
end
solve_internal_odes! = @closure (resid_nodes, us, p, cur_nshoot, nodes,
odecache) -> __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoot,
odecache, nodes, u0_size, N, ensemblealg, tspan)

# This gets all the nshoots except the final SingleShooting case
all_nshoots = __get_all_nshoots(alg.grid_coarsening, nshoots)
Expand Down Expand Up @@ -96,7 +97,7 @@ function __solve_nlproblem!(::TwoPointBVProblem, alg::MultipleShooting, bcresid_
resid_prototype = vcat(bcresid_prototype[1],
similar(u_at_nodes, cur_nshoot * N), bcresid_prototype[2])

loss_fn = (du, u, p) -> __multiple_shooting_2point_loss!(du, u, p, cur_nshoot,
loss_fn = @closure (du, u, p) -> __multiple_shooting_2point_loss!(du, u, p, cur_nshoot,
nodes, prob, solve_internal_odes!, resida_len, residb_len, N, bca, bcb,
ode_cache_loss_fn)

Expand All @@ -112,19 +113,18 @@ function __solve_nlproblem!(::TwoPointBVProblem, alg::MultipleShooting, bcresid_
jac_cache, alg.jac_alg.diffmode, alg.ode_alg, cur_nshoot, u0;
internal_ode_kwargs...)

loss_fnₚ = (du, u) -> __multiple_shooting_2point_loss!(du, u, prob.p, cur_nshoot,
nodes, prob, solve_internal_odes!, resida_len, residb_len, N, bca, bcb,
loss_fnₚ = @closure (du, u) -> __multiple_shooting_2point_loss!(du, u, prob.p,
cur_nshoot, nodes, prob, solve_internal_odes!, resida_len, residb_len, N, bca, bcb,
ode_cache_jac_fn)

jac_fn = (J, u, p) -> __multiple_shooting_2point_jacobian!(J, u, p, jac_cache,
jac_fn = @closure (J, u, p) -> __multiple_shooting_2point_jacobian!(J, u, p, jac_cache,
loss_fnₚ, resid_prototype_cached, alg)

loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype, jac = jac_fn,
jac_prototype)
loss_function! = __unsafe_nonlinearfunction{true}(loss_fn; resid_prototype,
jac = jac_fn, jac_prototype)

# NOTE: u_at_nodes is updated inplace
nlprob = (M != N ? NonlinearLeastSquaresProblem : NonlinearProblem)(loss_function!,
u_at_nodes, prob.p)
nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p)
__solve(nlprob, alg.nlsolve; kwargs..., alias_u0 = true)

return nothing
Expand All @@ -144,7 +144,7 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_
resid_nodes = __maybe_allocate_diffcache(__resid_nodes,
pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.bc_diffmode)

loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss!(du, u, p, cur_nshoot,
loss_fn = @closure (du, u, p) -> __multiple_shooting_mpoint_loss!(du, u, p, cur_nshoot,
nodes, prob, solve_internal_odes!, resid_len, N, f, bc, u0_size, prob.tspan,
alg.ode_alg, u0, ode_cache_loss_fn)

Expand All @@ -169,22 +169,21 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_
jac_prototype = vcat(init_jacobian(bc_jac_cache), init_jacobian(ode_jac_cache))

# Define the functions now
ode_fn = (du, u) -> solve_internal_odes!(du, u, prob.p, cur_nshoot, nodes,
ode_fn = @closure (du, u) -> solve_internal_odes!(du, u, prob.p, cur_nshoot, nodes,
ode_cache_ode_jac_fn)
bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc!(du, u, prob.p, cur_nshoot, nodes,
prob, solve_internal_odes!, N, f, bc, u0_size, prob.tspan, alg.ode_alg, u0,
ode_cache_bc_jac_fn)
bc_fn = @closure (du, u) -> __multiple_shooting_mpoint_loss_bc!(du, u, prob.p,
cur_nshoot, nodes, prob, solve_internal_odes!, N, f, bc, u0_size, prob.tspan,
alg.ode_alg, u0, ode_cache_bc_jac_fn)

jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian!(J, u, p,
jac_fn = @closure (J, u, p) -> __multiple_shooting_mpoint_jacobian!(J, u, p,
similar(bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache,
ode_fn, bc_fn, alg, N, M)

loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype, jac = jac_fn,
jac_prototype)
loss_function! = __unsafe_nonlinearfunction{true}(loss_fn; resid_prototype,
jac_prototype, jac = jac_fn)

# NOTE: u_at_nodes is updated inplace
nlprob = (M != N ? NonlinearLeastSquaresProblem : NonlinearProblem)(loss_function!,
u_at_nodes, prob.p)
nlprob = __internal_nlsolve_problem(prob, M, N, loss_function!, u_at_nodes, prob.p)
__solve(nlprob, alg.nlsolve; kwargs..., alias_u0 = true)

return nothing
Expand Down Expand Up @@ -224,7 +223,7 @@ end

# Not using `EnsembleProblem` since it is hard to initialize the cache and stuff
function __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoots::Int,
odecache, nodes, u0_size, N::Int, ::EnsembleSerial)
odecache, nodes, u0_size, N::Int, ::EnsembleSerial, tspan)
ts_ = Vector{Vector{typeof(first(tspan))}}(undef, cur_nshoots)
us_ = Vector{Vector{typeof(us)}}(undef, cur_nshoots)

Expand All @@ -242,7 +241,7 @@ function __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoots::
end

function __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoots::Int,
odecache::Vector, nodes, u0_size, N::Int, ::EnsembleThreads)
odecache::Vector, nodes, u0_size, N::Int, ::EnsembleThreads, tspan)
ts_ = Vector{Vector{typeof(first(tspan))}}(undef, cur_nshoots)
us_ = Vector{Vector{typeof(us)}}(undef, cur_nshoots)

Expand Down Expand Up @@ -364,7 +363,7 @@ end

# NOTE: We don't check `u0 isa Function` since `u0` in-principle can be a callable
# struct
u0_ = u0 isa AbstractArray ? u0 : [__initial_guess(u0, prob.p, t) for t in nodes]
u0_ = u0 isa VectorOfArray ? u0 : [__initial_guess(u0, prob.p, t) for t in nodes]

N = length(first(u0_))
u_at_nodes = similar(first(u0_), (nshoots + 1) * N)
Expand Down Expand Up @@ -399,14 +398,13 @@ end
sol = solve!(odecache)

if SciMLBase.successful_retcode(sol)
res = sol(nodes).u
for i in 1:length(nodes)
u_at_nodes[(i - 1) * N .+ (1:N)] .= vec(res[i])
u_at_nodes[(i - 1) * N .+ (1:N)] .= vec(sol(nodes[i]))
end
else
@warn "Initialization using odesolve failed. Initializing using 0s. It is \
recommended to provide an initial guess function via \
`u0 = <function>(p, t)` or `u0 = <function>(t)` in this case."
`u0 = <function>(p, t)` in this case."
fill!(u_at_nodes, 0)
end

Expand Down Expand Up @@ -478,4 +476,4 @@ end
end
@assert !(1 in nshoots_vec)
return nshoots_vec
end
end
Loading

0 comments on commit eb3a028

Please sign in to comment.