diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index d1f21b32..51ab8bd1 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -665,6 +665,27 @@ end function construct(backend::Backend, ::S, ::NDRange, xpu_name::XPUName) where {Backend<:Union{CPU,GPU}, S<:_Size, NDRange<:_Size, XPUName} return Kernel{Backend, S, NDRange, XPUName}(backend, xpu_name) end + +""" + argconvert(::Kernel, arg) + +Convert arguments to the device side representation. +""" +argconvert(k::Kernel{T}, arg) where T = + error("Don't know how to convert arguments for Kernel{$T}") + +# Enzyme support +supports_enzyme(::Backend) = false +function __fake_compiler_job end + +### +# Extras +# - LoopInfo +### + +include("extras/extras.jl") + +include("reflection.jl") ### # Compiler @@ -714,27 +735,6 @@ end __size(args::Tuple) = Tuple{args...} __size(i::Int) = Tuple{i} -""" - argconvert(::Kernel, arg) - -Convert arguments to the device side representation. -""" -argconvert(k::Kernel{T}, arg) where T = - error("Don't know how to convert arguments for Kernel{$T}") - -# Enzyme support -supports_enzyme(::Backend) = false -function __fake_compiler_job end - -### -# Extras -# - LoopInfo -### - -include("extras/extras.jl") - -include("reflection.jl") - # Initialized @kernel function init_kernel(arr, f::F, ::Type{T}) where {F, T} diff --git a/src/extras/loopinfo.jl b/src/extras/loopinfo.jl index bb7e6385..b18a9561 100644 --- a/src/extras/loopinfo.jl +++ b/src/extras/loopinfo.jl @@ -17,6 +17,8 @@ module MD unroll_disable() = (Symbol("llvm.loop.unroll.disable"), 1) unroll_enable() = (Symbol("llvm.loop.unroll.enable"), 1) unroll_full() = (Symbol("llvm.loop.unroll.full"), 1) + simd() = Symbol("julia.simdloop") + ivdep() = Symbol("julia.ivdep") end function loopinfo(expr, nodes...) diff --git a/src/macros.jl b/src/macros.jl index a659551d..8766a487 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -238,6 +238,8 @@ function split(stmts, return new_stmts end +import .Extras: LoopInfo + function emit(loop) idx = gensym(:I) for stmt in loop.indicies @@ -285,6 +287,7 @@ function emit(loop) $__validindex(__ctx__, $idx) || continue $(loop.indicies...) $(unblock(body)) + $(Expr(:loopinfo, LoopInfo.MD.simd(), LoopInfo.MD.ivdep())) end end push!(stmts, loopexpr)