Skip to content

Commit

Permalink
Refactor ChainRulesCoreExt (#133)
Browse files Browse the repository at this point in the history
* Refactor ChainRulesCoreExt into separate files

* Add rrule test `dot`

* Add tests `ProjectTo`

* Add tests rrules converters

* Restrict convert rrule to trivialtensormap

* Add kwargs to rrule

* Add unthunk

* Refactor array -> tensormap conversion

* Reenable constructor ad tests

* Refactor _interleave

* add converter for fusiontreepair to array

* Refactor `project_symmetric!` into separate function

* Add rule for (not) generating tangents for vector spaces
  • Loading branch information
lkdvos committed Jul 2, 2024
1 parent 3275ffb commit 1331067
Show file tree
Hide file tree
Showing 11 changed files with 431 additions and 397 deletions.
28 changes: 28 additions & 0 deletions ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module TensorKitChainRulesCoreExt

using TensorOperations
using VectorInterface
using TensorKit
using ChainRulesCore
using LinearAlgebra
using TupleTools

import TensorOperations as TO
using TensorOperations: Backend, promote_contract
using VectorInterface: promote_scale, promote_add

ext = @static if isdefined(Base, :get_extension)
Base.get_extension(TensorOperations, :TensorOperationsChainRulesCoreExt)
else
TensorOperations.TensorOperationsChainRulesCoreExt
end
const _conj = ext._conj
const trivtuple = ext.trivtuple

include("utility.jl")
include("constructors.jl")
include("linalg.jl")
include("tensoroperations.jl")
include("factorizations.jl")

end
49 changes: 49 additions & 0 deletions ext/TensorKitChainRulesCoreExt/constructors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom)
@non_differentiable TensorKit.id(args...)
@non_differentiable TensorKit.isomorphism(args...)
@non_differentiable TensorKit.isometry(args...)
@non_differentiable TensorKit.unitary(args...)

function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwargs...)
function TensorMap_pullback(Δt)
∂d = convert(Array, unthunk(Δt))
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
end
return TensorMap(d, args...; kwargs...), TensorMap_pullback
end

function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
copy_pullback(Δt) = NoTangent(), Δt
return copy(t), copy_pullback
end

function ChainRulesCore.rrule(::typeof(Base.convert), T::Type{<:Array},
t::AbstractTensorMap)
A = convert(T, t)
function convert_pullback(ΔA)
# use constructor to (unconditionally) project back onto symmetric subspace
∂t = TensorMap(unthunk(ΔA), codomain(t), domain(t); tol=Inf)
return NoTangent(), NoTangent(), ∂t
end
return A, convert_pullback
end

function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap)
out = convert(Dict, t)
function convert_pullback(c′)
c = unthunk(c′)
if haskey(c, :data) # :data is the only thing for which this dual makes sense
dual = copy(out)
dual[:data] = c[:data]
return (NoTangent(), NoTangent(), convert(TensorMap, dual))
else
# instead of zero(t) you can also return ZeroTangent(), which is type unstable
return (NoTangent(), NoTangent(), zero(t))
end
end
return out, convert_pullback
end
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap},
t::Dict{Symbol,Any})
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v))
end
Loading

0 comments on commit 1331067

Please sign in to comment.