Skip to content

Commit

Permalink
fermionic ad stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
qmortier committed May 28, 2024
1 parent e7edab8 commit a2d891d
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 33 deletions.
248 changes: 216 additions & 32 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module TensorKitChainRulesCoreExt

using TensorOperations
using TensorOperations: promote_contract, Backend
using VectorInterface
using TensorKit
using ChainRulesCore
using LinearAlgebra
Expand All @@ -11,6 +13,14 @@ using TupleTools

_conj(conjA::Symbol) = conjA == :C ? :N : :C
trivtuple(N) = ntuple(identity, N)
_kron(Es::NTuple{1}, backend::Backend...) = Es[1]
function _kron(Es::NTuple{N,Any}, backend::Backend...) where {N}
E1 = Es[1]
E2 = _kron(Base.tail(Es), backend...)
p2 = ((), trivtuple(2 * N - 2))
p = ((1, (2 .+ trivtuple(N - 1))...), (2, ((N + 1) .+ trivtuple(N - 1))...))
return tensorproduct(p, E1, ((1, 2), ()), :N, E2, p2, :N, One(), backend...)

Check warning on line 22 in ext/TensorKitChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt.jl#L17-L22

Added lines #L17 - L22 were not covered by tests
end

function _repartition(p::IndexTuple, N₁::Int)
length(p) >= N₁ ||
Expand Down Expand Up @@ -114,7 +124,15 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe
dA = zerovector(A,
TensorOperations.promote_contract(scalartype(ΔC),
scalartype(B)))
dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, B, pB, :C)
tB::typeof(B) = copy(B)
if BraidingStyle(sectortype(ΔC)) isa Fermionic
for i in allind(B)
if isdual(space(B, i))
twist!(tB, i)
end
end
end
dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, tB, pB, :C)
return projectA(dA)
end
dB_ = @thunk begin
Expand All @@ -123,7 +141,15 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe
dB = zerovector(B,
TensorOperations.promote_contract(scalartype(ΔC),
scalartype(A)))
dB = tensorcontract!(dB, ipB, A, pA, :C, ΔC, pΔC, :N)
tA::typeof(A) = copy(A)
if BraidingStyle(sectortype(ΔC)) isa Fermionic
for i in allind(A)
if isdual(space(A, i))
twist!(tA, i)
end
end
end
dB = tensorcontract!(dB, ipB, tA, pA, :C, ΔC, pΔC, :N)
return projectB(dB)
end
return NoTangent(), dA_, dB_
Expand Down Expand Up @@ -650,11 +676,12 @@ function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap},
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v))
end

function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
C::AbstractTensorMap, pC::Index2Tuple,
A::AbstractTensorMap, pA::Index2Tuple, conjA::Symbol,
B::AbstractTensorMap, pB::Index2Tuple, conjB::Symbol,
α::Number, β::Number, backend::Backend...)
function ChainRulesCore.rrule( ::typeof(TensorOperations.tensorcontract!),

Check warning on line 679 in ext/TensorKitChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt.jl#L679

Added line #L679 was not covered by tests
C::AbstractTensorMap{S}, pC::Index2Tuple,
A::AbstractTensorMap{S}, pA::Index2Tuple, conjA::Symbol,
B::AbstractTensorMap{S}, pB::Index2Tuple, conjB::Symbol,
α::Number, β::Number, backend::TensorOperations.Backend...) where {S}

C′ = tensorcontract!(copy(C), pC, A, pA, conjA, B, pB, conjB, α, β, backend...)

