-
Notifications
You must be signed in to change notification settings - Fork 11.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Vectorization support for convolution of i1 type #109480
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir Author: Nirvedh Meshram (nirvedhmeshram) ChangesNormally convolutions present with the following linalg op region
However, for i1 due to strength reduction we get something like
This PR updates the logic to support this region for i1 types. Full diff: https://github.com/llvm/llvm-project/pull/109480.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a376afa5ddab12..1cdf937742fd2e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2947,12 +2947,14 @@ struct Conv1DGenerator
if (!setOperKind(reduceOp))
return;
- auto maybeKind = getCombinerOpKind(reduceOp);
- if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
+ maybeKind = getCombinerOpKind(reduceOp);
+ // Typically convolution will have a `Add` CombiningKind but for i1 type it
+ // can get strength reduced to `OR` which is also supported.
+ if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
+ *maybeKind != vector::CombiningKind::OR) &&
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
return;
}
-
auto rhsRank = rhsShapedType.getRank();
switch (oper) {
case Conv:
@@ -3156,9 +3158,9 @@ struct Conv1DGenerator
lhsVals[linearIndex(kw, w)],
rhsVals[kw], resVals[w]);
} else {
- resVals[w] = conv1dSliceAsContraction(rewriter, loc,
- lhsVals[linearIndex(kw, w)],
- rhsVals[kw], resVals[w]);
+ resVals[w] = conv1dSliceAsContraction(
+ rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
+ resVals[w], maybeKind);
}
break;
case Pool:
@@ -3226,18 +3228,24 @@ struct Conv1DGenerator
}
// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
- Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
- Value lhs, Value rhs, Value res) {
+ Value
+ conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, Value lhs,
+ Value rhs, Value res,
+ std::optional<vector::CombiningKind> maybeKind) {
vector::IteratorType par = vector::IteratorType::parallel;
vector::IteratorType red = vector::IteratorType::reduction;
AffineExpr n, w, f, c;
bindDims(ctx, n, w, f, c);
lhs = promote(rewriter, loc, lhs, res.getType());
rhs = promote(rewriter, loc, rhs, res.getType());
- return rewriter.create<vector::ContractionOp>(
+ auto ContrationOp = rewriter.create<vector::ContractionOp>(
loc, lhs, rhs, res,
/*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
/*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
+ if (maybeKind) {
+ ContrationOp.setKind(*maybeKind);
+ }
+ return ContrationOp;
}
// Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
@@ -3627,6 +3635,7 @@ struct Conv1DGenerator
int strideW, dilationW;
Value lhsShaped, rhsShaped, resShaped;
ShapedType lhsShapedType, rhsShapedType, resShapedType;
+ std::optional<vector::CombiningKind> maybeKind;
// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
// Returns true iff it is a valid conv/pooling op.
@@ -3642,7 +3651,8 @@ struct Conv1DGenerator
switch (numBlockArguments) {
case 1: {
// Will be convolution if feeder is a MulOp.
- // Otherwise, if it can be pooling.
+ // A strength reduced version of MulOp for i1 type is AndOp which is also
+ // supported. Otherwise, it can be pooling.
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
llvm::IsaPred<BlockArgument>);
Operation *feedOp = (*feedValIt).getDefiningOp();
@@ -3650,7 +3660,9 @@ struct Conv1DGenerator
oper = Pool;
isPoolExt = true;
poolExtOp = feedOp->getName().getIdentifier();
- } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
+ } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+ (isa<arith::AndIOp>(feedOp) &&
+ feedOp->getResultTypes()[0].isInteger(1))) &&
llvm::all_of(feedOp->getOperands(), [](Value v) {
if (isa<BlockArgument>(v))
return true;
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 93e36a69567bd5..84e790954b4d02 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -654,6 +654,35 @@ func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter
// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[CAST0]], %[[CAST1]], %[[READ2]]
// CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]]
+// -----
+
+func.func @conv2d_i1_i1_i1(%arg0: tensor<1x8x6xi1>, %arg1: tensor<8x8x1xi1>, %arg2: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
+ %0 = linalg.conv_1d_ncw_fcw
+ {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
+ ins(%arg0, %arg1 : tensor<1x8x6xi1>, tensor<8x8x1xi1>)
+ outs(%arg2 : tensor<1x8x6xi1>) -> tensor<1x8x6xi1>
+ return %0 : tensor<1x8x6xi1>
+}
+
+// CHECK-LABEL: func @conv2d_i1_i1_i1
+// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: tensor<1x8x6xi1>, %[[FILTER:[0-9a-z]+]]: tensor<8x8x1xi1>, %[[OUTPUT:[0-9a-z]+]]: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
+// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[FALSE:.+]] = arith.constant false
+// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[INPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[FILTER]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG: %[[TREAD0:.+]] = vector.transpose %[[READ0]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
+// CHECK-DAG: %[[TREAD1:.+]] = vector.transpose %[[READ1]], [2, 1, 0] : vector<8x8x1xi1> to vector<1x8x8xi1>
+// CHECK-DAG: %[[TREAD2:.+]] = vector.transpose %[[READ2]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[TREAD1]][0] : vector<8x8xi1> from vector<1x8x8xi1>
+// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<or>}
+// CHECK-SAME: %[[TREAD0]], %[[EXTRACT]], %[[TREAD2]] : vector<1x6x8xi1>, vector<8x8xi1> into vector<1x6x8xi1>
+// CHECK: %[[TCONTRACT:.+]] = vector.transpose %[[CONTRACT]], [0, 2, 1] : vector<1x6x8xi1> to vector<1x8x6xi1>
+// CHECK: %[[RESULT:.+]] = vector.transfer_write %[[TCONTRACT]], %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]]
+// CHECK: return %[[RESULT]] : tensor<1x8x6xi1>
+
+
+
// -----
func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {
|
@llvm/pr-subscribers-mlir-linalg Author: Nirvedh Meshram (nirvedhmeshram) ChangesNormally convolutions present with the following linalg op region
However, for i1 due to strength reduction we get something like
This PR updates the logic to support this region for i1 types. Full diff: https://github.com/llvm/llvm-project/pull/109480.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a376afa5ddab12..1cdf937742fd2e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2947,12 +2947,14 @@ struct Conv1DGenerator
if (!setOperKind(reduceOp))
return;
- auto maybeKind = getCombinerOpKind(reduceOp);
- if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
+ maybeKind = getCombinerOpKind(reduceOp);
+ // Typically convolution will have a `Add` CombiningKind but for i1 type it
+ // can get strength reduced to `OR` which is also supported.
+ if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
+ *maybeKind != vector::CombiningKind::OR) &&
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
return;
}
-
auto rhsRank = rhsShapedType.getRank();
switch (oper) {
case Conv:
@@ -3156,9 +3158,9 @@ struct Conv1DGenerator
lhsVals[linearIndex(kw, w)],
rhsVals[kw], resVals[w]);
} else {
- resVals[w] = conv1dSliceAsContraction(rewriter, loc,
- lhsVals[linearIndex(kw, w)],
- rhsVals[kw], resVals[w]);
+ resVals[w] = conv1dSliceAsContraction(
+ rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
+ resVals[w], maybeKind);
}
break;
case Pool:
@@ -3226,18 +3228,24 @@ struct Conv1DGenerator
}
// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
- Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
- Value lhs, Value rhs, Value res) {
+ Value
+ conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, Value lhs,
+ Value rhs, Value res,
+ std::optional<vector::CombiningKind> maybeKind) {
vector::IteratorType par = vector::IteratorType::parallel;
vector::IteratorType red = vector::IteratorType::reduction;
AffineExpr n, w, f, c;
bindDims(ctx, n, w, f, c);
lhs = promote(rewriter, loc, lhs, res.getType());
rhs = promote(rewriter, loc, rhs, res.getType());
- return rewriter.create<vector::ContractionOp>(
+ auto ContrationOp = rewriter.create<vector::ContractionOp>(
loc, lhs, rhs, res,
/*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
/*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
+ if (maybeKind) {
+ ContrationOp.setKind(*maybeKind);
+ }
+ return ContrationOp;
}
// Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
@@ -3627,6 +3635,7 @@ struct Conv1DGenerator
int strideW, dilationW;
Value lhsShaped, rhsShaped, resShaped;
ShapedType lhsShapedType, rhsShapedType, resShapedType;
+ std::optional<vector::CombiningKind> maybeKind;
// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
// Returns true iff it is a valid conv/pooling op.
@@ -3642,7 +3651,8 @@ struct Conv1DGenerator
switch (numBlockArguments) {
case 1: {
// Will be convolution if feeder is a MulOp.
- // Otherwise, if it can be pooling.
+ // A strength reduced version of MulOp for i1 type is AndOp which is also
+ // supported. Otherwise, it can be pooling.
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
llvm::IsaPred<BlockArgument>);
Operation *feedOp = (*feedValIt).getDefiningOp();
@@ -3650,7 +3660,9 @@ struct Conv1DGenerator
oper = Pool;
isPoolExt = true;
poolExtOp = feedOp->getName().getIdentifier();
- } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
+ } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+ (isa<arith::AndIOp>(feedOp) &&
+ feedOp->getResultTypes()[0].isInteger(1))) &&
llvm::all_of(feedOp->getOperands(), [](Value v) {
if (isa<BlockArgument>(v))
return true;
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 93e36a69567bd5..84e790954b4d02 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -654,6 +654,35 @@ func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter
// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[CAST0]], %[[CAST1]], %[[READ2]]
// CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]]
+// -----
+
+func.func @conv2d_i1_i1_i1(%arg0: tensor<1x8x6xi1>, %arg1: tensor<8x8x1xi1>, %arg2: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
+ %0 = linalg.conv_1d_ncw_fcw
+ {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
+ ins(%arg0, %arg1 : tensor<1x8x6xi1>, tensor<8x8x1xi1>)
+ outs(%arg2 : tensor<1x8x6xi1>) -> tensor<1x8x6xi1>
+ return %0 : tensor<1x8x6xi1>
+}
+
+// CHECK-LABEL: func @conv2d_i1_i1_i1
+// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: tensor<1x8x6xi1>, %[[FILTER:[0-9a-z]+]]: tensor<8x8x1xi1>, %[[OUTPUT:[0-9a-z]+]]: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
+// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[FALSE:.+]] = arith.constant false
+// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[INPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[FILTER]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG: %[[TREAD0:.+]] = vector.transpose %[[READ0]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
+// CHECK-DAG: %[[TREAD1:.+]] = vector.transpose %[[READ1]], [2, 1, 0] : vector<8x8x1xi1> to vector<1x8x8xi1>
+// CHECK-DAG: %[[TREAD2:.+]] = vector.transpose %[[READ2]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[TREAD1]][0] : vector<8x8xi1> from vector<1x8x8xi1>
+// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<or>}
+// CHECK-SAME: %[[TREAD0]], %[[EXTRACT]], %[[TREAD2]] : vector<1x6x8xi1>, vector<8x8xi1> into vector<1x6x8xi1>
+// CHECK: %[[TCONTRACT:.+]] = vector.transpose %[[CONTRACT]], [0, 2, 1] : vector<1x6x8xi1> to vector<1x8x6xi1>
+// CHECK: %[[RESULT:.+]] = vector.transfer_write %[[TCONTRACT]], %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]]
+// CHECK: return %[[RESULT]] : tensor<1x8x6xi1>
+
+
+
// -----
func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! A few minor comments.
func.func @conv2d_i1_i1_i1(%arg0: tensor<1x8x6xi1>, %arg1: tensor<8x8x1xi1>, %arg2: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> { | ||
%0 = linalg.conv_1d_ncw_fcw | ||
{dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} | ||
ins(%arg0, %arg1 : tensor<1x8x6xi1>, tensor<8x8x1xi1>) | ||
outs(%arg2 : tensor<1x8x6xi1>) -> tensor<1x8x6xi1> | ||
return %0 : tensor<1x8x6xi1> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you adopt one of the existing tests for linalg.conv_1d_ncw_fcw
instead? And move this next to the original example that you would adopt? Thanks! This way it will be much easier to see all the cases that are tested for a particular Op. Also, it would be good to add a few more examples for i1
.
IIUC, the only difference between i1
and e.g. i32
would be the combining Op in vector.contract
? If that's the case, then IMHO you can write rather reduced CHECK
lines that primarily verify the contract Op. Everything else should be identical to what we get today for i32
, right? As per https://mlir.llvm.org/getting_started/TestingGuide/:
focus on testing the minimal set of functionalities needed
There's obviously room for interpretation and having more CHECK
lines is also totally fine :)
Normally convolutions present with the following linalg op region
However, for i1 due to strength reduction we get something like
This PR updates the logic to support this region for i1 types.