diff --git a/include/mlir/Dialect/LoopOps/LoopOps.h b/include/mlir/Dialect/LoopOps/LoopOps.h index fdadf4a40dd5..cc1af11c4eed 100644 --- a/include/mlir/Dialect/LoopOps/LoopOps.h +++ b/include/mlir/Dialect/LoopOps/LoopOps.h @@ -52,6 +52,8 @@ void ensureLoopTerminator(Region ®ion, Builder &builder, Location loc); /// not an induction variable, then return nullptr. ForOp getForInductionVarOwner(Value *val); +/// Returns the trip count of the loop if it's a constant, None otherwise. +Optional getConstantTripCount(ForOp forOp); } // end namespace loop } // end namespace mlir #endif // MLIR_LOOPOPS_OPS_H_ diff --git a/include/mlir/Transforms/LoopLikeInterface.h b/include/mlir/Transforms/LoopLikeInterface.h index a8bc0d113786..af3a9aea1198 100644 --- a/include/mlir/Transforms/LoopLikeInterface.h +++ b/include/mlir/Transforms/LoopLikeInterface.h @@ -22,6 +22,7 @@ #ifndef MLIR_TRANSFORMS_LOOPLIKEINTERFACE_H_ #define MLIR_TRANSFORMS_LOOPLIKEINTERFACE_H_ +#include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/ArrayRef.h" diff --git a/include/mlir/Transforms/LoopLikeInterface.td b/include/mlir/Transforms/LoopLikeInterface.td index a7479cab4a91..09d0e0cfafe4 100644 --- a/include/mlir/Transforms/LoopLikeInterface.td +++ b/include/mlir/Transforms/LoopLikeInterface.td @@ -56,6 +56,10 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { }], "LogicalResult", "moveOutOfLoop", (ins "ArrayRef":$ops) >, + InterfaceMethod<"Get the trip count if it is a constant.", + "llvm::Optional", "getConstantTripCount", (ins), [{ + return getConstantTripCount(op); + }]>, ]; } diff --git a/lib/Dialect/LoopOps/LoopOps.cpp b/lib/Dialect/LoopOps/LoopOps.cpp index 1dc7debd9a6c..27308660f8f8 100644 --- a/lib/Dialect/LoopOps/LoopOps.cpp +++ b/lib/Dialect/LoopOps/LoopOps.cpp @@ -140,7 +140,7 @@ bool ForOp::isDefinedOutsideOfLoop(Value *value) { LogicalResult ForOp::moveOutOfLoop(ArrayRef ops) { for (auto *op : ops) - op->moveBefore(this->getOperation()); + op->moveBefore(*this); return success(); } @@ -153,6 +153,31 @@ ForOp mlir::loop::getForInductionVarOwner(Value *val) { return dyn_cast_or_null(containingInst); } +Optional mlir::loop::getConstantTripCount(ForOp forOp) { + Value *lb = forOp.lowerBound(); + Value *ub = forOp.upperBound(); + + if (lb == ub) + return 0; + + IntegerAttr lbCst, ubCst, step; + if (!matchPattern(lb, m_Constant(&lbCst)) || + !matchPattern(ub, m_Constant(&ubCst))) + return llvm::None; + + int64_t lbConst = lbCst.getValue().getSExtValue(); + int64_t ubConst = ubCst.getValue().getSExtValue(); + if (ubConst - lbConst <= 0) + return 0; + + if (!matchPattern(forOp.step(), m_Constant(&step))) + return llvm::None; + + // Step is guaranteed to be positive. + int64_t stepConst = step.getValue().getSExtValue(); + return llvm::divideCeil(ubConst - lbConst, stepConst); +} + //===----------------------------------------------------------------------===// // IfOp //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/AffineLoopInvariantCodeMotion.cpp b/lib/Transforms/AffineLoopInvariantCodeMotion.cpp index f384f6d3fb18..2b9b2e005e9d 100644 --- a/lib/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/lib/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -40,17 +40,17 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#define DEBUG_TYPE "licm" +#define DEBUG_TYPE "affine-licm" using namespace mlir; namespace { -/// Loop invariant code motion (LICM) pass. +/// Affine loop invariant code motion (LICM) pass. +/// TODO: This pass should be removed once the new LICM pass can handle its +/// uses. /// TODO(asabne) : The pass is missing zero-trip tests. /// TODO(asabne) : Check for the presence of side effects before hoisting. -/// TODO: This code should be removed once the new LICM pass can handle its -/// uses. struct LoopInvariantCodeMotion : public FunctionPass { void runOnFunction() override; void runOnAffineForOp(AffineForOp forOp); @@ -245,4 +245,4 @@ mlir::createAffineLoopInvariantCodeMotionPass() { static PassRegistration pass("affine-loop-invariant-code-motion", - "Hoist loop invariant instructions outside of the loop"); + "Hoist loop invariant operations outside of the loop"); diff --git a/lib/Transforms/LoopInvariantCodeMotion.cpp b/lib/Transforms/LoopInvariantCodeMotion.cpp index 738524aa6ec9..6590ac1de0f2 100644 --- a/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -62,7 +62,7 @@ static bool canBeHoisted(Operation *op, auto thisOpIsSideEffecting = sideEffecting; if (thisOpIsSideEffecting != SideEffecting::Never) { thisOpIsSideEffecting = interface.isSideEffecting(op); - // If the op always has sideeffects, we cannot hoist. + // If the op always has side effects, we cannot hoist. if (thisOpIsSideEffecting == SideEffecting::Always) return false; } @@ -70,9 +70,7 @@ static bool canBeHoisted(Operation *op, // can be hoisted. for (auto ®ion : op->getRegions()) { for (auto &block : region.getBlocks()) { - for (auto &innerOp : block) { - if (innerOp.isKnownTerminator()) - continue; + for (auto &innerOp : block.without_terminator()) { if (!canBeHoisted(&innerOp, definedOutside, thisOpIsSideEffecting, interface)) return false; @@ -112,7 +110,7 @@ static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike, } } - // For all instructions that we found to be invariant, move outside of the + // For all operations that we found to be invariant, move outside of the // loop. auto result = looplike.moveOutOfLoop(opsToMove); LLVM_DEBUG(looplike.print(llvm::dbgs() << "Modified loop\n")); @@ -128,6 +126,13 @@ void LoopInvariantCodeMotion::runOnOperation() { // the outer loop, which in turn can be further LICM'ed. getOperation()->walk([&](Operation *op) { if (auto looplike = dyn_cast(op)) { + // Skip zero trip count loops. For unknown trip counts, we still move + // invariant code since it is side-effect free, and in general profitable. + // TODO: when necessary, we could only move when the trip count is + // guaranteed to be at least one. + auto tripCount = looplike.getConstantTripCount(); + if (tripCount == 0UL) + return; LLVM_DEBUG(op->print(llvm::dbgs() << "\nOriginal loop\n")); if (failed(moveLoopInvariantCode(looplike, interface))) signalPassFailure(); @@ -146,4 +151,4 @@ std::unique_ptr mlir::createLoopInvariantCodeMotionPass() { static PassRegistration pass("loop-invariant-code-motion", - "Hoist loop invariant instructions outside of the loop"); + "Hoist loop invariant operations outside of the loop"); diff --git a/lib/Transforms/SimplifyAffineStructures.cpp b/lib/Transforms/SimplifyAffineStructures.cpp index 9512ff738aa3..cfded6ebb8ae 100644 --- a/lib/Transforms/SimplifyAffineStructures.cpp +++ b/lib/Transforms/SimplifyAffineStructures.cpp @@ -22,6 +22,7 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LoopLikeInterface.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" @@ -93,12 +94,12 @@ std::unique_ptr> mlir::createSimplifyAffineStructuresPass() { void SimplifyAffineStructures::runOnFunction() { auto func = getFunction(); simplifiedAttributes.clear(); - func.walk([&](Operation *opInst) { - for (auto attr : opInst->getAttrs()) { + func.walk([&](Operation *op) { + for (auto attr : op->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) - simplifyAndUpdateAttribute(opInst, attr.first, mapAttr); + simplifyAndUpdateAttribute(op, attr.first, mapAttr); else if (auto setAttr = attr.second.dyn_cast()) - simplifyAndUpdateAttribute(opInst, attr.first, setAttr); + simplifyAndUpdateAttribute(op, attr.first, setAttr); } }); @@ -110,8 +111,16 @@ void SimplifyAffineStructures::runOnFunction() { for (auto allocOp : allocOps) { normalizeMemRef(allocOp); } + + // Remove zero trip count loops. + // TODO: this could be moved to a more appropriate place. + func.walk([&](LoopLikeOpInterface loopOp) { + if (loopOp.getConstantTripCount() == 0UL) + loopOp.erase(); + }); } static PassRegistration - pass("simplify-affine-structures", - "Simplify affine expressions in maps/sets and normalize memrefs"); + pass("simplify-affine-structures", "Simplify expressions in afine map/set " + "attributes, normalize memrefs, remove " + "zero trip-count loops"); diff --git a/test/Transforms/affine-loop-invariant-code-motion.mlir b/test/Transforms/affine-loop-invariant-code-motion.mlir index f7143b7ad7db..c24bb4f0d401 100644 --- a/test/Transforms/affine-loop-invariant-code-motion.mlir +++ b/test/Transforms/affine-loop-invariant-code-motion.mlir @@ -17,27 +17,23 @@ func @nested_loops_both_having_invariant_code() { // CHECK-NEXT: %cst_0 = constant 8.000000e+00 : f32 // CHECK-NEXT: %1 = addf %cst, %cst_0 : f32 // CHECK-NEXT: affine.for %arg0 = 0 to 10 { - // CHECK-NEXT: affine.store %1, %0[%arg0] : memref<10xf32> + // CHECK-NEXT: affine.store %1, %0[%arg0] : memref<10xf32> return } -// The store-load forwarding can see through affine apply's since it relies on -// dependence information. -// CHECK-LABEL: func @store_affine_apply -func @store_affine_apply() -> memref<10xf32> { +// CHECK-LABEL: func @store_affine_for +func @store_affine_for() -> memref<10xf32> { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> affine.for %arg0 = 0 to 10 { - %t0 = affine.apply (d1) -> (d1 + 1)(%arg0) - affine.store %cf7, %m[%t0] : memref<10xf32> + affine.store %cf7, %m[%arg0 + 1] : memref<10xf32> } return %m : memref<10xf32> // CHECK: %cst = constant 7.000000e+00 : f32 // CHECK-NEXT: %0 = alloc() : memref<10xf32> // CHECK-NEXT: affine.for %arg0 = 0 to 10 { -// CHECK-NEXT: %1 = affine.apply #map3(%arg0) -// CHECK-NEXT: affine.store %cst, %0[%1] : memref<10xf32> +// CHECK-NEXT: affine.store %cst, %0[%arg0 + 1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : memref<10xf32> } diff --git a/test/Transforms/loop-invariant-code-motion.mlir b/test/Transforms/loop-invariant-code-motion.mlir index 4d742acf246f..dd0e222ceccd 100644 --- a/test/Transforms/loop-invariant-code-motion.mlir +++ b/test/Transforms/loop-invariant-code-motion.mlir @@ -199,10 +199,54 @@ func @invariant_affine_nested_if_else() { // CHECK-NEXT: } // CHECK-NEXT: } + return +} + +// CHECK-LABEL: func @zero_trip_count_affine +func @zero_trip_count_affine() { + %m = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + %N = constant 0 : index + + affine.for %arg0 = 0 to %N { + affine.for %arg1 = 0 to 10 { + %v0 = addf %cf7, %cf7 : f32 + } + } + // CHECK: alloc() : memref<10xf32> + // CHECK-NEXT: %cst = constant 7.000000e+00 : f32 + // CHECK-NEXT: %c0 = constant 0 : index + // CHECK-NEXT: affine.for + // CHECK-NEXT: addf + // CHECK-NEXT: affine.for + + return +} + +// CHECK-LABEL: func @zero_trip_count_loop +func @zero_trip_count_loop(%N : index) { + %m = alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + %c1 = constant 1 : index + %c5 = constant 5 : index + + loop.for %i = %N to %N step %c1 { + loop.for %j = %c5 to %c5 step %c1 { + addf %cf7, %cf7 : f32 + } + } + // CHECK: alloc() : memref<10xf32> + // CHECK-NEXT: %cst = constant 7.000000e+00 : f32 + // CHECK-NEXT: %c1 = constant 1 : index + // CHECK-NEXT: %c5 = constant 5 : index + // CHECK-NEXT: loop.for + // CHECK-NEXT: loop.for + // CHECK-NEXT: addf return } +// CHECK-LABEL: func @invariant_loop_dialect func @invariant_loop_dialect() { %ci0 = constant 0 : index %ci10 = constant 10 : index @@ -211,7 +255,7 @@ func @invariant_loop_dialect() { %cf7 = constant 7.0 : f32 %cf8 = constant 8.0 : f32 loop.for %arg0 = %ci0 to %ci10 step %ci1 { - loop.for %arg1 = %ci0 to %ci10 step %ci1 { + loop.for %arg1 = %ci0 to %ci1 step %ci10 { %v0 = addf %cf7, %cf8 : f32 } } @@ -237,8 +281,8 @@ func @variant_loop_dialect() { // CHECK: %0 = alloc() : memref<10xf32> // CHECK-NEXT: loop.for - // CHECK-NEXT: loop.for - // CHECK-NEXT: addi + // CHECK-NEXT: loop.for + // CHECK-NEXT: addi return } diff --git a/test/Transforms/simplify-affine-structures.mlir b/test/Transforms/simplify-affine-structures.mlir index 07a2482f105a..20bdf13512ec 100644 --- a/test/Transforms/simplify-affine-structures.mlir +++ b/test/Transforms/simplify-affine-structures.mlir @@ -235,3 +235,22 @@ func @test_empty_set(%N : index) { return } + +// CHECK-LABEL: func @zero_trip_count_loops +func @zero_trip_count_loops(%N : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c-1 = constant -1 : index + %M = affine.apply (d0) -> ((2*d0 + 4) mod 2)(%N) + affine.for %i = 0 to %M { + } + affine.for %i = 0 to -1 { + } + loop.for %i = %M to %M step %c1 { + } + loop.for %i = %c0 to %c-1 step %N { + } + // All loops above should disappear. + // CHECK-NOT: loop.for + return +} diff --git a/tools/mlir-tblgen/OpInterfacesGen.cpp b/tools/mlir-tblgen/OpInterfacesGen.cpp index 4da412c2f438..2f872a4b58d4 100644 --- a/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -41,7 +41,10 @@ using mlir::tblgen::OpInterfaceMethod; // beginning of the argument list. static void emitMethodNameAndArgs(const OpInterfaceMethod &method, raw_ostream &os, bool addOperationArg) { - os << method.getName() << '('; + // Whenever an operation argument is added, suffix helper method name with an + // underscore to avoid conflicts with free functions of same name on the + // concrete ops using this interface. + os << method.getName() << (addOperationArg ? "_(" : "("); if (addOperationArg) os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", "); interleaveComma(method.getArguments(), os, @@ -64,9 +67,11 @@ static void emitInterfaceDef(OpInterface &interface, raw_ostream &os) { emitMethodNameAndArgs(method, os, /*addOperationArg=*/false); // Forward to the method on the concrete operation type. - os << " {\n return getImpl()->" << method.getName() << '('; + os << " {\n return getImpl()->" << method.getName(); if (!method.isStatic()) - os << "getOperation()" << (method.arg_empty() ? "" : ", "); + os << "_(getOperation()" << (method.arg_empty() ? "" : ", "); + else + os << "("; interleaveComma( method.getArguments(), os, [&](const OpInterfaceMethod::Argument &arg) { os << arg.name; });