projectA = ProjectTo(A)
Expand All @@ -666,52 +693,209 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
function pullback(ΔC′)
ΔC = unthunk(ΔC′)
ipC = invperm(linearize(pC))
pΔC = (TupleTools.getindices(ipC, trivtuple(numout(pA))),
TupleTools.getindices(ipC, numout(pA) .+ trivtuple(numin(pB))))
pΔC = (TupleTools.getindices(ipC, trivtuple(TensorOperations.numout(pA))),
TupleTools.getindices(ipC, TensorOperations.numout(pA) .+ trivtuple(TensorOperations.numin(pB))))
tΔC::typeof(ΔC) = copy(ΔC)
if BraidingStyle(sectortype(ΔC)) isa Fermionic
for i in allind(ΔC)
if isdual(space(ΔC, i))
twist!(tΔC, i)
end
end
end
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipA = (invperm(linearize(pA)), ())
conjΔC = conjA == :C ? :C : :N
conjB′ = conjA == :C ? conjB : _conj(conjB)
_dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B), typeof(α)))
_dA = tensorcontract!(_dA, ipA,
ΔC, pΔC, conjΔC,
B, reverse(pB), conjB′,
conjA == :C ? α : conj(α), Zero(), backend...)
tB::typeof(B) = copy(B)
if BraidingStyle(sectortype(ΔC)) isa Fermionic
pcodom = space.(Ref(B),pB[1])
for i in 1:length(pB[1])
if !isdual(pcodom[i])
twist!(tB, pB[1][i])
end
end
pdom = space.(Ref(B),pB[2])
for i in 1:length(pB[2])
if isdual(pdom[i])
twist!(tB, pB[2][i])
end
end
end
_dA = tensorcontract!( _dA, ipA,
ΔC, pΔC, conjΔC,
tB, reverse(pB), conjB′,
conjA == :C ? α : conj(α), Zero(), backend...)
return projectA(_dA)
end
dB = @thunk begin
ipB = (invperm(linearize(pB)), ())
conjΔC = conjB == :C ? :C : :N
conjA′ = conjB == :C ? conjA : _conj(conjA)
_dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A), typeof(α)))
_dB = tensorcontract!(_dB, ipB,
A, reverse(pA), conjA′,
ΔC, pΔC, conjΔC,
conjB == :C ? α : conj(α), Zero(), backend...)
tA::typeof(A) = copy(A)
if BraidingStyle(sectortype(ΔC)) isa Fermionic
pcodom = space.(Ref(A),pA[1])
for i in 1:length(pA[1])
if isdual(pcodom[i])
twist!(tA, pA[1][i])
end
end
pdom = space.(Ref(A),pA[2])
for i in 1:length(pA[2])
if !isdual(pdom[i])
twist!(tA, pA[2][i])
end
end
end
_dB = tensorcontract!( _dB, ipB,
tA, reverse(pA), conjA′,
ΔC, pΔC, conjΔC,
conjB == :C ? α : conj(α), Zero(), backend...)
return projectB(_dB)
end
= let tΔC = tΔC
@thunk begin
_dα = tensorscalar( tensorcontract(((), ()),
tensorcontract(pC, A, pA, conjA, B, pB, conjB),
((), trivtuple(TensorOperations.numind(pC))),
:C, tΔC,
(trivtuple(TensorOperations.numind(pC)), ()), :N,
backend...))
return projectα(_dα)
end
end
= @thunk begin
_dα = tensorscalar(tensorcontract(((), ()),
tensorcontract(pC, A, pA, conjA, B, pB,
conjB),
((), trivtuple(numind(pC))),
:C, ΔC,
(trivtuple(numind(pC)), ()), :N,
backend...))
return projectα(_dα)
= @thunk begin
_dβ = tensorscalar( tensorcontract(((), ()), C,
((), trivtuple(TensorOperations.numind(pC))), :C, tΔC,
(trivtuple(TensorOperations.numind(pC)), ()), :N,
backend...))
return projectβ(_dβ)
end
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(),
dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), dα, dβ,
dbackend...
end
return C′, pullback
end

function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
C::AbstractTensorMap{S}, pC::Index2Tuple,
A::AbstractTensorMap{S}, conjA::Symbol,
α::Number, β::Number, backend::Backend...) where{S}

C′ = tensoradd!(copy(C), pC, A, conjA, α, β, backend...)

