Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ImmutableNonlinearProblem #153

Merged
merged 5 commits into from
Jul 22, 2024
Merged

Conversation

m-bossart
Copy link
Contributor

Note this is replacing #151.

I'm having trouble figuring out all the changes that need to take place to make all the AD dispatches work correctly. So far I have changed the dispatches in this package and also in SciMLSensitivity (SciML/SciMLSensitivity.jl#1066) -- essentially replacing every dispatch for NonlinearProblem with ImmutableNonlinearProblem.

The sensitivity test using ForwardDiff does not pass, there is a method ambiguity error. Do I need to go change the dispatches here as well: https://github.com/SciML/NonlinearSolve.jl/blob/master/src/internal/forward_diff.jl

When I try to test the adjoint methods in SciMLSensitivity.jl I don't hit the rule here because it uses the method from SimpleNonlinearSolve: https://github.com/SciML/SimpleNonlinearSolve.jl/blob/main/ext/SimpleNonlinearSolveChainRulesCoreExt.jl

My overarching question is in what packages should this require changes?

@m-bossart
Copy link
Contributor Author

@ChrisRackauckas Can you give me a bit more direction about the strategy to make all the AD stuff work?

@m-bossart m-bossart marked this pull request as draft July 16, 2024 21:28
@ChrisRackauckas
Copy link
Member

I think see how this would be ambiguous with the definitions that require NonlinearProblem. What is the method ambiguity? Show the error message.

When I try to test the adjoint methods in SciMLSensitivity.jl I don't hit the rule here because it uses the method from SimpleNonlinearSolve: https://github.com/SciML/SimpleNonlinearSolve.jl/blob/main/ext/SimpleNonlinearSolveChainRulesCoreExt.jl

That should get the immutable problem dispatch too.

I think that's all that's needed to complete this.

@m-bossart
Copy link
Contributor Author

Here is the ambiguous method error message from this test:

∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)

julia> show(err)
1-element ExceptionStack:
LoadError: MethodError: solve(::NonlinearProblem{Vector{Float64}, false, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2}}, NonlinearFunction{false, SciMLBase.FullSpecialize, typeof(ff), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardNonlinearProblem}, ::SimpleNewtonRaphson{Nothing}) is ambiguous.

Candidates:
  solve(prob::NonlinearProblem{<:Union{Number, var"#s145"} where var"#s145"<:AbstractArray, iip, <:Union{var"#s144", var"#s143"} where {var"#s144"<:ForwardDiff.Dual{T, V, P}, var"#s143"<:(AbstractArray{<:ForwardDiff.Dual{T, V, P}})}}, alg::Union{Nothing, SciMLBase.AbstractNonlinearAlgorithm}, args...; kwargs...) where {T, V, P, iip}
    @ NonlinearSolve C:\Users\Matt Bossart\.julia\packages\NonlinearSolve\82B4C\src\internal\forward_diff.jl:5
  solve(prob::NonlinearProblem, alg::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg, u0, p, kwargs...)
    @ SimpleNonlinearSolve C:\Users\Matt Bossart\OneDrive - UCB-O365\Desktop\Transient Stability\Forked Repositories\SimpleNonlinearSolve.jl\src\SimpleNonlinearSolve.jl:72

Possible fix, define
  solve(::NonlinearProblem{uType, iip, P1} where {uType<:(Union{Number, var"#s145"} where var"#s145"<:AbstractArray), P1<:(Union{var"#s144", var"#s143"} where {var"#s144"<:ForwardDiff.Dual{T, V, P}, var"#s143"<:(AbstractArray{<:ForwardDiff.Dual{T, V, P}})})}, ::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm, ::Vararg{Any}) where {T, V, P, iip}

