Skip to content

Commit

Permalink
Use custom vjp
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 26, 2024
1 parent 051c24c commit 3f9124e
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 16 deletions.
49 changes: 33 additions & 16 deletions src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,39 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..

uu = sol.u

if !SciMLBase.has_jac(prob.f)
# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
# nested autodiff as the last resort
if SciMLBase.has_vjp(prob.f)
if isinplace(prob)
_F = @closure (du, u, p) -> begin
resid = similar(du, length(sol.resid))
prob.f(resid, u, p)
prob.f.vjp(du, resid, u, p)
du .*= 2
return nothing
end
else
_F = @closure (u, p) -> begin
resid = prob.f(u, p)
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
end
end
elseif SciMLBase.has_jac(prob.f)
if isinplace(prob)
_F = @closure (du, u, p) -> begin
J = similar(du, length(sol.resid), length(u))
prob.f.jac(J, u, p)
resid = similar(du, length(sol.resid))
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
end
else
_F = @closure (u, p) -> begin
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
end
end
else
if isinplace(prob)
_F = @closure (du, u, p) -> begin
resid = similar(du, length(sol.resid))
Expand Down Expand Up @@ -103,21 +135,6 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
end
end
end
else
if isinplace(prob)
_F = @closure (du, u, p) -> begin
J = similar(du, length(sol.resid), length(u))
prob.jac(J, u, p)
resid = similar(du, length(sol.resid))
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
end
else
_F = @closure (u, p) -> begin
return reshape(2 .* vec(prob.f(u, p))' * prob.jac(u, p), size(u))
end
end
end

f_p = __nlsolve_∂f_∂p(prob, _F, uu, p)
Expand Down
118 changes: 118 additions & 0 deletions test/core/forward_ad_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,121 @@ end
end
end
end

@testsetup module ForwardADNLLSTesting
using Reexport
@reexport using ForwardDiff, FiniteDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra,
Zygote

true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])

const θ_true = [1.0, 0.1, 2.0, 0.5]
const x = [-1.0, -0.5, 0.0, 0.5, 1.0]
const y_target = true_function(x, θ_true)

function loss_function(θ, p)
= true_function(p, θ)
return.- y_target
end

function loss_function_jac(θ, p)
return ForwardDiff.jacobian-> loss_function(θ, p), θ)
end

loss_function_vjp(v, θ, p) = reshape(vec(v)' * loss_function_jac(θ, p), size(θ))

function loss_function!(resid, θ, p)
= true_function(p, θ)
@. resid =- y_target
return
end

function loss_function_jac!(J, θ, p)
J .= ForwardDiff.jacobian-> loss_function(θ, p), θ)
return
end

function loss_function_vjp!(vJ, v, θ, p)
vec(vJ) .= reshape(vec(v)' * loss_function_jac(θ, p), size(θ))
return
end

θ_init = θ_true .+ 0.1

export loss_function, loss_function!, loss_function_jac, loss_function_vjp,
loss_function_jac!, loss_function_vjp!, θ_init, x, y_target
end

@testitem "ForwardDiff.jl Integration: NLLS" setup=[ForwardADNLLSTesting] begin
@testset "$(nameof(typeof(alg)))" for alg in (
SimpleNewtonRaphson(), SimpleGaussNewton(),
SimpleNewtonRaphson(AutoFiniteDiff()), SimpleGaussNewton(AutoFiniteDiff()))
function obj_1(p)
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, p)
sol = solve(prob_oop, alg)
return sum(abs2, sol.u)
end

function obj_2(p)
ff = NonlinearFunction{false}(loss_function; jac = loss_function_jac)
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
sol = solve(prob_oop, alg)
return sum(abs2, sol.u)
end

function obj_3(p)
ff = NonlinearFunction{false}(loss_function; vjp = loss_function_vjp)
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
sol = solve(prob_oop, alg)
return sum(abs2, sol.u)
end

finitediff = FiniteDiff.finite_difference_gradient(obj_1, x)

fdiff1 = ForwardDiff.gradient(obj_1, x)
fdiff2 = ForwardDiff.gradient(obj_2, x)
fdiff3 = ForwardDiff.gradient(obj_3, x)

@test finitedifffdiff1 atol=1e-5
@test finitedifffdiff2 atol=1e-5
@test finitedifffdiff3 atol=1e-5
@test fdiff1 fdiff2 fdiff3

function obj_4(p)
prob_iip = NonlinearLeastSquaresProblem(
NonlinearFunction{true}(
loss_function!; resid_prototype = zeros(length(y_target))), θ_init, p)
sol = solve(prob_iip, alg)
return sum(abs2, sol.u)
end

function obj_5(p)
ff = NonlinearFunction{true}(
loss_function!; resid_prototype = zeros(length(y_target)), jac = loss_function_jac!)
prob_iip = NonlinearLeastSquaresProblem(
ff, θ_init, p)
sol = solve(prob_iip, alg)
return sum(abs2, sol.u)
end

function obj_6(p)
ff = NonlinearFunction{true}(
loss_function!; resid_prototype = zeros(length(y_target)), vjp = loss_function_vjp!)
prob_iip = NonlinearLeastSquaresProblem(
ff, θ_init, p)
sol = solve(prob_iip, alg)
return sum(abs2, sol.u)
end

finitediff = FiniteDiff.finite_difference_gradient(obj_4, x)

fdiff4 = ForwardDiff.gradient(obj_4, x)
fdiff5 = ForwardDiff.gradient(obj_5, x)
fdiff6 = ForwardDiff.gradient(obj_6, x)

@test finitedifffdiff4 atol=1e-5
@test finitedifffdiff5 atol=1e-5
@test finitedifffdiff6 atol=1e-5
@test fdiff4 fdiff5 fdiff6
end
end

0 comments on commit 3f9124e

Please sign in to comment.