projectA = ProjectTo(A)
projectC = ProjectTo(C)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC′)
ΔC = unthunk(ΔC′)
tΔC::typeof(ΔC) = copy(ΔC)
if BraidingStyle(sectortype(ΔC)) isa Fermionic
for i in allind(ΔC)
if isdual(space(ΔC, i))
twist!(tΔC, i)
end
end
end
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipC = invperm(linearize(pC))
_dA = zerovector(A, VectorInterface.promote_add(ΔC, α))
_dA = tensoradd!(_dA, (ipC, ()), ΔC, conjA, conjA == :N ? conj(α) : α, Zero(),
backend...)
return projectA(_dA)
end
= let tΔC = tΔC
@thunk begin
_dα = tensorscalar(tensorcontract(((), ()), A, ((), linearize(pC)),
_conj(conjA), tΔC,
(trivtuple(TensorOperations.numind(pC)),
()), :N, One(), backend...))
return projectα(_dα)
end
end
= let tΔC = tΔC
@thunk begin
_dβ = tensorscalar(tensorcontract(((), ()), C,
((), trivtuple(TensorOperations.numind(pC))), :C, tΔC,
(trivtuple(TensorOperations.numind(pC)), ()), :N, One(),
backend...))
return projectβ(_dβ)
end
end
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(), dA, NoTangent(), dα, dβ, dbackend...
end

return C′, pullback
end

function ChainRulesCore.rrule(::typeof(tensortrace!), C::AbstractTensorMap{S}, pC::Index2Tuple, A::AbstractTensorMap{S},
pA::Index2Tuple, conjA::Symbol, α::Number, β::Number,
backend::Backend...) where {S}
C′ = tensortrace!(copy(C), pC, A, pA, conjA, α, β, backend...)

projectA = ProjectTo(A)
projectC = ProjectTo(C)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC′)
ΔC = unthunk(ΔC′)
tΔC::typeof(ΔC) = copy(ΔC)
if BraidingStyle(sectortype(ΔC)) isa Fermionic
for i in allind(ΔC)
if isdual(space(ΔC, i))
twist!(tΔC, i)
end
end
end
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipC = invperm((linearize(pC)..., pA[1]..., pA[2]...))
Es = map(pA[1], pA[2]) do i1, i2
E = one(TensorOperations.tensoralloc_add(scalartype(A), ((i1,), (i2,)),
A, conjA))
if BraidingStyle(sectortype(ΔC)) isa Fermionic
conjA == :N ? twist!(E, 1) : E
end
return E
end
E = _kron(Es, backend...)
_dA = zerovector(A, VectorInterface.promote_scale(ΔC, α))
_dA = tensorproduct!(_dA, (ipC, ()), ΔC, (trivtuple(TensorOperations.numind(pC)), ()), conjA, E,
((), trivtuple(TensorOperations.numind(pA))), conjA,
conjA == :N ? conj(α) : α, Zero(), backend...)
return projectA(_dA)
end
= let tΔC = tΔC
@thunk begin
_dα = tensorscalar(tensorcontract(((), ()),
tensortrace(pC, A, pA),
((), trivtuple(TensorOperations.numind(pC))),
_conj(conjA), tΔC,
(trivtuple(TensorOperations.numind(pC)), ()), :N, One(),
backend...))
return projectα(_dα)
end
end
= @thunk begin
_dβ = tensorscalar(tensorcontract(((), ()), C,
((), trivtuple(numind(pC))), :C, ΔC,
(trivtuple(numind(pC)), ()), :N,
backend...))
((), trivtuple(TensorOperations.numind(pC))), :C, tΔC,
(trivtuple(TensorOperations.numind(pC)), ()), :N, One(),
backend...))
return projectβ(_dβ)
end
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(),
dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), dα, dβ,
dbackend...
return NoTangent(), dC, NoTangent(), dA, NoTangent(), NoTangent(), dα, dβ,
dbackend...
end

return C′, pullback
Expand Down
2 changes: 1 addition & 1 deletion test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
ℂ[Z2Irrep](0 => 3, 1 => 2)',
ℂ[Z2Irrep](0 => 2, 1 => 3),
ℂ[Z2Irrep](0 => 2, 1 => 2)),
(ℂ[FermionParity](0 => 1, 1 => 1),
(ℂ[FermionParity](0 => 1, 1 => 1),
ℂ[FermionParity](0 => 1, 1 => 2)',
ℂ[FermionParity](0 => 3, 1 => 2)',
ℂ[FermionParity](0 => 2, 1 => 3),
Expand Down

0 comments on commit a2d891d

Please sign in to comment.