Skip to content

Commit

Permalink
improve
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jul 6, 2023
1 parent 7edacce commit 5854a45
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions ext/EnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ module EnzymeExt
return nothing
end

function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, ::Val{tt}, subtape, args...) where {ModifiedBetween, tt, FT}
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, tt.parameters...)
function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
subtape[__groupindex(ctx)] = forward(Const(f), Const(ctx), args...)[1]
return nothing
end

function rev(ctx, f::FT, ::Val{ModifiedBetween}, ::Val{tt}, subtape, args...) where {ModifiedBetween, tt, FT}
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, tt.parameters...)
function rev(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
tp = subtape[__groupindex(ctx)]
reverse(Const(f), Const(ctx), args..., tp)
return nothing
Expand All @@ -45,39 +45,37 @@ module EnzymeExt

ctx = mkcontext(kernel, block, ndrange, iterspace, dynamic)
ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)}
tt′ = Tuple{Const{ctxTy}, map(Core.Typeof, args)...}

# TODO autodiff_deferred on the func.val
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

FT = Const{Core.Typeof(f)}

# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, tt′.parameters...)
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args)...)

subtape = Array{TapeType}(undef, __groupsize(ctx))

aug_kernel = similar(kernel, aug_fwd)

vtt = Val(tt′)
aug_kernel(f, ModifiedBetween, vtt, subtape, args...; ndrange, workgroupsize)
aug_kernel(f, ModifiedBetween, subtape, args...; ndrange, workgroupsize)

# TODO the fact that ctxTy is type unstable means this is all type unstable.
# Since custom rules require a fixed return type, explicitly cast to Any, rather
# than returning a AugmentedReturn{Nothing, Nothing, T} where T.

res = AugmentedReturn{Nothing, Nothing, Tuple{Vector, Val{T} where T, Val{T} where T}}(nothing, nothing, (subtape, ModifiedBetween, vtt))

res = AugmentedReturn{Nothing, Nothing, Vector}(nothing, nothing, subtape)
return res
end

function EnzymeRules.reverse(::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, tape, args...; ndrange=nothing, workgroupsize=nothing)
function EnzymeRules.reverse(::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, subtape, args...; ndrange=nothing, workgroupsize=nothing)
kernel = func.val
f = kernel.f
(subtape, ModifiedBetween, vtt) = tape

ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

rev_kernel = similar(func.val, rev)
rev_kernel(f, ModifiedBetween, vtt, subtape, args...; ndrange, workgroupsize)
rev_kernel(f, ModifiedBetween, subtape, args...; ndrange, workgroupsize)
return ((nothing for a in args)...,)
end
end

0 comments on commit 5854a45

Please sign in to comment.