Skip to content

Commit

Permalink
[CPU] Introduce TilingConfig class (#14082)
Browse files Browse the repository at this point in the history
TilingConfig is a simple steps towards separating the API to retrieve
the tile size information from the actual representation and
implementation of such information. It will let us implement different
tiling configuration scenarios without exposing the implementation
details or even replacing LoweringConfig with something else without
impacting TilingConfig users.
  • Loading branch information
dcaballe committed Jun 15, 2023
1 parent 77eda48 commit 9303825
Show file tree
Hide file tree
Showing 11 changed files with 421 additions and 127 deletions.
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ iree_compiler_cc_library(
"LLVMCPUVectorization.cpp",
"Passes.cpp",
"TargetMLTransformInfo.cpp",
"TileSizeSelection.cpp",
"Utils.cpp",
"VectorContractCustomKernels.cpp",
"VerifyLinalgTransformLegality.cpp",
Expand All @@ -47,6 +48,7 @@ iree_compiler_cc_library(
"KernelDispatch.h",
"LLVMCPUPasses.h",
"TargetMLTransformInfo.h",
"TileSizeSelection.h",
"Utils.h",
],
deps = [
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ iree_cc_library(
"KernelDispatch.h"
"LLVMCPUPasses.h"
"TargetMLTransformInfo.h"
"TileSizeSelection.h"
"Utils.h"
SRCS
"ConvertToLLVM.cpp"
Expand All @@ -43,6 +44,7 @@ iree_cc_library(
"LLVMCPUVectorization.cpp"
"Passes.cpp"
"TargetMLTransformInfo.cpp"
"TileSizeSelection.cpp"
"Utils.cpp"
"VectorContractCustomKernels.cpp"
"VerifyLinalgTransformLegality.cpp"
Expand Down
16 changes: 0 additions & 16 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,6 @@
namespace mlir {
namespace iree_compiler {

// TODO(hanchung): Create a pass to handle detailed logic about splitting tiling
// sizes for parallel dims and reduction dims.
// We have to fuse the fill + named_op + generic ops along parallel dims
// firstly. At this stage, we do not apply vectorization. The reduction dim
// won't get tiled if the case is matmul + generic op. In this case, we have to
// tile along reduction dim again, which needs them to be TilingInterface ops.
enum class TilingLevel : unsigned {
// Tile TilingInterface operations to threads.
WorkGroupTiles = 0,
// Tile TilingInterface operation on workgroup thread for parallel dims.
ParallelTiles = 1,
// Tile TilingInterface operations on workgroup thread for reduction dims.
ReductionTiles = 2,
NumTileLevels = 3
};

LogicalResult initCPULaunchConfig(ModuleOp moduleOp);

} // namespace iree_compiler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/LLVMCPU/KernelDispatch.h"
#include "iree/compiler/Codegen/LLVMCPU/LLVMCPUPasses.h"
#include "iree/compiler/Codegen/LLVMCPU/TileSizeSelection.h"
#include "iree/compiler/Codegen/LLVMCPU/Utils.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
Expand All @@ -24,6 +26,8 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"

using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttr;

namespace mlir {
namespace iree_compiler {

Expand Down Expand Up @@ -106,12 +110,53 @@ static LogicalResult verifyLoweringConfiguration(
auto walkResult = module.walk([&](Operation *op) -> WalkResult {
IREE::Codegen::LoweringConfigAttr loweringConfig = getLoweringConfig(op);
if (!loweringConfig) return WalkResult::advance();
return verificationFn(op, loweringConfig, translationInfo,
TilingConfig tilingConfig(loweringConfig);
return verificationFn(op, tilingConfig, translationInfo,
ArrayRef<int64_t>{});
});
return failure(walkResult.wasInterrupted());
}

// TODO(dcaballe): We temporarily need this utility to retrieve a valid
// lowering config. We should be able to remove this once we have a lowering
// config attribute per op.
static FailureOr<LoweringConfigAttr> getRootLoweringConfig(Operation *op) {
// Check for self first.
auto rootLoweringConfig = iree_compiler::getLoweringConfig(op);
if (rootLoweringConfig) {
return rootLoweringConfig;
}

auto result = op->walk([&](Operation *op) -> WalkResult {
auto loweringConfig = iree_compiler::getLoweringConfig(op);
if (!loweringConfig) {
return WalkResult::advance();
}
if (rootLoweringConfig) {
if (rootLoweringConfig != loweringConfig) {
return WalkResult::interrupt();
}
} else {
rootLoweringConfig = loweringConfig;
}
return WalkResult::advance();
});

if (!rootLoweringConfig || result.wasInterrupted()) {
return failure();
}

return rootLoweringConfig;
}

static TilingConfig getTilingConfigForPipeline(
IREE::HAL::ExecutableVariantOp variantOp) {
auto maybeLoweringConfig = getRootLoweringConfig(variantOp);
assert(succeeded(maybeLoweringConfig) &&
"Pipeline requires a lowering config");
return TilingConfig(*maybeLoweringConfig);
}

void LLVMCPULowerExecutableTargetPass::runOnOperation() {
IREE::HAL::ExecutableVariantOp variantOp = getOperation();
ModuleOp moduleOp = variantOp.getInnerModule();
Expand Down Expand Up @@ -179,7 +224,8 @@ void LLVMCPULowerExecutableTargetPass::runOnOperation() {
moduleOp, translationInfo.value(),
verifyConvTileAndDecomposeExpertConfig);
break;
default:;
default:
break;
}
if (failed(verificationStatus)) {
return signalPassFailure();
Expand All @@ -190,6 +236,7 @@ void LLVMCPULowerExecutableTargetPass::runOnOperation() {
bool enableVectorMasking =
isX86(target) || isRISCV(target) ||
(isAArch64(target) && hasAnySVEFeature(target));

bool enableMicrokernels = hasMicrokernels(target);
bool enableAArch64SSVE = isAArch64(target) && hasAnySVEFeature(target) &&
hasSMEFeature(target);
Expand All @@ -200,44 +247,56 @@ void LLVMCPULowerExecutableTargetPass::runOnOperation() {
addCPUDefaultPassPipeline(executableLoweringPipeline);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::
CPUBufferOpsTileAndVectorize:
addCPUBufferOpsTileAndVectorizePipeline(executableLoweringPipeline,
enableVectorMasking,
enableAArch64SSVE);
CPUBufferOpsTileAndVectorize: {
TilingConfig tilingConfig = getTilingConfigForPipeline(variantOp);
addCPUBufferOpsTileAndVectorizePipeline(
executableLoweringPipeline, tilingConfig, enableVectorMasking,
enableAArch64SSVE);
break;
}
case IREE::Codegen::DispatchLoweringPassPipeline::
CPUDoubleTilingExpert:
CPUDoubleTilingExpert: {
TilingConfig tilingConfig = getTilingConfigForPipeline(variantOp);
addMultiTilingExpertPassPipeline(
executableLoweringPipeline,
static_cast<int>(TilingLevel::NumTileLevels),
executableLoweringPipeline, tilingConfig,
/*enablePeeling=*/false, enableVectorMasking, lowerToAVX2);
break;
}
case IREE::Codegen::DispatchLoweringPassPipeline::
CPUDoubleTilingPadExpert:
addDoubleTilingPadExpertPassPipeline(executableLoweringPipeline,
enableVectorMasking);
CPUDoubleTilingPadExpert: {
TilingConfig tilingConfig = getTilingConfigForPipeline(variantOp);
addDoubleTilingPadExpertPassPipeline(
executableLoweringPipeline, tilingConfig, enableVectorMasking);
break;
}
case IREE::Codegen::DispatchLoweringPassPipeline::
CPUDoubleTilingPeelingExpert:
CPUDoubleTilingPeelingExpert: {
TilingConfig tilingConfig = getTilingConfigForPipeline(variantOp);
addMultiTilingExpertPassPipeline(
executableLoweringPipeline,
static_cast<int>(TilingLevel::NumTileLevels),
executableLoweringPipeline, tilingConfig,
/*enablePeeling=*/true, enableVectorMasking, lowerToAVX2,
enableAArch64SSVE);
break;
}
case IREE::Codegen::DispatchLoweringPassPipeline::
CPUConvTileAndDecomposeExpert:
CPUConvTileAndDecomposeExpert: {
TilingConfig tilingConfig = getTilingConfigForPipeline(variantOp);
addConvTileAndDecomposeExpertPassPipeline(
executableLoweringPipeline, enableVectorMasking,
executableLoweringPipeline, tilingConfig, enableVectorMasking,
enableAArch64SSVE);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::Mmt4dTilingExpert:
}
case IREE::Codegen::DispatchLoweringPassPipeline::Mmt4dTilingExpert: {
TilingConfig tilingConfig = getTilingConfigForPipeline(variantOp);
addMmt4dTilingExpertPassPipeline(executableLoweringPipeline,
enableMicrokernels);
tilingConfig, enableMicrokernels);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::CPUDataTiling:
addCPUDataTilingPipeline(executableLoweringPipeline);
}
case IREE::Codegen::DispatchLoweringPassPipeline::CPUDataTiling: {
TilingConfig tilingConfig = getTilingConfigForPipeline(variantOp);
addCPUDataTilingPipeline(executableLoweringPipeline, tilingConfig);
break;
}
case IREE::Codegen::DispatchLoweringPassPipeline::VMVXDefault:
addVMVXDefaultPassPipeline(executableLoweringPipeline,
enableMicrokernels);
Expand Down
22 changes: 14 additions & 8 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
namespace mlir {
namespace iree_compiler {

class TilingConfig;

/// Performs the final conversion to LLVM dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertToLLVMPass(
bool reassociateFpReordering = false);
Expand Down Expand Up @@ -116,38 +118,42 @@ void populateVectorContractCustomKernelsPatterns(
//----------------------------------------------------------------------------//
// LLVMCPU backend Pass Pipelines.
//----------------------------------------------------------------------------//

/// Populates the passes to lower linalg ops on buffers. Currenly this
/// pipeline is only used for dispatches that just copy data from input
/// interfaces to output interface.
void addCPUBufferOpsTileAndVectorizePipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableVectorMasking,
bool enableAArch64SSVE = false);

/// Populates the passes to lower ops through data tiling transformations.
void addCPUDataTilingPipeline(OpPassManager &passManager);
void addCPUDataTilingPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig);

/// Populates the passes to lower to scalars operations for linalg based
/// code-generation. This pipeline does not vectorize, but instead just
/// converts to memrefs
void addCPUDefaultPassPipeline(OpPassManager &passManager);

void addConvTileAndDecomposeExpertPassPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableVectorMasking,
bool enableAArch64SSVE = false);

void addDoubleTilingPadExpertPassPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableVectorMasking);

/// Populates the passes needed to multi level tile, fuse and vectorize
/// lowering of linalg ops on tensors to vectors operations.
void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableMicrokernels);

void addMultiTilingExpertPassPipeline(OpPassManager &passManager,
int64_t numLevels, bool enablePeeling,
bool enableVectorMasking,
bool lowerToAVX2,
bool enableAArch64SSVE = false);
void addMultiTilingExpertPassPipeline(
OpPassManager &passManager, TilingConfig &tilingConfig, bool enablePeeling,
bool enableVectorMasking, bool lowerToAVX2, bool enableAArch64SSVE = false);

void addTensorToVectorsPassPipeline(OpPassManager &passManager,
bool lowerToVectors = true);
Expand All @@ -162,13 +168,13 @@ void addVMVXDefaultPassPipeline(OpPassManager &passManager,
// Populates the passes needed to do tiling, decomposing, and vectorizing the
// convolution ops.
LogicalResult verifyConvTileAndDecomposeExpertConfig(
Operation *op, IREE::Codegen::LoweringConfigAttr loweringConfig,
Operation *op, TilingConfig &tilingConfig,
IREE::Codegen::TranslationInfoAttr translationInfo,
ArrayRef<int64_t> workgroupSize = {});

/// Populates the passes needed to do two-level tile + vectorize of linalg ops.
LogicalResult verifyDoubleTilingExpertPassPipelineConfig(
Operation *op, IREE::Codegen::LoweringConfigAttr loweringConfig,
Operation *op, TilingConfig &tilingConfig,
IREE::Codegen::TranslationInfoAttr translationInfo,
ArrayRef<int64_t> workgroupSize = {});

Expand Down
Loading

0 comments on commit 9303825

Please sign in to comment.