diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp index 837d02838cf1..2c10bb6fb36e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp @@ -32,12 +32,18 @@ struct GPUDistributePass : public GPUDistributeBase { getEntryPoint(funcOp)->getWorkgroupSize().value(), [&](Attribute attr) { return llvm::cast(attr).getInt(); }); + // TODO: Thread through subgroup size everywhere. + std::optional maybeSubgroupSize = + getEntryPoint(funcOp)->getSubgroupSize(); + // TODO: Don't hard code kCudaWarpSize here. + int64_t subgroupSize = + maybeSubgroupSize ? maybeSubgroupSize->getSExtValue() : 32; + IRRewriter rewriter(funcOp->getContext()); rewriter.setInsertionPointToStart(&funcOp.getBody().front()); DiagnosedSilenceableFailure result = mlir::transform::gpu::mapNestedForallToThreadsImpl( - rewriter, std::nullopt, funcOp, workgroupSize, /*warpDims=*/{}, - false); + rewriter, std::nullopt, funcOp, workgroupSize, subgroupSize, false); if (!result.succeeded()) return signalPassFailure(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp index fc63de684d4d..fbbfc53a3016 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp @@ -227,9 +227,9 @@ static LogicalResult tileParallelDims(func::FuncOp funcOp, SmallVector idDims; auto getThreadMapping = [&](int64_t dim) { return mlir::gpu::GPUThreadMappingAttr::get( - tilingOp->getContext(), dim == 0 ? mlir::gpu::Threads::DimX - : dim == 1 ? mlir::gpu::Threads::DimY - : mlir::gpu::Threads::DimZ); + tilingOp->getContext(), dim == 0 ? mlir::gpu::MappingId::DimX + : dim == 1 ? mlir::gpu::MappingId::DimY + : mlir::gpu::MappingId::DimZ); }; for (unsigned loop : llvm::reverse(partitionedLoops)) { int64_t num = elementPerWorkgroup[id++]; diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index b1b431f8d588..0466442bdd06 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -458,9 +458,9 @@ LogicalResult rewriteForallToWorkgroup(RewriterBase &rewriter, MLIRContext *ctx = forallOp->getContext(); Location loc = forallOp->getLoc(); // TODO iree should have own device mapping like #hal.workgroup - Attribute bX = gpu::GPUBlockMappingAttr::get(ctx, gpu::Blocks::DimX); - Attribute bY = gpu::GPUBlockMappingAttr::get(ctx, gpu::Blocks::DimY); - Attribute bZ = gpu::GPUBlockMappingAttr::get(ctx, gpu::Blocks::DimZ); + Attribute bX = gpu::GPUBlockMappingAttr::get(ctx, gpu::MappingId::DimX); + Attribute bY = gpu::GPUBlockMappingAttr::get(ctx, gpu::MappingId::DimY); + Attribute bZ = gpu::GPUBlockMappingAttr::get(ctx, gpu::MappingId::DimZ); if (forallOp.getNumResults() > 0) return forallOp->emitError( "only bufferized scf.forall lowers to workgroup"); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel index e4ff4144ca5b..e7b0db3569f1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel @@ -135,7 +135,6 @@ iree_compiler_cc_library( "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:LinalgUtils", "@llvm-project//mlir:MathDialect", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index 39390453aee5..692db84dbc5b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -107,7 +107,6 @@ iree_cc_library( MLIRLLVMCommonConversion MLIRLLVMDialect MLIRLinalgDialect - MLIRLinalgToLLVM MLIRLinalgTransforms MLIRLinalgUtils MLIRMathDialect diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp index b3a1d33f450a..8aaf7a4e0057 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp @@ -29,7 +29,6 @@ #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" @@ -1051,7 +1050,6 @@ void ConvertToLLVMPass::runOnOperation() { populateVectorToLLVMMatrixConversionPatterns(typeConverter, patterns); populateVectorToLLVMConversionPatterns( typeConverter, patterns, targetReassociateFpReductions.getValue()); - populateLinalgToLLVMConversionPatterns(typeConverter, patterns); populateReconcileUnrealizedCastsPatterns(patterns); HALDispatchABI abi(&typeConverter); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index f101d90d2993..8a8431855bb0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -102,35 +102,19 @@ transform_dialect::MapNestedForallToGpuThreadsOp::applyToOne( rewriter.setInsertionPointToStart(&target.getBody().front()); DiagnosedSilenceableFailure diag = mlir::transform::gpu::mapNestedForallToThreadsImpl( - rewriter, transformOp, target, getWorkgroupDims(), getWarpDims(), + rewriter, transformOp, target, getWorkgroupDims(), getSubgroupSize(), true); if (!diag.succeeded()) return diag; auto newAttr = rewriter.getIndexArrayAttr(getWorkgroupDims()); + auto subgroupSizeAttr = rewriter.getIndexAttr(getSubgroupSize()); rewriter.startRootUpdate(exportOp); exportOp->setAttr(exportOp.getWorkgroupSizeAttrName(), newAttr); - if (std::optional subgroupSize = getSubgroupSize()) { - auto subgroupSizeAttr = rewriter.getIndexAttr(*subgroupSize); - exportOp->setAttr(exportOp.getSubgroupSizeAttrName(), subgroupSizeAttr); - } + exportOp->setAttr(exportOp.getSubgroupSizeAttrName(), subgroupSizeAttr); rewriter.finalizeRootUpdate(exportOp); return DiagnosedSilenceableFailure::success(); } -void transform_dialect::MapNestedForallToGpuThreadsOp::build( - OpBuilder &builder, OperationState &state, Value target, - ArrayRef workgroupDims, ArrayRef warpDims) { - build(builder, state, {}, target, workgroupDims, warpDims, IntegerAttr()); -} - -void transform_dialect::MapNestedForallToGpuThreadsOp::build( - OpBuilder &builder, OperationState &state, Value target, - ArrayRef workgroupDims, ArrayRef warpDims, - int64_t subgroupSize) { - build(builder, state, {}, target, workgroupDims, warpDims, - builder.getI64IntegerAttr(subgroupSize)); -} - void transform_dialect::MapNestedForallToGpuThreadsOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getTarget(), effects); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td index 8956a6ba77ab..718a82d84d51 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td @@ -89,29 +89,18 @@ def MapNestedForallToGpuThreadsOp : let arguments = (ins TransformHandleTypeInterface:$target, DefaultValuedAttr:$workgroup_dims, - DefaultValuedOptionalAttr:$warp_dims, - OptionalAttr:$subgroup_size); + DefaultValuedOptionalAttr:$subgroup_size); let results = (outs); let assemblyFormat = [{ $target `workgroup_dims` `=` $workgroup_dims - (`warp_dims` `=` $warp_dims^)? (`subgroup_size` `=` $subgroup_size^)? attr-dict `:` functional-type($target, results) }]; let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect"; - let builders = [ - OpBuilder<(ins "Value":$target, - "ArrayRef":$workgroup_dims, - "ArrayRef":$warp_dims)>, - OpBuilder<(ins "Value":$target, - "ArrayRef":$workgroup_dims, - "ArrayRef":$warp_dims, - "int64_t":$subgroupSize)> - ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::transform::TransformRewriter &rewriter, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp index a07271b5cb13..824251be55c6 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp @@ -114,8 +114,7 @@ static Value getMaskValue(RewriterBase &rewriter, Operation *op) { vector::ExtractOp maybeExtractOp = maskResult.maybeExtractOp; if (maybeExtractOp) { assert(maybeExtractOp.getPosition().size() == 1 && "expected single pos"); - int64_t sliceNum = - llvm::cast(maybeExtractOp.getPosition()[0]).getInt(); + int64_t sliceNum = maybeExtractOp.getPosition()[0]; // TODO: to support >2-D mask + extract, and all the cmp. Location loc = op->getLoc(); Value zero = rewriter.create(loc, 0); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir index 70b8cc6b38e5..37645faf6be7 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir @@ -76,7 +76,7 @@ transform.sequence failures(propagate) { // Tile and fuse attention ops // ========================================== - %forall, %tiled_matmul = transform.structured.tile_to_forall_op %promoted_second_matmul tile_sizes [32] (mapping = [#gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %forall, %tiled_matmul = transform.structured.tile_to_forall_op %promoted_second_matmul tile_sizes [32] (mapping = [#gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %f0, %loop0 = transform.structured.fuse_into_containing_op %scale_acc into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) %f1, %loop1 = transform.structured.fuse_into_containing_op %truncate into %loop0 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) @@ -101,7 +101,7 @@ transform.sequence failures(propagate) { // Distribute fills and last truncate // ========================================== %fills = transform.merge_handles %acc_fill, %max_fill, %sum_fill, %last_truncate : !transform.any_op - %fill_grid, %tiled_fill = transform.structured.tile_to_forall_op %fills tile_sizes[32] (mapping = [#gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fill_grid, %tiled_fill = transform.structured.tile_to_forall_op %fills tile_sizes[32] (mapping = [#gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) // Vectorize function // ========================================== @@ -137,7 +137,7 @@ transform.sequence failures(propagate) { // =========================================================================== %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> () - transform.iree.map_nested_forall_to_gpu_threads %func_7 workgroup_dims = [4, 8, 4] warp_dims = [4, 1, 1] : (!transform.any_op) -> () + transform.iree.map_nested_forall_to_gpu_threads %func_7 workgroup_dims = [4, 8, 4] subgroup_size = 32 : (!transform.any_op) -> () transform.apply_patterns to %func_7 { transform.apply_patterns.memref.fold_memref_alias_ops @@ -158,7 +158,7 @@ transform.sequence failures(propagate) { // CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 128)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s2 * 32 + ((s0 + s1 * 4) floordiv 32) * 32 - ((s2 + (s0 + s1 * 4) floordiv 32) floordiv 4) * 128)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s2 * 32 + ((s0 + s1 * 4) floordiv 32) * 32)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir index df8c894cdbd1..867ac4b31bf6 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir @@ -34,7 +34,7 @@ module attributes {hal.device.targets = [#device_target_cuda]} { // CHECK-NEXT: return // workgroup_size is explicitly set to [10, 11]. - // FOREACH-TO-GPU-DAG: hal.executable.export {{.*}}{translation_info = #translation, workgroup_size = [10 : index, 11 : index, 1 : index]} + // FOREACH-TO-GPU-DAG: hal.executable.export {{.*}}{subgroup_size = 32 : index, translation_info = #translation, workgroup_size = [10 : index, 11 : index, 1 : index]} // FOREACH-TO-GPU-DAG: %[[C0:.*]] = arith.constant 0 : index // FOREACH-TO-GPU-DAG: %[[C1:.*]] = arith.constant 1 : index // FOREACH-TO-GPU-DAG: %[[C5:.*]] = arith.constant 5 : index diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir index d68fc9bf7f10..dcd19fc5a9b9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir @@ -114,8 +114,8 @@ module attributes {hal.device.targets = [#device_target_cuda]} { // CHECK: %[[RHS_DPS:.+]] = transform.structured.rewrite_in_destination_passing_style %[[RHS]] // CHECK: transform.structured.tile_to_forall_op %[[LHS]] -// DEFAULT: num_threads [1, 32, 4] tile_sizes [](mapping = [#gpu.linear, #gpu.linear, #gpu.linear]) -// OPTIONS: num_threads [1, 64, 2] tile_sizes [](mapping = [#gpu.linear, #gpu.linear, #gpu.linear]) +// DEFAULT: num_threads [1, 32, 4] tile_sizes [](mapping = [#gpu.thread, #gpu.thread, #gpu.thread]) +// OPTIONS: num_threads [1, 64, 2] tile_sizes [](mapping = [#gpu.thread, #gpu.thread, #gpu.thread]) // CHECK: apply_patterns // CHECK: transform.iree.apply_licm // CHECK: transform.iree.apply_cse @@ -123,15 +123,15 @@ module attributes {hal.device.targets = [#device_target_cuda]} { // CHECK: transform.scf.take_assumed_branch %{{.*}} take_else_branch // CHECK: transform.structured.tile_to_forall_op %[[RHS_DPS]] -// DEFAULT: num_threads [8, 16, 1] tile_sizes [](mapping = [#gpu.linear, #gpu.linear, #gpu.linear]) -// OPTIONS: num_threads [2, 8, 8] tile_sizes [](mapping = [#gpu.linear, #gpu.linear, #gpu.linear]) +// DEFAULT: num_threads [8, 16, 1] tile_sizes [](mapping = [#gpu.thread, #gpu.thread, #gpu.thread]) +// OPTIONS: num_threads [2, 8, 8] tile_sizes [](mapping = [#gpu.thread, #gpu.thread, #gpu.thread]) // CHECK: apply_patterns // CHECK: transform.iree.apply_licm // CHECK: transform.iree.apply_cse // CHECK: transform.structured.tile_to_forall_op -// DEFAULT: num_threads [2, 64, 1] tile_sizes [](mapping = [#gpu.linear, #gpu.linear, #gpu.linear]) -// OPTIONS: num_threads [1, 16, 8] tile_sizes [](mapping = [#gpu.linear, #gpu.linear, #gpu.linear]) +// DEFAULT: num_threads [2, 64, 1] tile_sizes [](mapping = [#gpu.thread, #gpu.thread, #gpu.thread]) +// OPTIONS: num_threads [1, 16, 8] tile_sizes [](mapping = [#gpu.thread, #gpu.thread, #gpu.thread]) // CHECK: apply_patterns // CHECK: transform.iree.apply_licm // CHECK: transform.iree.apply_cse @@ -175,8 +175,8 @@ module attributes {hal.device.targets = [#device_target_cuda]} { // CHECK: transform.iree.apply_buffer_optimizations // CHECK: transform.iree.forall_to_workgroup // CHECK: transform.iree.map_nested_forall_to_gpu_threads -// DEFAULT: workgroup_dims = [64, 2, 1] warp_dims = [2, 2, 1] -// OPTIONS: workgroup_dims = [32, 4, 1] warp_dims = [1, 4, 1] +// DEFAULT: workgroup_dims = [64, 2, 1] +// OPTIONS: workgroup_dims = [32, 4, 1] // CHECK: transform.iree.eliminate_gpu_barriers // CHECK: apply_patterns // CHECK: transform.iree.apply_licm diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_convolution.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_convolution.mlir index a75b5716b343..f51b499f98a4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_convolution.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_convolution.mlir @@ -49,8 +49,8 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // CHECK: %[[LHS:.+]] = get_producer_of_operand %{{.*}}[0] // CHECK: %[[RHS:.+]] = get_producer_of_operand %{{.*}}[1] // CHECK: transform.structured.rewrite_in_destination_passing_style %[[LHS]] -// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [32, 4] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) -// CHECK: transform.structured.tile_to_forall_op %[[RHS]] num_threads [1, 4, 32] tile_sizes [](mapping = [#gpu.linear, #gpu.linear, #gpu.linear]) +// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [32, 4] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) +// CHECK: transform.structured.tile_to_forall_op %[[RHS]] num_threads [1, 4, 32] tile_sizes [](mapping = [#gpu.thread, #gpu.thread, #gpu.thread]) // CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [1, 2, 2] tile_sizes [](mapping = [#gpu.warp, #gpu.warp, #gpu.warp]) // CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [1, 2, 2] tile_sizes [](mapping = [#gpu.warp, #gpu.warp, #gpu.warp]) // CHECK: transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface @@ -61,7 +61,7 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // CHECK: transform.iree.bufferize {target_gpu} // CHECK: transform.iree.apply_buffer_optimizations // CHECK: transform.iree.forall_to_workgroup -// CHECK: transform.iree.map_nested_forall_to_gpu_threads %{{.*}} workgroup_dims = [64, 2, 1] warp_dims = [2, 2, 1] +// CHECK: transform.iree.map_nested_forall_to_gpu_threads %{{.*}} workgroup_dims = [64, 2, 1] // CHECK: transform.iree.hoist_static_alloc %{{.*}} // CHECK: transform.apply_patterns.memref.fold_memref_alias_ops // CHECK: transform.apply_patterns.memref.extract_address_computations @@ -108,11 +108,11 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // CHECK: %[[LHS:.+]] = get_producer_of_operand %{{.*}}[0] // CHECK: %[[RHS:.+]] = get_producer_of_operand %{{.*}}[1] // CHECK: transform.structured.rewrite_in_destination_passing_style %[[RHS]] -// CHECK: transform.structured.tile_to_forall_op %[[LHS]] num_threads [1, 32, 4] tile_sizes [](mapping = [#gpu.linear, #gpu.linear, #gpu.linear]) -// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [4, 32] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) +// CHECK: transform.structured.tile_to_forall_op %[[LHS]] num_threads [1, 32, 4] tile_sizes [](mapping = [#gpu.thread, #gpu.thread, #gpu.thread]) +// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [4, 32] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [1, 2, 2] tile_sizes [](mapping = [#gpu.warp, #gpu.warp, #gpu.warp]) // CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [1, 2, 2] tile_sizes [](mapping = [#gpu.warp, #gpu.warp, #gpu.warp]) -// CHECK: transform.iree.map_nested_forall_to_gpu_threads %{{.*}} workgroup_dims = [64, 2, 1] warp_dims = [2, 2, 1] +// CHECK: transform.iree.map_nested_forall_to_gpu_threads %{{.*}} workgroup_dims = [64, 2, 1] // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_matmul.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_matmul.mlir index 2c573f3881fb..5bc3b6cb3ba4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_matmul.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_matmul.mlir @@ -76,11 +76,11 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // CHECK: transform.structured.pad %{{.*}} {copy_back = false, pack_paddings = [1, 1, 1], pad_to_multiple_of = [1, 1, 1], padding_dimensions = [0, 1, 2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]} // CHECK: transform.structured.hoist_pad %{{.}} by 1 loops // CHECK: transform.structured.insert_slice_to_copy %{{.*}} : (!transform.any_op) -> !transform.any_op -// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [32, 4] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) +// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [32, 4] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // CHECK: transform.scf.take_assumed_branch %{{.*}} take_else_branch : (!transform.any_op) -> () -// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [4, 32] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) +// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [4, 32] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // CHECK: transform.scf.take_assumed_branch %{{.*}} take_else_branch : (!transform.any_op) -> () -// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [4, 32] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) +// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [4, 32] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp, #gpu.warp]) // CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp, #gpu.warp]) // CHECK: transform.structured.masked_vectorize %{{.*}} vector_sizes [4, 4] @@ -91,7 +91,7 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // CHECK: transform.iree.eliminate_empty_tensors %{{.*}} // CHECK: transform.iree.bufferize {target_gpu} %{{.*}} // CHECK: transform.iree.forall_to_workgroup %{{.*}} -// CHECK: transform.iree.map_nested_forall_to_gpu_threads %{{.*}} workgroup_dims = [64, 2, 1] warp_dims = [2, 2, 1] +// CHECK: transform.iree.map_nested_forall_to_gpu_threads %{{.*}} workgroup_dims = [64, 2, 1] // CHECK: transform.iree.hoist_static_alloc %{{.*}} // CHECK: apply_patterns to %{{.*}} { // CHECK: transform.apply_patterns.memref.fold_memref_alias_ops @@ -138,11 +138,11 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // WITH_OPTIONS: transform.structured.pad %{{.*}} {copy_back = false, pack_paddings = [1, 1, 1], pad_to_multiple_of = [1, 1, 1], padding_dimensions = [0, 1, 2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]} // WITH_OPTIONS: transform.structured.hoist_pad %{{.}} by 1 loops // WITH_OPTIONS: transform.structured.insert_slice_to_copy %{{.*}} : (!transform.any_op) -> !transform.any_op -// WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}} num_threads [64, 2] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) +// WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}} num_threads [64, 2] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // WITH_OPTIONS: transform.scf.take_assumed_branch %{{.*}} take_else_branch : (!transform.any_op) -> () -// WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}} num_threads [8, 16] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) +// WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}} num_threads [8, 16] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // WITH_OPTIONS: transform.scf.take_assumed_branch %{{.*}} take_else_branch : (!transform.any_op) -> () -// WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}} num_threads [8, 16] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) +// WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}} num_threads [8, 16] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}} num_threads [4, 1] tile_sizes [](mapping = [#gpu.warp, #gpu.warp]) // WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}} num_threads [4, 1] tile_sizes [](mapping = [#gpu.warp, #gpu.warp]) // WITH_OPTIONS: transform.structured.masked_vectorize %{{.*}} vector_sizes [4, 4] @@ -155,7 +155,7 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // WITH_OPTIONS: transform.iree.forall_to_workgroup %{{.*}} // The workgroup dimensions are controled by td-matmul-strategy-num-threads-XX. // The warp dimensions are controled by td-matmul-strategy-num-warps-XX. -// WITH_OPTIONS: transform.iree.map_nested_forall_to_gpu_threads %{{.*}} workgroup_dims = [32, 4, 1] warp_dims = [1, 4, 1] +// WITH_OPTIONS: transform.iree.map_nested_forall_to_gpu_threads %{{.*}} workgroup_dims = [32, 4, 1] // WITH_OPTIONS: transform.iree.hoist_static_alloc %{{.*}} // WITH_OPTIONS: apply_patterns to %{{.*}} { // WITH_OPTIONS: transform.apply_patterns.memref.fold_memref_alias_ops @@ -239,11 +239,11 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // CHECK: transform.iree.populate_workgroup_count_region_using_num_threads_slice // CHECK: transform.structured.tile %{{.*}}[0, 0, 16] // align1 -// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [8, 16] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) +// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [8, 16] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // align2 -// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 64] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) +// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 64] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // align2 -// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 64] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) +// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 64] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp, #gpu.warp]) // CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp, #gpu.warp]) // align1 @@ -339,10 +339,10 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // CHECK: %[[RES_COPY:.+]] = transform.structured.rewrite_in_destination_passing_style %[[RES_PAD]] // CHECK: %[[LHS_PAD:.+]] = get_producer_of_operand %{{.*}}[0] // CHECK: %[[RHS_PAD:.+]] = get_producer_of_operand %{{.*}}[1] -// CHECK: %{{.*}}, %[[TILED_LHS:.+]] = transform.structured.tile_to_forall_op %[[LHS_PAD]] num_threads [32, 4] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) +// CHECK: %{{.*}}, %[[TILED_LHS:.+]] = transform.structured.tile_to_forall_op %[[LHS_PAD]] num_threads [32, 4] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // CHECK: transform.structured.match ops{["scf.if"]} // CHECK: transform.scf.take_assumed_branch %{{.*}} take_else_branch -// CHECK: %{{.*}}, %[[TILED_RHS:.+]] = transform.structured.tile_to_forall_op %[[RHS_PAD]] num_threads [4, 32] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) +// CHECK: %{{.*}}, %[[TILED_RHS:.+]] = transform.structured.tile_to_forall_op %[[RHS_PAD]] num_threads [4, 32] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // CHECK: transform.structured.match ops{["scf.if"]} // CHECK: transform.scf.take_assumed_branch %{{.*}} take_else_branch // CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp, #gpu.warp]) @@ -415,8 +415,8 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // CHECK: %[[RHS_PAD:.+]] = get_producer_of_operand %{{.*}}[1] // CHECK: %[[LHS_COPY:.+]] = transform.structured.rewrite_in_destination_passing_style %[[LHS_PAD]] // CHECK: %[[RHS_COPY:.+]] = transform.structured.rewrite_in_destination_passing_style %[[RHS_PAD]] -// CHECK: transform.structured.tile_to_forall_op %[[LHS_COPY]] num_threads [32, 4] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) -// CHECK: transform.structured.tile_to_forall_op %[[RHS_COPY]] num_threads [4, 32] tile_sizes [](mapping = [#gpu.linear, #gpu.linear]) +// CHECK: transform.structured.tile_to_forall_op %[[LHS_COPY]] num_threads [32, 4] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) +// CHECK: transform.structured.tile_to_forall_op %[[RHS_COPY]] num_threads [4, 32] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp, #gpu.warp]) // CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp, #gpu.warp]) // CHECK: transform.apply_patterns.canonicalization diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_pad.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_pad.mlir index 427c5ccb682e..92206855b4ea 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_pad.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_pad.mlir @@ -84,7 +84,7 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // CHECK: transform.iree.apply_buffer_optimizations {{.*}} : (!transform.any_op) -> () // CHECK: {{.*}} = transform.structured.match ops{["func.func"]} in {{.*}} : (!transform.any_op) -> !transform.any_op // CHECK: transform.iree.forall_to_workgroup {{.*}} : (!transform.any_op) -> () -// CHECK: transform.iree.map_nested_forall_to_gpu_threads {{.*}} workgroup_dims = [16, 16, 1] warp_dims = [] : (!transform.any_op) -> () +// CHECK: transform.iree.map_nested_forall_to_gpu_threads {{.*}} workgroup_dims = [16, 16, 1] subgroup_size = 32 : (!transform.any_op) -> () // CHECK: transform.apply_patterns.vector.lower_masks // CHECK: transform.apply_patterns.vector.materialize_masks // CHECK: apply_patterns to %{{.*}} { @@ -99,7 +99,7 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", // WITH_OPTIONS: transform.structured.tile_to_forall_op {{.*}} num_threads [] tile_sizes [32, 16](mapping = [#gpu.block, #gpu.block]) // WITH_OPTIONS: {{.*}} = transform.structured.tile_to_forall_op {{.*}} num_threads [4, 8] tile_sizes [](mapping = [#gpu.thread, #gpu.thread]) // WITH_OPTIONS: transform.structured.masked_vectorize {{.*}} vector_sizes [2, 4] : !transform.any_op -// WITH_OPTIONS: transform.iree.map_nested_forall_to_gpu_threads {{.*}} workgroup_dims = [8, 4, 1] warp_dims = [] +// WITH_OPTIONS: transform.iree.map_nested_forall_to_gpu_threads {{.*}} workgroup_dims = [8, 4, 1] // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir index d29f891d5a7e..ea440b10cf1d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir @@ -1,6 +1,6 @@ // RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(iree-transform-dialect-interpreter,transform-dialect-drop-schedule))" | FileCheck %s -// CHECK: #[[$DIV32MOD8:.*]] = affine_map<()[s0] -> ((s0 floordiv 32) mod 8)> +// CHECK: #[[$DIV32:.*]] = affine_map<()[s0] -> (s0 floordiv 32)> #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_60"}> #map = affine_map<()[s0] -> (s0 * 8)> #map1 = affine_map<(d0) -> (d0)> @@ -38,7 +38,7 @@ hal.executable private @distribute { {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>> } {mapping = [#gpu.thread]} -// CHECK: %[[WX:.+]] = affine.apply #[[$DIV32MOD8]]()[%[[TX]]] +// CHECK: %[[WX:.+]] = affine.apply #[[$DIV32]]()[%[[TX]]] // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[WX]]] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>> scf.forall (%arg0) in (%c8) { vector.transfer_write %cst_0, %subview[%arg0] @@ -52,7 +52,7 @@ hal.executable private @distribute { %17 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op transform.iree.map_nested_forall_to_gpu_threads %17 - workgroup_dims = [256, 1, 1] warp_dims = [8, 1, 1] subgroup_size = 32 : (!transform.any_op) -> () + workgroup_dims = [256, 1, 1] subgroup_size = 32 : (!transform.any_op) -> () // Late canonicalizations to cleanup and pass the checks. // Needs to occur on the whole variant to perform cse on the workgroup_count region diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h index e0ff79c2017e..1e2a01d24830 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h @@ -21,13 +21,13 @@ namespace iree_compiler { // Base quantities generally useful for all CPU and GPU strategies. //===----------------------------------------------------------------------===// inline Attribute blockX(MLIRContext *ctx) { - return mlir::gpu::GPUBlockMappingAttr::get(ctx, mlir::gpu::Blocks::DimX); + return mlir::gpu::GPUBlockMappingAttr::get(ctx, mlir::gpu::MappingId::DimX); } inline Attribute blockY(MLIRContext *ctx) { - return mlir::gpu::GPUBlockMappingAttr::get(ctx, mlir::gpu::Blocks::DimY); + return mlir::gpu::GPUBlockMappingAttr::get(ctx, mlir::gpu::MappingId::DimY); } inline Attribute blockZ(MLIRContext *ctx) { - return mlir::gpu::GPUBlockMappingAttr::get(ctx, mlir::gpu::Blocks::DimZ); + return mlir::gpu::GPUBlockMappingAttr::get(ctx, mlir::gpu::MappingId::DimZ); } struct AbstractReductionStrategy; diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp index f15901791622..9c1726bdfda8 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp @@ -141,10 +141,9 @@ static std::pair computeSplitPoint(int64_t upperBound, /// Takes a handle to a func.func and returns an updated handle to a /// func.func. Value mlir::iree_compiler::gpu::buildMapToBlockAndThreads( - ImplicitLocOpBuilder &b, Value funcH, ArrayRef blockSize, - ArrayRef warpDims) { + ImplicitLocOpBuilder &b, Value funcH, ArrayRef blockSize) { b.create(funcH); - b.create(funcH, blockSize, warpDims); + b.create(funcH, blockSize); return funcH; } diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.h index 1a6746aab0ba..f65f8768cbb7 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.h +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.h @@ -24,31 +24,34 @@ struct GPUModel; // Base quantities generally useful for all GPU strategies. //===----------------------------------------------------------------------===// inline Attribute threadX(MLIRContext *ctx) { - return mlir::gpu::GPUThreadMappingAttr::get(ctx, mlir::gpu::Threads::DimX); + return mlir::gpu::GPUThreadMappingAttr::get(ctx, mlir::gpu::MappingId::DimX); } inline Attribute threadY(MLIRContext *ctx) { - return mlir::gpu::GPUThreadMappingAttr::get(ctx, mlir::gpu::Threads::DimY); + return mlir::gpu::GPUThreadMappingAttr::get(ctx, mlir::gpu::MappingId::DimY); } inline Attribute threadZ(MLIRContext *ctx) { - return mlir::gpu::GPUThreadMappingAttr::get(ctx, mlir::gpu::Threads::DimZ); + return mlir::gpu::GPUThreadMappingAttr::get(ctx, mlir::gpu::MappingId::DimZ); } inline Attribute warpX(MLIRContext *ctx) { - return mlir::gpu::GPUWarpMappingAttr::get(ctx, mlir::gpu::Warps::DimX); + return mlir::gpu::GPUWarpMappingAttr::get(ctx, mlir::gpu::MappingId::DimX); } inline Attribute warpY(MLIRContext *ctx) { - return mlir::gpu::GPUWarpMappingAttr::get(ctx, mlir::gpu::Warps::DimY); + return mlir::gpu::GPUWarpMappingAttr::get(ctx, mlir::gpu::MappingId::DimY); } inline Attribute warpZ(MLIRContext *ctx) { - return mlir::gpu::GPUWarpMappingAttr::get(ctx, mlir::gpu::Warps::DimZ); + return mlir::gpu::GPUWarpMappingAttr::get(ctx, mlir::gpu::MappingId::DimZ); } -inline Attribute linearIdX(MLIRContext *ctx) { - return mlir::gpu::GPULinearIdMappingAttr::get(ctx, mlir::gpu::LinearId::DimX); +inline Attribute linearId0(MLIRContext *ctx) { + return mlir::gpu::GPUThreadMappingAttr::get(ctx, + mlir::gpu::MappingId::LinearDim0); } -inline Attribute linearIdY(MLIRContext *ctx) { - return mlir::gpu::GPULinearIdMappingAttr::get(ctx, mlir::gpu::LinearId::DimY); +inline Attribute linearId1(MLIRContext *ctx) { + return mlir::gpu::GPUThreadMappingAttr::get(ctx, + mlir::gpu::MappingId::LinearDim1); } -inline Attribute linearIdZ(MLIRContext *ctx) { - return mlir::gpu::GPULinearIdMappingAttr::get(ctx, mlir::gpu::LinearId::DimZ); +inline Attribute linearId2(MLIRContext *ctx) { + return mlir::gpu::GPUThreadMappingAttr::get(ctx, + mlir::gpu::MappingId::LinearDim2); } //===----------------------------------------------------------------------===// @@ -74,8 +77,7 @@ int64_t adjustNumberOfWarpsForBlockShuffle(int64_t numWarpsToUse, /// dimensions to consider along various dimensions and avoid second-guessing /// how the mapping to warps should occur. Value buildMapToBlockAndThreads(ImplicitLocOpBuilder &b, Value funcH, - ArrayRef blockSize, - ArrayRef warpDims = {}); + ArrayRef blockSize); /// Post-bufferization vector distribution with rank-reduction. /// Takes a handle to a func.func and returns an updated handle to a diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionImplicitGemmStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionImplicitGemmStrategy.cpp index 08fc4b5e95fb..1840e931d249 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionImplicitGemmStrategy.cpp +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionImplicitGemmStrategy.cpp @@ -346,8 +346,7 @@ void iree_compiler::gpu::buildConvolutionImplicitGemmStrategy( // TODO: assumes a single func::FuncOp to transform, needs hardening. // TODO: extract info from strategy. funcH = b.create(variantH, func::FuncOp::getOperationName()); - funcH = buildMapToBlockAndThreads(b, funcH, strategy.numThreads, - strategy.numWarps); + funcH = buildMapToBlockAndThreads(b, funcH, strategy.numThreads); funcH = b.create(funcH); // Step 9. Convert to tensor core ops. diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp index c3a957be348a..259c86f2861b 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp @@ -128,8 +128,8 @@ iree_compiler::gpu::MappingInfo iree_compiler::gpu::CopyMapping::getMappingInfo( std::tie(size, numThreads) = pair; return mlir::ceilDiv(size, numThreads); })); - SmallVector allThreadMappings{linearIdZ(ctx), linearIdY(ctx), - linearIdX(ctx)}; + SmallVector allThreadMappings{linearId2(ctx), linearId1(ctx), + linearId0(ctx)}; auto threadMapping = llvm::to_vector(ArrayRef(allThreadMappings).take_back(tileSizes.size())); diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp index f92f8f21631d..13b8bf241c9c 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp @@ -238,8 +238,7 @@ buildCommonMatmulLikeThreadSchedule(ImplicitLocOpBuilder &b, Value variantH, // Need to match again since bufferize invalidated all handles. // TODO: assumes a single func::FuncOp to transform, needs hardening. funcH = b.create(variantH, func::FuncOp::getOperationName()); - funcH = buildMapToBlockAndThreads(b, funcH, strategy.numThreads, - strategy.numWarps); + funcH = buildMapToBlockAndThreads(b, funcH, strategy.numThreads); funcH = b.create(funcH); // Step 9. Convert to tensor core ops. diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp index 9552156e7604..0abd2fa81219 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp @@ -561,10 +561,9 @@ canonicalizeSubViewParts(OpTy op, RankedTensorType sliceType, mixedOffsets.assign(op.getMixedOffsets()); mixedSizes.assign(op.getMixedSizes()); mixedStrides.assign(op.getMixedStrides()); - Builder builder(op.getContext()); - if (failed(foldDynamicIndexList(builder, mixedOffsets)) && - failed(foldDynamicIndexList(builder, mixedSizes)) && - failed(foldDynamicIndexList(builder, mixedStrides))) { + if (failed(foldDynamicIndexList(mixedOffsets)) && + failed(foldDynamicIndexList(mixedSizes)) && + failed(foldDynamicIndexList(mixedStrides))) { return failure(); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index dd6a7959c901..f55251c673e0 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp @@ -941,7 +941,7 @@ struct ElideEmptyFenceJoin : public OpRewritePattern { // Produces a deduplicated and null-elided operand list. // Returns std::nullopt if nothing changed. -static std::optional> +static std::optional> deduplicateFenceOperands(ValueRange operands) { SetVector newOperands; for (auto operand : operands) { diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel index b1dcbce3c3be..d4fd84e8a63d 100644 --- a/compiler/src/iree/compiler/Tools/BUILD.bazel +++ b/compiler/src/iree/compiler/Tools/BUILD.bazel @@ -119,7 +119,6 @@ iree_compiler_cc_library( "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgPassIncGen", - "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MLProgramDialect", "@llvm-project//mlir:MathDialect", diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt index f1e31cafc1d0..55f6086ba0b7 100644 --- a/compiler/src/iree/compiler/Tools/CMakeLists.txt +++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt @@ -136,7 +136,6 @@ iree_cc_library( MLIRIR MLIRLLVMDialect MLIRLinalgDialect - MLIRLinalgToLLVM MLIRLinalgTransforms MLIRMLProgramDialect MLIRQuantDialect diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel index 77c5e34b7d1c..1f8a3fd63d3b 100644 --- a/llvm-external-projects/iree-dialects/BUILD.bazel +++ b/llvm-external-projects/iree-dialects/BUILD.bazel @@ -598,7 +598,6 @@ cc_library( "@llvm-project//mlir:AsyncToLLVM", "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:IndexToLLVM", - "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgToStandard", "@llvm-project//mlir:MathToLLVM", "@llvm-project//mlir:MemRefToLLVM", diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt index 772bdd3831ee..eec1bb554923 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt @@ -17,7 +17,6 @@ add_mlir_library(IREELinalgExtTransforms MLIRAffineToStandard MLIRAsyncDialect MLIRSCFToControlFlow - MLIRLinalgToLLVM MLIRDialectUtils MLIRVectorToLLVM MLIRMathToLLVM diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt index cfba284ec00d..899c80159a1b 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt @@ -38,6 +38,5 @@ add_mlir_library(IREELinalgTransformDialect MLIRMemRefToLLVM MLIRMathToLLVM MLIRVectorToLLVM - MLIRLinalgToLLVM MLIRSCFToControlFlow ) diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp index 88b5587ff3fc..4ae2935ee18a 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp @@ -14,7 +14,6 @@ #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" -#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp index ed3dfd68c4ec..61e71a0f3de1 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp @@ -15,7 +15,6 @@ #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" -#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -415,8 +414,6 @@ DiagnosedSilenceableFailure transform_ext::LowerToLLVMOp::apply( // Sprinkle some cleanups. pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - // Blanket-convert any remaining linalg ops to LLVM if any remain. - pm.addPass(createConvertLinalgToLLVMPass()); { auto options = ConvertVectorToLLVMPassOptions(); options.reassociateFPReductions = getReassociateFpReductions(); diff --git a/third_party/llvm-project b/third_party/llvm-project index 4706251a3186..2dc1a27449a9 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 4706251a3186c34da0ee8fd894f7e6b095da8fdc +Subproject commit 2dc1a27449a98cf18214174d626c29c7bb72c88f