Skip to content

Commit

Permalink
Continuation test
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 24, 2024
1 parent 61dda98 commit 20c4004
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 20 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ functionality should check out [DifferentialEquations.jl](https://github.com/Jul

## API

BoundaryValueDiffEq.jl is part of the JuliaDiffEq common interface, but can be used independently of DifferentialEquations.jl. The only requirement is that the user passes a BoundaryValueDiffEq.jl algorithm to solve. For example, we can solve the [BVP tutorial from the documentation](https://docs.sciml.ai/DiffEqDocs/stable/tutorials/bvp_example/) using the `MIRK4()` algorithm:
BoundaryValueDiffEq.jl is part of the SciML common interface, but can be used independently of DifferentialEquations.jl. The only requirement is that the user passes a BoundaryValueDiffEq.jl algorithm to solve. For example, we can solve the [BVP tutorial from the documentation](https://docs.sciml.ai/DiffEqDocs/stable/tutorials/bvp_example/) using the `MIRK4()` algorithm:

```julia
using BoundaryValueDiffEq
Expand Down
17 changes: 10 additions & 7 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ end
u0 = [5.0, -3.5]
bcresid_prototype = (Array{Float64}(undef, 1), Array{Float64}(undef, 1))

probs = [BVProblem(f1!, bc1!, u0, tspan; nlls=Val(false)),
BVProblem(f1, bc1, u0, tspan; nlls=Val(false)),
TwoPointBVProblem(f1!, (bc1_a!, bc1_b!), u0, tspan; bcresid_prototype, nlls=Val(false)),
TwoPointBVProblem(f1, (bc1_a, bc1_b), u0, tspan; bcresid_prototype, nlls=Val(false))]
probs = [BVProblem(f1!, bc1!, u0, tspan; nlls = Val(false)),
BVProblem(f1, bc1, u0, tspan; nlls = Val(false)),
TwoPointBVProblem(
f1!, (bc1_a!, bc1_b!), u0, tspan; bcresid_prototype, nlls = Val(false)),
TwoPointBVProblem(
f1, (bc1_a, bc1_b), u0, tspan; bcresid_prototype, nlls = Val(false))]

algs = []

Expand Down Expand Up @@ -130,8 +132,8 @@ end
u0, tspan, nlls = Val(true)),
BVProblem(BVPFunction(f1_nlls, bc1_nlls; bcresid_prototype = bcresid_prototype1),
u0, tspan, nlls = Val(true)),
TwoPointBVProblem(f1_nlls!, (bc1_nlls_a!, bc1_nlls_b!), u0,
tspan; bcresid_prototype = bcresid_prototype2, nlls = Val(true)),
TwoPointBVProblem(f1_nlls!, (bc1_nlls_a!, bc1_nlls_b!), u0, tspan;
bcresid_prototype = bcresid_prototype2, nlls = Val(true)),
TwoPointBVProblem(f1_nlls, (bc1_nlls_a, bc1_nlls_b), u0, tspan;
bcresid_prototype = bcresid_prototype2, nlls = Val(true))]

Expand Down Expand Up @@ -240,7 +242,8 @@ end
if @load_preference("PrecompileShootingNLLS", true)
append!(algs,
[
Shooting(Tsit5(); nlsolve = LevenbergMarquardt(; disable_geodesic = Val(true)),
Shooting(
Tsit5(); nlsolve = LevenbergMarquardt(; disable_geodesic = Val(true)),
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))),
Shooting(Tsit5(); nlsolve = GaussNewton(),
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2)))])
Expand Down
3 changes: 1 addition & 2 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ end

function concrete_jacobian_algorithm(
jac_alg::BVPJacobianAlgorithm, prob_type, prob::BVProblem, alg)
u0 = prob.u0 isa AbstractArray ? prob.u0 :
__initial_guess(prob.u0, prob.p, first(prob.tspan))
u0 = __extract_u0(prob.u0, prob.p, first(prob.tspan))
diffmode = jac_alg.diffmode === nothing ? __default_sparse_ad(u0) : jac_alg.diffmode
bc_diffmode = jac_alg.bc_diffmode === nothing ?
(prob_type isa TwoPointBVProblem ? __default_sparse_ad :
Expand Down
6 changes: 6 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ function __extract_problem_details(prob, u0::AbstractVector{<:AbstractArray}; kw
_u0 = first(u0)
return Val(true), eltype(_u0), length(_u0), (length(u0) - 1), _u0
end
function __extract_problem_details(
prob, u0::RecursiveArrayTools.AbstractVectorOfArray; kwargs...)
# Problem has Initial Guess
_u0 = first(u0.u)
return Val(true), eltype(_u0), length(_u0), (length(u0.u) - 1), _u0
end
function __extract_problem_details(
prob, u0::AbstractArray; dt = 0.0, check_positive_dt::Bool = false)
# Problem does not have Initial Guess
Expand Down
37 changes: 37 additions & 0 deletions test/mirk/mirk_basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,40 @@ end

@test_nowarn solve(prob, MIRK4(); dt = 0.01)
end

@testitem "Solve using Continuation" begin
using RecursiveArrayTools

g = 9.81
L = 1.0
tspan = (0.0, pi / 2)
function simplependulum!(du, u, p, t)
θ = u[1]
= u[2]
du[1] =
du[2] = -(g / L) * sin(θ)
end

function bc2a!(resid_a, u_a, p) # u_a is at the beginning of the time span
x0 = p
resid_a[1] = u_a[1] - x0 # the solution at the beginning of the time span should be -pi/2
end
function bc2b!(resid_b, u_b, p) # u_b is at the ending of the time span
x0 = p
resid_b[1] = u_b[1] - pi / 2 # the solution at the end of the time span should be pi/2
end

bvp3 = TwoPointBVProblem(
simplependulum!, (bc2a!, bc2b!), [pi / 2, pi / 2], (pi / 4, pi / 2),
-pi / 2; bcresid_prototype = (zeros(1), zeros(1)))
sol3 = solve(bvp3, MIRK4(), dt = 0.05)

# Needs a SciMLBase fix
bvp4 = TwoPointBVProblem(simplependulum!, (bc2a!, bc2b!), sol3, (0, pi / 2),
pi / 2; bcresid_prototype = (zeros(1), zeros(1)))
@test_broken solve(bvp4, MIRK4(), dt = 0.05) isa SciMLBase.ODESolution

bvp5 = TwoPointBVProblem(simplependulum!, (bc2a!, bc2b!), DiffEqArray(sol3.u, sol3.t),
(0, pi / 2), pi / 2; bcresid_prototype = (zeros(1), zeros(1)))
@test SciMLBase.successful_retcode(solve(bvp5, MIRK4(), dt = 0.05).retcode)
end
7 changes: 6 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
using ReTestItems

ReTestItems.runtests(@__DIR__)
ReTestItems.runtests(joinpath(@__DIR__, "mirk/"))
ReTestItems.runtests(joinpath(@__DIR__, "misc/"))
ReTestItems.runtests(joinpath(@__DIR__, "shooting/"))

# Wrappers like ODEInterface don't support parallel testing
ReTestItems.runtests(joinpath(@__DIR__, "wrappers/"); nworkers = 0)
21 changes: 12 additions & 9 deletions test/wrappers/odeinterface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,15 @@ u0 = [0.5, 1.0]
p = [0.1]
tspan = (-π / 2, π / 2)

tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), u0, tspan,
p; bcresid_prototype = (zeros(1), zeros(1)))

# Just generate a solution for bvpsol
sol_ms = solve(tpprob, MultipleShooting(10, DP5(), NewtonRaphson());
dt = π / 20, abstol = 1e-5, maxiters = 1000, adaptive = false)

export ex7_f!, ex7_2pbc1!, ex7_2pbc2!, u0, p, tspan, tpprob, sol_ms
export ex7_f!, ex7_2pbc1!, ex7_2pbc2!, u0, p, tspan

end

@testitem "BVPM2" setup=[ODEInterfaceWrapperTestSetup] begin
using ODEInterface, RecursiveArrayTools
using ODEInterface, RecursiveArrayTools, LinearAlgebra

tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), u0, tspan,
p; bcresid_prototype = (zeros(1), zeros(1)))

sol_bvpm2 = solve(tpprob, BVPM2(); dt = π / 20)
@test SciMLBase.successful_retcode(sol_bvpm2)
Expand All @@ -51,6 +47,13 @@ end
@testitem "BVPSOL" setup=[ODEInterfaceWrapperTestSetup] begin
using ODEInterface, RecursiveArrayTools

tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), u0, tspan,
p; bcresid_prototype = (zeros(1), zeros(1)))

# Just generate a solution for bvpsol
sol_ms = solve(tpprob, MultipleShooting(10, DP5(), NewtonRaphson());
dt = π / 20, abstol = 1e-5, maxiters = 1000, adaptive = false)

initial_u0 = [sol_ms(t) .+ rand() for t in tspan[1]:/ 20):tspan[2]]
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0,
tspan, p; bcresid_prototype = (zeros(1), zeros(1)))
Expand Down

0 comments on commit 20c4004

Please sign in to comment.