Skip to content

Commit

Permalink
Fix ODEInterface tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 1, 2024
1 parent eb3a028 commit 87dca5c
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 22 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ jobs:
- Others
version:
- '1'
- '~1.10.0-0'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: [1.5.0]
julia-version: [1]
julia-arch: [x86]
os: [ubuntu-latest]
steps:
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "BoundaryValueDiffEq"
uuid = "764a87c0-6b3e-53db-9096-fe964310641d"
version = "5.6.3"
version = "5.7.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -51,7 +51,7 @@ ForwardDiff = "0.10"
JET = "0.8"
LinearAlgebra = "1.9"
LinearSolve = "2.20"
NonlinearSolve = "2.6.1, 3"
NonlinearSolve = "3.5"
ODEInterface = "0.5"
OrdinaryDiffEq = "6"
PreallocationTools = "0.4"
Expand All @@ -70,7 +70,7 @@ Test = "1"
Tricks = "0.1"
TruncatedStacktraces = "1"
UnPack = "1"
julia = "1.9"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
4 changes: 2 additions & 2 deletions ext/BoundaryValueDiffEqODEInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function __solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, kwargs...

n == -1 && dt 0 && throw(ArgumentError("`dt` must be positive."))

mesh = __extract_mesh(prob.u0, t₀, t₁, ifelse(n == -1, dt, n))
mesh = __extract_mesh(prob.u0, t₀, t₁, ifelse(n == -1, dt, n - 1))
n = length(mesh) - 1
no_odes = length(u0_)

Expand Down Expand Up @@ -111,7 +111,7 @@ function __solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol = 1e-3, d

n == -1 && dt 0 && throw(ArgumentError("`dt` must be positive."))
u0 = __flatten_initial_guess(prob.u0)
mesh = __extract_mesh(prob.u0, t₀, t₁, ifelse(n == -1, dt, n))
mesh = __extract_mesh(prob.u0, t₀, t₁, ifelse(n == -1, dt, n - 1))
if u0 === nothing
# initial_guess function was provided
u0 = mapreduce(@closure(t->vec(__initial_guess(prob.u0, prob.p, t))), hcat, mesh)
Expand Down
4 changes: 2 additions & 2 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
internal_ode_kwargs = (; verbose, kwargs..., odesolve_kwargs..., save_end = true)

solve_internal_odes! = @closure (resid_nodes, us, p, cur_nshoot, nodes,
odecache) -> __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoot,
odecache) -> __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoot,
odecache, nodes, u0_size, N, ensemblealg, tspan)

# This gets all the nshoots except the final SingleShooting case
Expand Down Expand Up @@ -476,4 +476,4 @@ end
end
@assert !(1 in nshoots_vec)
return nshoots_vec
end
end
3 changes: 2 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ Takes the input initial guess and returns the mesh.
"""
@inline __extract_mesh(u₀, t₀, t₁, n::Int) = collect(range(t₀; stop = t₁, length = n + 1))
@inline __extract_mesh(u₀, t₀, t₁, dt::Number) = collect(t₀:dt:t₁)
@inline __extract_mesh(u₀::DiffEqArray, t₀, t₁, n) = u₀.t
@inline __extract_mesh(u₀::DiffEqArray, t₀, t₁, ::Int) = u₀.t
@inline __extract_mesh(u₀::DiffEqArray, t₀, t₁, ::Number) = u₀.t

"""
__has_initial_guess(u₀) -> Bool
Expand Down
30 changes: 18 additions & 12 deletions test/misc/odeinterface_wrapper.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Test, BoundaryValueDiffEq, LinearAlgebra, ODEInterface, Random, RecursiveArrayTools
using Test, BoundaryValueDiffEq, LinearAlgebra, ODEInterface, Random, OrdinaryDiffEq,
RecursiveArrayTools

# Adaptation of https://github.com/luchr/ODEInterface.jl/blob/958b6023d1dabf775033d0b89c5401b33100bca3/examples/BasicExamples/ex7.jl
function ex7_f!(du, u, p, t)
Expand All @@ -25,7 +26,6 @@ tspan = (-π / 2, π / 2)

tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), u0, tspan, p;
bcresid_prototype = (zeros(1), zeros(1)))
sol_bvpm2 = solve(tpprob, BVPM2(); dt = π / 20)

@testset "BVPM2" begin
@info "Testing BVPM2"
Expand All @@ -38,23 +38,28 @@ sol_bvpm2 = solve(tpprob, BVPM2(); dt = π / 20)
@test norm(resid_f, Inf) < 1e-6
end

# Just generate a solution for bvpsol
sol_ms = solve(tpprob, MultipleShooting(10, DP5(), NewtonRaphson());
dt = π / 20, abstol = 1e-5, maxiters = 1000,
odesolve_kwargs = (; adaptive = false, dt = 0.01, abstol = 1e-6, maxiters = 1000))

# Just test that it runs. BVPSOL only works with linearly separable BCs.
@testset "BVPSOL" begin
@info "Testing BVPSOL"

@info "BVPSOL with Vector{<:AbstractArray}"

initial_u0 = [sol_bvpm2(t) .+ rand() for t in tspan[1]:/ 20):tspan[2]]
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan;
initial_u0 = [sol_ms(t) .+ rand() for t in tspan[1]:/ 20):tspan[2]]
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan, p;
bcresid_prototype = (zeros(1), zeros(1)))

# Just test that it runs. BVPSOL only works with linearly separable BCs.
sol_bvpsol = solve(tpprob, BVPSOL(); dt = π / 20)

@info "BVPSOL with VectorOfArray"

initial_u0 = VectorOfArray([sol_bvpm2(t) .+ rand() for t in tspan[1]:/ 20):tspan[2]])
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan;
initial_u0 = VectorOfArray([sol_ms(t) .+ rand() for t in tspan[1]:/ 20):tspan[2]])
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan, p;
bcresid_prototype = (zeros(1), zeros(1)))

# Just test that it runs. BVPSOL only works with linearly separable BCs.
Expand All @@ -63,18 +68,19 @@ end
@info "BVPSOL with DiffEqArray"

ts = collect(tspan[1]:/ 20):tspan[2])
initial_u0 = DiffEqArray([sol_bvpm2(t) .+ rand() for t in ts], ts)
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan;
initial_u0 = DiffEqArray([sol_ms(t) .+ rand() for t in ts], ts)
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan, p;
bcresid_prototype = (zeros(1), zeros(1)))

sol_bvpsol = solve(tpprob, BVPSOL(); dt = π / 20)

@info "BVPSOL with initial guess function"

initial_u0 = (p, t) -> sol_bvpm2(t) .+ rand()
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan, p;
bcresid_prototype = (zeros(1), zeros(1)))
sol_bvpsol = solve(tpprob, BVPSOL(); dt = π / 20)
initial_u0 = (p, t) -> sol_ms(t) .+ rand()
# FIXME: Upstream fix
# tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan, p;
# bcresid_prototype = (zeros(1), zeros(1)))
# sol_bvpsol = solve(tpprob, BVPSOL(); dt = π / 20)
end

#=
Expand Down

0 comments on commit 87dca5c

Please sign in to comment.