diff --git a/ext/EnzymeExt.jl b/ext/EnzymeExt.jl index 687c8ba8..4a56c719 100644 --- a/ext/EnzymeExt.jl +++ b/ext/EnzymeExt.jl @@ -15,14 +15,14 @@ module EnzymeExt return nothing end - function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, ::Val{tt}, subtape, args...) where {ModifiedBetween, tt, FT} - forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, tt.parameters...) + function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT} + forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) subtape[__groupindex(ctx)] = forward(Const(f), Const(ctx), args...)[1] return nothing end - function rev(ctx, f::FT, ::Val{ModifiedBetween}, ::Val{tt}, subtape, args...) where {ModifiedBetween, tt, FT} - forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, tt.parameters...) + function rev(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT} + forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) tp = subtape[__groupindex(ctx)] reverse(Const(f), Const(ctx), args..., tp) return nothing @@ -45,7 +45,6 @@ module EnzymeExt ctx = mkcontext(kernel, block, ndrange, iterspace, dynamic) ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)} - tt′ = Tuple{Const{ctxTy}, map(Core.Typeof, args)...} # TODO autodiff_deferred on the func.val ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...)) @@ -53,31 +52,30 @@ module EnzymeExt FT = Const{Core.Typeof(f)} # TODO in KA backends like CUDAKernels, etc have a version with a parent job type - TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, tt′.parameters...) + TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args)...) subtape = Array{TapeType}(undef, __groupsize(ctx)) aug_kernel = similar(kernel, aug_fwd) - vtt = Val(tt′) - aug_kernel(f, ModifiedBetween, vtt, subtape, args...; ndrange, workgroupsize) + aug_kernel(f, ModifiedBetween, subtape, args...; ndrange, workgroupsize) # TODO the fact that ctxTy is type unstable means this is all type unstable. # Since custom rules require a fixed return type, explicitly cast to Any, rather # than returning a AugmentedReturn{Nothing, Nothing, T} where T. - res = AugmentedReturn{Nothing, Nothing, Tuple{Vector, Val{T} where T, Val{T} where T}}(nothing, nothing, (subtape, ModifiedBetween, vtt)) - + res = AugmentedReturn{Nothing, Nothing, Vector}(nothing, nothing, subtape) return res end - function EnzymeRules.reverse(::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, tape, args...; ndrange=nothing, workgroupsize=nothing) + function EnzymeRules.reverse(::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, subtape, args...; ndrange=nothing, workgroupsize=nothing) kernel = func.val f = kernel.f - (subtape, ModifiedBetween, vtt) = tape + + ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...)) rev_kernel = similar(func.val, rev) - rev_kernel(f, ModifiedBetween, vtt, subtape, args...; ndrange, workgroupsize) + rev_kernel(f, ModifiedBetween, subtape, args...; ndrange, workgroupsize) return ((nothing for a in args)...,) end end