Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't differentiate an ODE solver due to lack of isnan and other type errors. #73

Closed
orebas opened this issue May 16, 2024 · 8 comments
Closed

Comments

@orebas
Copy link

orebas commented May 16, 2024

TaylorDiff.jl seems to throw an error when I try to differentiate a fairly simple ODE solver. There is an error on the MTK side, but even after the workaround there (See https://discourse.julialang.org/t/error-trying-to-forwarddiff-through-an-ode-solver/114339/6) I can't get Taylor diff.jl to work.

MWE:

using ModelingToolkit, DifferentialEquations
using TaylorDiff, ForwardDiff
using DifferentiationInterface, Enzyme, Zygote, ReverseDiff
using SciMLSensitivity
#import Base.isnan
#function isnan(x::TaylorScalar{Float64, 2})
#	return false
#end

function ADTest()
	@parameters a
	@variables t x1(t) 
	D = Differential(t)
	states = [x1]
	parameters = [a]

	@named pre_model = ODESystem([D(x1) ~ a * x1], t, states, parameters)
	model = structural_simplify(pre_model)

	ic = Dict(x1 => 1.0)
	p_true = Dict(a => 2.0)

	problem = ODEProblem{true, SciMLBase.FullSpecialize}(model, ic, [0.0, 1.0], p_true)
	soln = ModelingToolkit.solve(problem, Tsit5(), abstol = 1e-12, reltol = 1e-12)
	display(soln(0.5, idxs = [x1]))

	function different_time(new_ic, new_params, new_t)
		#newprob = ODEProblem{true, SciMLBase.FullSpecialize}(model, new_ic, [0.0, new_t*2], new_params)
		#newprob = remake(problem, u0=new_ic, tspan = [0.0, new_t], p = new_params)
		newprob = remake(problem, u0 = new_ic, tspan = [0.0, new_t], p=new_params)
		newprob = remake(newprob, u0 = typeof(new_t).(newprob.u0))
        new_soln = ModelingToolkit.solve(newprob, Tsit5(), abstol = 1e-12, reltol = 1e-12)
		return (soln(new_t, idxs = [x1]))
	end

	function just_t(new_t)
		return different_time(ic, p_true, new_t)[1]
	end
	display(different_time(ic, p_true, 2e-5))
	display(just_t(0.5))

	
    #display(ForwardDiff.derivative(just_t,1.0))
	display(TaylorDiff.derivative(just_t,1.0,1))  #isnan error
    #display(value_and_gradient(just_t, AutoForwardDiff(), 1.0)) 
	#display(value_and_gradient(just_t, AutoReverseDiff(), 1.0)) 	
    #display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Reverse), 1.0)) 
	#display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Forward), 1.0)) 
    #display(value_and_gradient(just_t, AutoZygote(), 1.0)) 
	
end

ADTest()

@orebas
Copy link
Author

orebas commented May 16, 2024

Running the above, the error is

ERROR: LoadError: MethodError: no method matching isnan(::TaylorScalar{Float64, 2})

Closest candidates are:
  isnan(::Missing)
   @ Base missing.jl:101
  isnan(::BigFloat)
   @ Base mpfr.jl:982
  isnan(::Complex)
   @ Base complex.jl:151
  ...

Stacktrace:
  [1] _any(f::typeof(isnan), itr::Tuple{TaylorScalar{Float64, 2}, TaylorScalar{Float64, 2}}, ::Colon)
    @ Base ./reduce.jl:1220
  [2] any(f::Function, itr::Tuple{TaylorScalar{Float64, 2}, TaylorScalar{Float64, 2}})
    @ Base ./reduce.jl:1235
  [3] get_concrete_tspan(prob::ODEProblem{…}, isadapt::Bool, kwargs::@Kwargs{…}, p::ModelingToolkit.MTKParameters{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1287
  [4] get_concrete_problem(prob::ODEProblem{…}, isadapt::Bool; kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1169
  [5] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1074
  [6] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1003
  [7] (::var"#different_time#1"{ODESolution{…}, ODEProblem{…}, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::TaylorScalar{Float64, 2})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:32
  [8] (::var"#just_t#2"{var"#different_time#1"{ODESolution{…}, ODEProblem{…}, Num}, Dict{Num, Float64}, Dict{Num, Float64}})(new_t::TaylorScalar{Float64, 2})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:37
  [9] derivative
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:28 [inlined]
 [10] derivative
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:18 [inlined]
 [11] ADTest()
    @ Main ~/learning/ODETests/PLI/MWE3.jl:44
 [12] top-level scope
    @ ~/learning/ODETests/PLI/MWE3.jl:53
 [13] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [14] top-level scope
    @ REPL[1]:1

@orebas
Copy link
Author

orebas commented May 16, 2024

If I go ahead and try to define isnan, (you can uncomment 4 lines near the top of the MWE), the error becomes

ERROR: LoadError: Non-concrete element type inside of an `Array` detected.
Arrays with non-concrete element types, such as
`Array{Union{Float32,Float64}}`, are not supported by the
differential equation solvers. Anyways, this is bad for
performance so you don't want to be doing this!

If this was a mistake, promote the element types to be
all the same. If this was intentional, for example,
using Unitful.jl with different unit values, then use
an array type which has fast broadcast support for
heterogeneous values such as the ArrayPartition
from RecursiveArrayTools.jl. For example:

```julia
using RecursiveArrayTools
x = ArrayPartition([1.0,2.0],[1f0,2f0])
y = ArrayPartition([3.0,4.0],[3f0,4f0])
x .+ y # fast, stable, and usable as u0 into DiffEq!

Element type:
Any

Some of the types have been truncated in the stacktrace for improved reading. To emit complete information
in the stack trace, evaluate TruncatedStacktraces.VERBOSE[] = true and re-run the code.

Stacktrace:
[1] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:592
[2] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1080
[3] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1003
[4] (::var"#different_time#3"{ODESolution{…}, ODEProblem{…}, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::TaylorScalar{Float64, 2})
@ Main ~/learning/ODETests/PLI/MWE3.jl:32
[5] (::var"#just_t#4"{var"#different_time#3"{ODESolution{…}, ODEProblem{…}, Num}, Dict{Num, Float64}, Dict{Num, Float64}})(new_t::TaylorScalar{Float64, 2})
@ Main ~/learning/ODETests/PLI/MWE3.jl:37
[6] derivative
@ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:28 [inlined]
[7] derivative
@ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:18 [inlined]
[8] ADTest()
@ Main ~/learning/ODETests/PLI/MWE3.jl:44
[9] top-level scope
@ ~/learning/ODETests/PLI/MWE3.jl:53
[10] include(fname::String)
@ Base.MainInclude ./client.jl:489
[11] top-level scope
@ REPL[1]:1
in expression starting at /home/orebas/learning/ODETests/PLI/MWE3.jl:53
Some type information was truncated. Use show(err) to see complete types.

@tansongchen
Copy link
Member

This is identified previously: #35, due to the type system inconsistency issues. Unfortunately I haven't figured out a good way to handle this...

@tansongchen
Copy link
Member

Ok I now believe not <: Real is a design error and needs to be fixed. I initiated a fix at https://github.com/JuliaDiff/TaylorDiff.jl/tree/subtype-number , when it is done you will be fine at this application

@tansongchen
Copy link
Member

Fixed in latest version 0.2.2

@orebas
Copy link
Author

orebas commented May 22, 2024

I'm still getting this error with the above MWE:

ERROR: LoadError: MethodError: no method matching TaylorScalar{Float64, 2}(::Tuple{Float64, ChainRulesCore.ZeroTangent})

Closest candidates are:
  TaylorScalar{T, N}(::TaylorScalar{T, M}) where {T, N, M}
   @ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/scalar.jl:65
  TaylorScalar{T, N}(::S, ::S) where {T, S<:Real, N}
   @ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/scalar.jl:58
  TaylorScalar{T, N}(::S) where {T, S<:Real, N}
   @ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/scalar.jl:46
  ...

Stacktrace:
  [1] sign(t::TaylorScalar{Float64, 2})
    @ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/codegen.jl:20
  [2] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::TaylorScalar{…}, dtmin::TaylorScalar{…}, dtmax::TaylorScalar{…}, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/YXsFS/src/solve.jl:120
  [3] __solve(::ODEProblem{…}, ::Tsit5{…}; kwargs::@Kwargs{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/YXsFS/src/solve.jl:6
  [4] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:612
  [5] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1080
  [6] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1003
  [7] (::var"#different_time#5"{ODESolution{…}, ODEProblem{…}, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::TaylorScalar{Float64, 2})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:32
  [8] (::var"#just_t#6"{var"#different_time#5"{ODESolution{…}, ODEProblem{…}, Num}, Dict{Num, Float64}, Dict{Num, Float64}})(new_t::TaylorScalar{Float64, 2})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:37
  [9] derivatives
    @ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:66 [inlined]
 [10] derivative
    @ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:54 [inlined]
 [11] derivative
    @ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:35 [inlined]
 [12] derivative
    @ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:30 [inlined]
 [13] ADTest()
    @ Main ~/learning/ODETests/PLI/MWE3.jl:44
 [14] top-level scope
    @ ~/learning/ODETests/PLI/MWE3.jl:53
 [15] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [16] top-level scope
    @ REPL[5]:1

@tansongchen
Copy link
Member

Oh that's a problem with codegen. I will run you example and make it work tomorrow

@tansongchen tansongchen reopened this May 23, 2024
@tansongchen
Copy link
Member

Ok so I fixed a minor problem related to convert special tangent types at ChainRules. Now they should be fine

julia> ForwardDiff.derivative(just_t, 1.0)
14.778112197861631

julia> TaylorDiff.derivative(just_t, 1.0, 1)
14.77811219786163

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants