diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 0d1d0825..27208c3a 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -385,15 +385,14 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector # check whether cotangents arise from gauge-invariance objective function mask = abs.(Sp' .- Sp) .< tol - gaugepart = view(aUΔU, mask) + view(aVΔV, mask) - norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice" + Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) if p > r rprange = (r + 1):p - norm(view(aUΔU, rprange, rprange), Inf) < tol || - @warn "cotangents sensitive to gauge choice" - norm(view(aVΔV, rprange, rprange), Inf) < tol || - @warn "cotangents sensitive to gauge choice" + Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf)) + Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf)) end + Δgauge < tol || + @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+ (aUΔU .- aVΔV) .* safe_inv.(Sp' .+ Sp, tol) @@ -461,8 +460,9 @@ function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix VdΔV = V' * ΔV mask = abs.(transpose(D) .- D) .< tol - gaugepart = view(VdΔV, mask) - norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice" + Δgauge = norm(view(VdΔV, mask), Inf) + Δgauge < tol || + @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" VdΔV .*= conj.(safe_inv.(transpose(D) .- D, tol)) @@ -504,8 +504,9 @@ function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatri aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2) mask = abs.(D' .- D) .< tol - gaugepart = view(aVdΔV, mask) - norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice" + Δgauge = norm(view(aVdΔV, mask)) + Δgauge < tol || + @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" aVdΔV .*= safe_inv.(D' .- D, tol) @@ -567,8 +568,9 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, Q2 = view(Q, :, (p + 1):m) ΔQ2 = view(ΔQ, :, (p + 1):m) Q1dΔQ2 = Q1' * ΔQ2 - gaugepart = mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1) - norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice" + Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) + Δgauge < tol || + @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" mul!(ΔA1, Q2, Q1dΔQ2', -1, 1) end rdiv!(ΔA1, UpperTriangular(R11)') @@ -621,8 +623,9 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, Q2 = view(Q, (p + 1):n, :) ΔQ2 = view(ΔQ, (p + 1):n, :) ΔQ2Q1d = ΔQ2 * Q1' - gaugepart = mul!(copy(ΔQ2), ΔQ2Q1d, Q1, -1, 1) - norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice" + Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1d, Q1, -1, 1)) + Δgauge < tol || + @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" mul!(ΔA1, ΔQ2Q1d', Q2, -1, 1) end ldiv!(LowerTriangular(L11)', ΔA1)