Skip to content

Commit

Permalink
Add JET tests for MIRK
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 22, 2024
1 parent 9192bd0 commit 0effc01
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 69 deletions.
14 changes: 7 additions & 7 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.2"
manifest_format = "2.0"
project_hash = "ad2f92b76844532cf3e0a4abe6a85be7751b1309"
project_hash = "5a9701ab315469eb074470e987a001fa3e379b21"

[[deps.ADTypes]]
git-tree-sha1 = "016833eb52ba2d6bea9fcb50ca295980e728ee24"
Expand Down Expand Up @@ -57,9 +57,9 @@ version = "7.9.0"

[[deps.ArrayLayouts]]
deps = ["FillArrays", "LinearAlgebra"]
git-tree-sha1 = "2aeaeaff72cdedaa0b5f30dfb8c1f16aefdac65d"
git-tree-sha1 = "6404a564c24a994814106c374bec893195e19bac"
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
version = "1.7.0"
version = "1.8.0"
weakdeps = ["SparseArrays"]

[deps.ArrayLayouts.extensions]
Expand Down Expand Up @@ -265,8 +265,6 @@ version = "0.1.10"
[[deps.FastAlmostBandedMatrices]]
deps = ["ArrayInterface", "ArrayLayouts", "BandedMatrices", "ConcreteStructs", "LazyArrays", "LinearAlgebra", "MatrixFactorizations", "PrecompileTools", "Reexport"]
git-tree-sha1 = "9dc913faf8552fd09b92a0d7fcc25f1d5609d795"
repo-rev = "main"
repo-url = "https://github.com/SciML/FastAlmostBandedMatrices.jl"
uuid = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
version = "0.1.1"

Expand Down Expand Up @@ -609,9 +607,11 @@ version = "1.2.0"

[[deps.NonlinearSolve]]
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Preferences", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "TimerOutputs"]
git-tree-sha1 = "d52bac2b94358b4b960cbfb896d5193d67f3ff09"
git-tree-sha1 = "0e464ca0e5d44a88c91f394c3f9a9448523e378b"
repo-rev = "ap/tstable_findmin"
repo-url = "https://github.com/SciML/NonlinearSolve.jl.git"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
version = "3.8.0"
version = "3.8.2"

[deps.NonlinearSolve.extensions]
NonlinearSolveBandedMatricesExt = "BandedMatrices"
Expand Down
41 changes: 21 additions & 20 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,37 @@ ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"

[compat]
ADTypes = "0.2"
ADTypes = "0.2.6"
Adapt = "4"
Aqua = "0.8"
ArrayInterface = "7"
BandedMatrices = "1"
ConcreteStructs = "0.2"
DiffEqBase = "6.145"
ArrayInterface = "7.7"
BandedMatrices = "1.4"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.146"
DiffEqDevTools = "2.44"
FastAlmostBandedMatrices = "0.1"
FastAlmostBandedMatrices = "0.1.1"
FastClosures = "0.3"
ForwardDiff = "0.10"
ForwardDiff = "0.10.36"
JET = "0.8"
LinearAlgebra = "1.9"
LinearSolve = "2.20"
NonlinearSolve = "3.5"
LinearAlgebra = "1.10"
LinearSolve = "2.21"
Logging = "1.10"
NonlinearSolve = "3.8.1"
ODEInterface = "0.5"
OrdinaryDiffEq = "6"
OrdinaryDiffEq = "6.63"
PreallocationTools = "0.4"
PrecompileTools = "1"
Preferences = "1"
Random = "1"
PrecompileTools = "1.2"
Preferences = "1.4"
Random = "1.10"
ReTestItems = "1.23.1"
RecursiveArrayTools = "3"
Reexport = "1.0"
SciMLBase = "2.12"
RecursiveArrayTools = "3.4"
Reexport = "1.2"
SciMLBase = "2.19"
Setfield = "1"
SparseArrays = "1.9"
SparseDiffTools = "2.9"
SparseArrays = "1.10"
SparseDiffTools = "2.14"
StaticArrays = "1.8.1"
Test = "1"
Test = "1.10"
julia = "1.10"

[extras]
Expand Down
9 changes: 5 additions & 4 deletions src/interpolation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct MIRKInterpolation{T1, T2} <: AbstractDiffEqInterpolation
t::T1
u::T2
@concrete struct MIRKInterpolation <: AbstractDiffEqInterpolation
t
u
cache
end

