Skip to content

Commit

Permalink
Improve warnings in AD rules (#122)
Browse files Browse the repository at this point in the history
* Improve warnings in AD rules

* Refactor variable names and reduce number of warnings
  • Loading branch information
lkdvos committed May 12, 2024
1 parent 7ee775b commit 4607691
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,15 +385,14 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector

# check whether cotangents arise from gauge-invariance objective function
mask = abs.(Sp' .- Sp) .< tol
gaugepart = view(aUΔU, mask) + view(aVΔV, mask)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
if p > r
rprange = (r + 1):p
norm(view(aUΔU, rprange, rprange), Inf) < tol ||
@warn "cotangents sensitive to gauge choice"
norm(view(aVΔV, rprange, rprange), Inf) < tol ||
@warn "cotangents sensitive to gauge choice"
Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf))
Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf))
end
Δgauge < tol ||
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"

UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+
(aUΔU .- aVΔV) .* safe_inv.(Sp' .+ Sp, tol)
Expand Down Expand Up @@ -461,8 +460,9 @@ function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix
VdΔV = V' * ΔV

mask = abs.(transpose(D) .- D) .< tol
gaugepart = view(VdΔV, mask)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"
Δgauge = norm(view(VdΔV, mask), Inf)
Δgauge < tol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"

VdΔV .*= conj.(safe_inv.(transpose(D) .- D, tol))

Expand Down Expand Up @@ -504,8 +504,9 @@ function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatri
aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2)

mask = abs.(D' .- D) .< tol
gaugepart = view(aVdΔV, mask)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"
Δgauge = norm(view(aVdΔV, mask))
Δgauge < tol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"

aVdΔV .*= safe_inv.(D' .- D, tol)

Expand Down Expand Up @@ -567,8 +568,9 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix,
Q2 = view(Q, :, (p + 1):m)
ΔQ2 = view(ΔQ, :, (p + 1):m)
Q1dΔQ2 = Q1' * ΔQ2
gaugepart = mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge < tol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
mul!(ΔA1, Q2, Q1dΔQ2', -1, 1)
end
rdiv!(ΔA1, UpperTriangular(R11)')
Expand Down Expand Up @@ -621,8 +623,9 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
Q2 = view(Q, (p + 1):n, :)
ΔQ2 = view(ΔQ, (p + 1):n, :)
ΔQ2Q1d = ΔQ2 * Q1'
gaugepart = mul!(copy(ΔQ2), ΔQ2Q1d, Q1, -1, 1)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1d, Q1, -1, 1))
Δgauge < tol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
mul!(ΔA1, ΔQ2Q1d', Q2, -1, 1)
end
ldiv!(LowerTriangular(L11)', ΔA1)
Expand Down

0 comments on commit 4607691

Please sign in to comment.