From 0effc0139fc902669ec70192d8eb57387af3e6f0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Mar 2024 16:25:56 -0400 Subject: [PATCH] Add JET tests for MIRK --- Manifest.toml | 14 ++--- Project.toml | 41 +++++++------- src/interpolation.jl | 9 +-- test/mirk/interpolation_tests.jl | 36 ------------ ...nvergence_tests.jl => mirk_basic_tests.jl} | 55 ++++++++++++++++++- test/misc/affine_geodesic.jl | 3 +- 6 files changed, 89 insertions(+), 69 deletions(-) delete mode 100644 test/mirk/interpolation_tests.jl rename test/mirk/{mirk_convergence_tests.jl => mirk_basic_tests.jl} (68%) diff --git a/Manifest.toml b/Manifest.toml index a6574ebf..ab514172 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.2" manifest_format = "2.0" -project_hash = "ad2f92b76844532cf3e0a4abe6a85be7751b1309" +project_hash = "5a9701ab315469eb074470e987a001fa3e379b21" [[deps.ADTypes]] git-tree-sha1 = "016833eb52ba2d6bea9fcb50ca295980e728ee24" @@ -57,9 +57,9 @@ version = "7.9.0" [[deps.ArrayLayouts]] deps = ["FillArrays", "LinearAlgebra"] -git-tree-sha1 = "2aeaeaff72cdedaa0b5f30dfb8c1f16aefdac65d" +git-tree-sha1 = "6404a564c24a994814106c374bec893195e19bac" uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" -version = "1.7.0" +version = "1.8.0" weakdeps = ["SparseArrays"] [deps.ArrayLayouts.extensions] @@ -265,8 +265,6 @@ version = "0.1.10" [[deps.FastAlmostBandedMatrices]] deps = ["ArrayInterface", "ArrayLayouts", "BandedMatrices", "ConcreteStructs", "LazyArrays", "LinearAlgebra", "MatrixFactorizations", "PrecompileTools", "Reexport"] git-tree-sha1 = "9dc913faf8552fd09b92a0d7fcc25f1d5609d795" -repo-rev = "main" -repo-url = "https://github.com/SciML/FastAlmostBandedMatrices.jl" uuid = "9d29842c-ecb8-4973-b1e9-a27b1157504e" version = "0.1.1" @@ -609,9 +607,11 @@ version = "1.2.0" [[deps.NonlinearSolve]] deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Preferences", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "TimerOutputs"] -git-tree-sha1 = "d52bac2b94358b4b960cbfb896d5193d67f3ff09" +git-tree-sha1 = "0e464ca0e5d44a88c91f394c3f9a9448523e378b" +repo-rev = "ap/tstable_findmin" +repo-url = "https://github.com/SciML/NonlinearSolve.jl.git" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" -version = "3.8.0" +version = "3.8.2" [deps.NonlinearSolve.extensions] NonlinearSolveBandedMatricesExt = "BandedMatrices" diff --git a/Project.toml b/Project.toml index 51e98aaa..e9c3e455 100644 --- a/Project.toml +++ b/Project.toml @@ -34,36 +34,37 @@ ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c" BoundaryValueDiffEqODEInterfaceExt = "ODEInterface" [compat] -ADTypes = "0.2" +ADTypes = "0.2.6" Adapt = "4" Aqua = "0.8" -ArrayInterface = "7" -BandedMatrices = "1" -ConcreteStructs = "0.2" -DiffEqBase = "6.145" +ArrayInterface = "7.7" +BandedMatrices = "1.4" +ConcreteStructs = "0.2.3" +DiffEqBase = "6.146" DiffEqDevTools = "2.44" -FastAlmostBandedMatrices = "0.1" +FastAlmostBandedMatrices = "0.1.1" FastClosures = "0.3" -ForwardDiff = "0.10" +ForwardDiff = "0.10.36" JET = "0.8" -LinearAlgebra = "1.9" -LinearSolve = "2.20" -NonlinearSolve = "3.5" +LinearAlgebra = "1.10" +LinearSolve = "2.21" +Logging = "1.10" +NonlinearSolve = "3.8.1" ODEInterface = "0.5" -OrdinaryDiffEq = "6" +OrdinaryDiffEq = "6.63" PreallocationTools = "0.4" -PrecompileTools = "1" -Preferences = "1" -Random = "1" +PrecompileTools = "1.2" +Preferences = "1.4" +Random = "1.10" ReTestItems = "1.23.1" -RecursiveArrayTools = "3" -Reexport = "1.0" -SciMLBase = "2.12" +RecursiveArrayTools = "3.4" +Reexport = "1.2" +SciMLBase = "2.19" Setfield = "1" -SparseArrays = "1.9" -SparseDiffTools = "2.9" +SparseArrays = "1.10" +SparseDiffTools = "2.14" StaticArrays = "1.8.1" -Test = "1" +Test = "1.10" julia = "1.10" [extras] diff --git a/src/interpolation.jl b/src/interpolation.jl index 8ca5abf9..d94b13f5 100644 --- a/src/interpolation.jl +++ b/src/interpolation.jl @@ -1,6 +1,6 @@ -struct MIRKInterpolation{T1, T2} <: AbstractDiffEqInterpolation - t::T1 - u::T2 +@concrete struct MIRKInterpolation <: AbstractDiffEqInterpolation + t + u cache end @@ -9,11 +9,12 @@ function DiffEqBase.interp_summary(interp::MIRKInterpolation) end function (id::MIRKInterpolation)(tvals, idxs, deriv, p, continuity::Symbol = :left) - interpolation(tvals, id, idxs, deriv, p, continuity) + return interpolation(tvals, id, idxs, deriv, p, continuity) end function (id::MIRKInterpolation)(val, tvals, idxs, deriv, p, continuity::Symbol = :left) interpolation!(val, tvals, id, idxs, deriv, p, continuity) + return end # FIXME: Fix the interpolation outside the tspan diff --git a/test/mirk/interpolation_tests.jl b/test/mirk/interpolation_tests.jl deleted file mode 100644 index a5a8efd6..00000000 --- a/test/mirk/interpolation_tests.jl +++ /dev/null @@ -1,36 +0,0 @@ -@testitem "MIRK Interpolation" begin - using LinearAlgebra - - λ = 1 - function prob_bvp_linear_analytic(u, λ, t) - a = 1 / sqrt(λ) - return [(exp(-a * t) - exp((t - 2) * a)) / (1 - exp(-2 * a)), - (-a * exp(-t * a) - a * exp((t - 2) * a)) / (1 - exp(-2 * a))] - end - - function prob_bvp_linear_f!(du, u, p, t) - du[1] = u[2] - du[2] = 1 / p * u[1] - end - function prob_bvp_linear_bc!(res, u, p, t) - res[1] = u[1][1] - 1 - res[2] = u[end][1] - end - - prob_bvp_linear_function = ODEFunction( - prob_bvp_linear_f!, analytic = prob_bvp_linear_analytic) - prob_bvp_linear_tspan = (0.0, 1.0) - prob_bvp_linear = BVProblem( - prob_bvp_linear_function, prob_bvp_linear_bc!, [1.0, 0.0], prob_bvp_linear_tspan, λ) - testTol = 1e-6 - - for order in (2, 3, 4, 5, 6) - s = Symbol("MIRK$(order)") - @eval mirk_solver(::Val{$order}) = $(s)() - end - - @testset "MIRK$order" for order in (2, 3, 4, 5, 6) - sol = solve(prob_bvp_linear, mirk_solver(Val(order)); dt = 0.001) - @test sol(0.001)≈[0.998687464, -1.312035941] atol=testTol - end -end diff --git a/test/mirk/mirk_convergence_tests.jl b/test/mirk/mirk_basic_tests.jl similarity index 68% rename from test/mirk/mirk_convergence_tests.jl rename to test/mirk/mirk_basic_tests.jl index 1f21ad37..6e153990 100644 --- a/test/mirk/mirk_convergence_tests.jl +++ b/test/mirk/mirk_basic_tests.jl @@ -4,7 +4,7 @@ using BoundaryValueDiffEq for order in (2, 3, 4, 5, 6) s = Symbol("MIRK$(order)") - @eval mirk_solver(::Val{$order}) = $(s)() + @eval mirk_solver(::Val{$order}, args...; kwargs...) = $(s)(args...; kwargs...) end # First order test @@ -87,6 +87,22 @@ end end end +@testitem "JET: Runtime Dispatches" setup=[MIRKConvergenceTests] begin + using JET + + @testset "Problem: $i" for i in 1:8 + prob = probArr[i] + @testset "MIRK$order" for order in (2, 3, 4, 5, 6) + solver = mirk_solver(Val(order); + jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))) + @test_opt target_modules=(NonlinearSolve, BoundaryValueDiffEq) solve( + prob, solver; dt = 0.2) + @test_call target_modules=(NonlinearSolve, BoundaryValueDiffEq) solve( + prob, solver; dt = 0.2) + end + end +end + @testitem "Convergence on Linear" setup=[MIRKConvergenceTests] begin using LinearAlgebra, DiffEqDevTools @@ -129,3 +145,40 @@ end @test_nowarn solve(bvp1, MIRK5(; jac_alg); dt = 0.05) @test_nowarn solve(bvp1, MIRK6(; jac_alg); dt = 0.05) end + +@testitem "Interpolation" begin + using LinearAlgebra + + λ = 1 + function prob_bvp_linear_analytic(u, λ, t) + a = 1 / sqrt(λ) + return [(exp(-a * t) - exp((t - 2) * a)) / (1 - exp(-2 * a)), + (-a * exp(-t * a) - a * exp((t - 2) * a)) / (1 - exp(-2 * a))] + end + + function prob_bvp_linear_f!(du, u, p, t) + du[1] = u[2] + du[2] = 1 / p * u[1] + end + function prob_bvp_linear_bc!(res, u, p, t) + res[1] = u[1][1] - 1 + res[2] = u[end][1] + end + + prob_bvp_linear_function = ODEFunction( + prob_bvp_linear_f!, analytic = prob_bvp_linear_analytic) + prob_bvp_linear_tspan = (0.0, 1.0) + prob_bvp_linear = BVProblem( + prob_bvp_linear_function, prob_bvp_linear_bc!, [1.0, 0.0], prob_bvp_linear_tspan, λ) + testTol = 1e-6 + + for order in (2, 3, 4, 5, 6) + s = Symbol("MIRK$(order)") + @eval mirk_solver(::Val{$order}) = $(s)() + end + + @testset "MIRK$order" for order in (2, 3, 4, 5, 6) + sol = solve(prob_bvp_linear, mirk_solver(Val(order)); dt = 0.001) + @test sol(0.001)≈[0.998687464, -1.312035941] atol=testTol + end +end diff --git a/test/misc/affine_geodesic.jl b/test/misc/affine_geodesic.jl index 69fcc938..2543e891 100644 --- a/test/misc/affine_geodesic.jl +++ b/test/misc/affine_geodesic.jl @@ -27,7 +27,8 @@ tspan = (0.0, 1.0) function bc1!(residual, u, p, t) mid = div(length(u[1]), 2) residual[1:mid] = u[1][1:mid] - a1 - return residual[(mid + 1):end] = u[end][1:mid] - a2 + residual[(mid + 1):end] = u[end][1:mid] - a2 + return end function chart_log_problem!(du, u, params, t)