Expand All @@ -9,11 +9,12 @@ function DiffEqBase.interp_summary(interp::MIRKInterpolation)
end

function (id::MIRKInterpolation)(tvals, idxs, deriv, p, continuity::Symbol = :left)
interpolation(tvals, id, idxs, deriv, p, continuity)
return interpolation(tvals, id, idxs, deriv, p, continuity)
end

function (id::MIRKInterpolation)(val, tvals, idxs, deriv, p, continuity::Symbol = :left)
interpolation!(val, tvals, id, idxs, deriv, p, continuity)
return

Check warning on line 17 in src/interpolation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation.jl#L17

Added line #L17 was not covered by tests
end

# FIXME: Fix the interpolation outside the tspan
Expand Down
36 changes: 0 additions & 36 deletions test/mirk/interpolation_tests.jl

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using BoundaryValueDiffEq

for order in (2, 3, 4, 5, 6)
s = Symbol("MIRK$(order)")
@eval mirk_solver(::Val{$order}) = $(s)()
@eval mirk_solver(::Val{$order}, args...; kwargs...) = $(s)(args...; kwargs...)
end

# First order test
Expand Down Expand Up @@ -87,6 +87,22 @@ end
end
end

@testitem "JET: Runtime Dispatches" setup=[MIRKConvergenceTests] begin
using JET

@testset "Problem: $i" for i in 1:8
prob = probArr[i]
@testset "MIRK$order" for order in (2, 3, 4, 5, 6)
solver = mirk_solver(Val(order);
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2)))
@test_opt target_modules=(NonlinearSolve, BoundaryValueDiffEq) solve(
prob, solver; dt = 0.2)
@test_call target_modules=(NonlinearSolve, BoundaryValueDiffEq) solve(
prob, solver; dt = 0.2)
end
end
end

@testitem "Convergence on Linear" setup=[MIRKConvergenceTests] begin
using LinearAlgebra, DiffEqDevTools

Expand Down Expand Up @@ -129,3 +145,40 @@ end
@test_nowarn solve(bvp1, MIRK5(; jac_alg); dt = 0.05)
@test_nowarn solve(bvp1, MIRK6(; jac_alg); dt = 0.05)
end

@testitem "Interpolation" begin
using LinearAlgebra

λ = 1
function prob_bvp_linear_analytic(u, λ, t)
a = 1 / sqrt(λ)
return [(exp(-a * t) - exp((t - 2) * a)) / (1 - exp(-2 * a)),
(-a * exp(-t * a) - a * exp((t - 2) * a)) / (1 - exp(-2 * a))]
end

function prob_bvp_linear_f!(du, u, p, t)
du[1] = u[2]
du[2] = 1 / p * u[1]
end
function prob_bvp_linear_bc!(res, u, p, t)
res[1] = u[1][1] - 1
res[2] = u[end][1]
end

prob_bvp_linear_function = ODEFunction(
prob_bvp_linear_f!, analytic = prob_bvp_linear_analytic)
prob_bvp_linear_tspan = (0.0, 1.0)
prob_bvp_linear = BVProblem(
prob_bvp_linear_function, prob_bvp_linear_bc!, [1.0, 0.0], prob_bvp_linear_tspan, λ)
testTol = 1e-6

for order in (2, 3, 4, 5, 6)
s = Symbol("MIRK$(order)")
@eval mirk_solver(::Val{$order}) = $(s)()
end

@testset "MIRK$order" for order in (2, 3, 4, 5, 6)
sol = solve(prob_bvp_linear, mirk_solver(Val(order)); dt = 0.001)
@test sol(0.001)[0.998687464, -1.312035941] atol=testTol
end
end
3 changes: 2 additions & 1 deletion test/misc/affine_geodesic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ tspan = (0.0, 1.0)
function bc1!(residual, u, p, t)
mid = div(length(u[1]), 2)
residual[1:mid] = u[1][1:mid] - a1
return residual[(mid + 1):end] = u[end][1:mid] - a2
residual[(mid + 1):end] = u[end][1:mid] - a2
return
end

function chart_log_problem!(du, u, params, t)
Expand Down

0 comments on commit 0effc01

Please sign in to comment.