Stacktrace:
  [1] solve_nlprob(p::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2}})
    @ Main c:\Users\Matt Bossart\OneDrive - UCB-O365\Desktop\Transient Stability\Forked Repositories\SimpleNonlinearSolve.jl\test\core\adjoint_tests.jl:9
  [2] vector_mode_dual_eval!(f::typeof(solve_nlprob), cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2}}}, x::Vector{Float64})
    @ ForwardDiff C:\Users\Matt Bossart\.julia\packages\ForwardDiff\PcZ48\src\apiutils.jl:24
  [3] vector_mode_gradient(f::typeof(solve_nlprob), x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2}}})
    @ ForwardDiff C:\Users\Matt Bossart\.julia\packages\ForwardDiff\PcZ48\src\gradient.jl:89
  [4] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2}}}, ::Val{true})
    @ ForwardDiff C:\Users\Matt Bossart\.julia\packages\ForwardDiff\PcZ48\src\gradient.jl:19
  [5] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2}}})
    @ ForwardDiff C:\Users\Matt Bossart\.julia\packages\ForwardDiff\PcZ48\src\gradient.jl:17
  [6] top-level scope
    @ c:\Users\Matt Bossart\OneDrive - UCB-O365\Desktop\Transient Stability\Forked Repositories\SimpleNonlinearSolve.jl\test\core\adjoint_tests.jl:17
  [7] eval
    @ .\boot.jl:385 [inlined]
  [8] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base .\loading.jl:2076
  [9] invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::@Kwargs{})
    @ Base .\essentials.jl:892
 [10] invokelatest(::Any, ::Any, ::Vararg{Any})
    @ Base .\essentials.jl:889
 [11] inlineeval(m::Module, code::String, code_line::Int64, code_column::Int64, file::String; softscope::Bool)
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\eval.jl:271
 [12] (::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\eval.jl:181
 [13] withpath(f::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, path::String)
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\repl.jl:276
 [14] (::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\eval.jl:179
 [15] hideprompt(f::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})        
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\repl.jl:38
 [16] (::VSCodeServer.var"#67#72"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\eval.jl:150
 [17] with_logstate(f::Function, logstate::Any)
    @ Base.CoreLogging .\logging.jl:515
 [18] with_logger
    @ .\logging.jl:627 [inlined]
 [19] (::VSCodeServer.var"#66#71"{VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\eval.jl:263
 [20] #invokelatest#2
    @ .\essentials.jl:892 [inlined]
 [21] invokelatest(::Any)
    @ Base .\essentials.jl:889
 [22] (::VSCodeServer.var"#64#65")()
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\eval.jl:34
in expression starting at c:\Users\Matt Bossart\OneDrive - UCB-O365\Desktop\Transient Stability\Forked Repositories\SimpleNonlinearSolve.jl\test\core\adjoint_tests.jl:17
julia> 

@ChrisRackauckas
Copy link
Member

solve(prob::NonlinearProblem, alg::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg, u0, p, kwargs...)
    @ SimpleNonlinearSolve C:\Users\Matt Bossart\OneDrive - UCB-O365\Desktop\Transient Stability\Forked Repositories\SimpleNonlinearSolve.jl\src\SimpleNonlinearSolve.jl:72

fix this one so the alg dispatch is Union{Nothing,SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm}

@ChrisRackauckas
Copy link
Member

nevermind.

@ChrisRackauckas
Copy link
Member

Add SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm to the union of the first one.

@m-bossart
Copy link
Contributor Author

I still get the method ambiguity error:

LoadError: MethodError: solve(::NonlinearProblem{Vector{Float64}, false, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2}}, NonlinearFunction{false, SciMLBase.FullSpecialize, typeof(ff), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardNonlinearProblem}, ::SimpleNewtonRaphson{Nothing}) is ambiguous.

Candidates:
  solve(prob::NonlinearProblem{<:Union{Number, var"#s145"} where var"#s145"<:AbstractArray, iip, <:Union{var"#s144", var"#s143"} where {var"#s144"<:ForwardDiff.Dual{T, V, P}, var"#s143"<:(AbstractArray{<:ForwardDiff.Dual{T, V, P}})}}, alg::Union{Nothing, SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm}, args...; kwargs...) where {T, V, P, iip}
    @ NonlinearSolve C:\Users\Matt Bossart\OneDrive - UCB-O365\Desktop\Transient Stability\Forked Repositories\NonlinearSolve.jl\src\internal\forward_diff.jl:5
  solve(prob::NonlinearProblem, alg::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg, u0, p, kwargs...)
    @ SimpleNonlinearSolve C:\Users\Matt Bossart\OneDrive - UCB-O365\Desktop\Transient Stability\Forked Repositories\SimpleNonlinearSolve.jl\src\SimpleNonlinearSolve.jl:72

Possible fix, define
  solve(::NonlinearProblem{uType, iip, P1} where {uType<:(Union{Number, var"#s145"} where var"#s145"<:AbstractArray), P1<:(Union{var"#s144", var"#s143"} where {var"#s144"<:ForwardDiff.Dual{T, V, P}, var"#s143"<:(AbstractArray{<:ForwardDiff.Dual{T, V, P}})})}, ::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm, ::Vararg{Any}) where {T, V, P, iip}

Stacktrace:
  [1] solve_nlprob(p::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2}})
    @ Main c:\Users\Matt Bossart\OneDrive - UCB-O365\Desktop\Transient Stability\Forked Repositories\SimpleNonlinearSolve.jl\test\core\adjoint_tests.jl:10
  [2] vector_mode_dual_eval!(f::typeof(solve_nlprob), cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2}}}, x::Vector{Float64})
    @ ForwardDiff C:\Users\Matt Bossart\.julia\packages\ForwardDiff\PcZ48\src\apiutils.jl:24
  [3] vector_mode_gradient(f::typeof(solve_nlprob), x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2}}})
    @ ForwardDiff C:\Users\Matt Bossart\.julia\packages\ForwardDiff\PcZ48\src\gradient.jl:89
  [4] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2}}}, ::Val{true})
    @ ForwardDiff C:\Users\Matt Bossart\.julia\packages\ForwardDiff\PcZ48\src\gradient.jl:19
  [5] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(solve_nlprob), Float64}, Float64, 2}}})
    @ ForwardDiff C:\Users\Matt Bossart\.julia\packages\ForwardDiff\PcZ48\src\gradient.jl:17
  [6] top-level scope
    @ c:\Users\Matt Bossart\OneDrive - UCB-O365\Desktop\Transient Stability\Forked Repositories\SimpleNonlinearSolve.jl\test\core\adjoint_tests.jl:18
  [7] eval
    @ .\boot.jl:385 [inlined]
  [8] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base .\loading.jl:2076
  [9] invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::@Kwargs{})
    @ Base .\essentials.jl:892
 [10] invokelatest(::Any, ::Any, ::Vararg{Any})
    @ Base .\essentials.jl:889
 [11] inlineeval(m::Module, code::String, code_line::Int64, code_column::Int64, file::String; softscope::Bool)
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\eval.jl:271
 [12] (::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\eval.jl:181
 [13] withpath(f::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, path::String)
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\repl.jl:276
 [14] (::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\eval.jl:179
 [15] hideprompt(f::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\repl.jl:38
 [16] (::VSCodeServer.var"#67#72"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\eval.jl:150
 [17] with_logstate(f::Function, logstate::Any)
    @ Base.CoreLogging .\logging.jl:515
 [18] with_logger
    @ .\logging.jl:627 [inlined]
 [19] (::VSCodeServer.var"#66#71"{VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\eval.jl:263
 [20] #invokelatest#2
    @ .\essentials.jl:892 [inlined]
 [21] invokelatest(::Any)
    @ Base .\essentials.jl:889
 [22] (::VSCodeServer.var"#64#65")()
    @ VSCodeServer c:\Users\Matt Bossart\.vscode\extensions\julialang.language-julia-1.83.2\scripts\packages\VSCodeServer\src\eval.jl:34
in expression starting at c:\Users\Matt Bossart\OneDrive - UCB-O365\Desktop\Transient Stability\Forked Repositories\SimpleNonlinearSolve.jl\test\core\adjoint_tests.jl:18

@m-bossart
Copy link
Contributor Author

In addition to the method ambiguity with ForwardDiff...
At this point Zygote.gradient works on a problem solved with SimpleNewtonRaphson() but not NewtonRaphson(). I've updated the chain rules rule here to dispatch on ImmutableNonlinearProblem but my impression is that the outstanding issue is the rule is defined for ::typeof(SimpleNonlinearSolve.__internal_solve_up)? Should this be changed to work for both SimpleNonlinearSolve and NonlinearSolve algorithms? If so to what?

@ChrisRackauckas
Copy link
Member

Split the nothing dispatch to a separate one?

@m-bossart
Copy link
Contributor Author

I think this is ready and the adjoint tests in SimpleNonlinearSolve.jl pass with the changes to the dispatch included in SciML/SciMLSensitivity.jl#1066

@m-bossart m-bossart marked this pull request as ready for review July 17, 2024 23:25
@m-bossart
Copy link
Contributor Author

The adjoint tests i n SciMLSensitivity that use the regular Nonlinear solvers (not the simple ones) still aren't working properly. There needs to be an additional conversion in those dispatches... where is the equivalent of

function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
but for AbstractNonlinearSolveAlgorithm ?

@ChrisRackauckas
Copy link
Member

That one just hits the highest level https://github.com/SciML/DiffEqBase.jl/blob/master/src/solve.jl#L993. Only the simple methods skip it.

src/ad.jl Outdated Show resolved Hide resolved
m-bossart and others added 3 commits July 22, 2024 08:50
Co-authored-by: Christopher Rackauckas <[email protected]>
Co-authored-by: Christopher Rackauckas <[email protected]>
Co-authored-by: Christopher Rackauckas <[email protected]>
@ChrisRackauckas ChrisRackauckas merged commit e63b1a8 into SciML:main Jul 22, 2024
15 of 17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants