Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Vectorization support for convolution of i1 type #109480

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

nirvedhmeshram
Copy link
Contributor

Normally convolutions present with the following linalg op region

^bb0(%arg14: i4, %arg15: i4, %arg16: i4):
  %17 = arith.muli %arg14, %arg15 : i4
  %18 = arith.addi %arg16, %17 : i4
  linalg.yield %18 : i4

However, for i1 due to strength reduction we get something like

^bb0(%arg14: i1, %arg15: i1, %arg16: i1):
%17 = arith.andi %arg14, %arg15 : i1
%18 = arith.ori %arg16, %17 : i1
linalg.yield %18 : i1

This PR updates the logic to support this region for i1 types.

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 20, 2024

@llvm/pr-subscribers-mlir

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

Normally convolutions present with the following linalg op region

^bb0(%arg14: i4, %arg15: i4, %arg16: i4):
  %17 = arith.muli %arg14, %arg15 : i4
  %18 = arith.addi %arg16, %17 : i4
  linalg.yield %18 : i4

However, for i1 due to strength reduction we get something like

^bb0(%arg14: i1, %arg15: i1, %arg16: i1):
%17 = arith.andi %arg14, %arg15 : i1
%18 = arith.ori %arg16, %17 : i1
linalg.yield %18 : i1

This PR updates the logic to support this region for i1 types.


Full diff: https://github.com/llvm/llvm-project/pull/109480.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+23-11)
  • (modified) mlir/test/Dialect/Linalg/vectorize-convolution.mlir (+29)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a376afa5ddab12..1cdf937742fd2e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2947,12 +2947,14 @@ struct Conv1DGenerator
 
     if (!setOperKind(reduceOp))
       return;
-    auto maybeKind = getCombinerOpKind(reduceOp);
-    if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
+    maybeKind = getCombinerOpKind(reduceOp);
+    // Typically convolution will have a `Add` CombiningKind but for i1 type it
+    // can get strength reduced to `OR` which is also supported.
+    if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
+                        *maybeKind != vector::CombiningKind::OR) &&
                        (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
       return;
     }
-
     auto rhsRank = rhsShapedType.getRank();
     switch (oper) {
     case Conv:
@@ -3156,9 +3158,9 @@ struct Conv1DGenerator
                                                    lhsVals[linearIndex(kw, w)],
                                                    rhsVals[kw], resVals[w]);
           } else {
-            resVals[w] = conv1dSliceAsContraction(rewriter, loc,
-                                                  lhsVals[linearIndex(kw, w)],
-                                                  rhsVals[kw], resVals[w]);
+            resVals[w] = conv1dSliceAsContraction(
+                rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
+                resVals[w], maybeKind);
           }
           break;
         case Pool:
@@ -3226,18 +3228,24 @@ struct Conv1DGenerator
   }
 
   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
-  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
-                                 Value lhs, Value rhs, Value res) {
+  Value
+  conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, Value lhs,
+                           Value rhs, Value res,
+                           std::optional<vector::CombiningKind> maybeKind) {
     vector::IteratorType par = vector::IteratorType::parallel;
     vector::IteratorType red = vector::IteratorType::reduction;
     AffineExpr n, w, f, c;
     bindDims(ctx, n, w, f, c);
     lhs = promote(rewriter, loc, lhs, res.getType());
     rhs = promote(rewriter, loc, rhs, res.getType());
-    return rewriter.create<vector::ContractionOp>(
+    auto ContrationOp = rewriter.create<vector::ContractionOp>(
         loc, lhs, rhs, res,
         /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
         /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
+    if (maybeKind) {
+      ContrationOp.setKind(*maybeKind);
+    }
+    return ContrationOp;
   }
 
   // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
@@ -3627,6 +3635,7 @@ struct Conv1DGenerator
   int strideW, dilationW;
   Value lhsShaped, rhsShaped, resShaped;
   ShapedType lhsShapedType, rhsShapedType, resShapedType;
+  std::optional<vector::CombiningKind> maybeKind;
 
   // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
   // Returns true iff it is a valid conv/pooling op.
@@ -3642,7 +3651,8 @@ struct Conv1DGenerator
     switch (numBlockArguments) {
     case 1: {
       // Will be convolution if feeder is a MulOp.
-      // Otherwise, if it can be pooling.
+      // A strength reduced version of MulOp for i1 type is AndOp which is also
+      // supported. Otherwise, it can be pooling.
       auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
                                          llvm::IsaPred<BlockArgument>);
       Operation *feedOp = (*feedValIt).getDefiningOp();
@@ -3650,7 +3660,9 @@ struct Conv1DGenerator
         oper = Pool;
         isPoolExt = true;
         poolExtOp = feedOp->getName().getIdentifier();
-      } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
+      } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+                    (isa<arith::AndIOp>(feedOp) &&
+                     feedOp->getResultTypes()[0].isInteger(1))) &&
                    llvm::all_of(feedOp->getOperands(), [](Value v) {
                      if (isa<BlockArgument>(v))
                        return true;
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 93e36a69567bd5..84e790954b4d02 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -654,6 +654,35 @@ func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter
 // CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[CAST0]], %[[CAST1]], %[[READ2]]
 // CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]]
 
+// -----
+
+func.func @conv2d_i1_i1_i1(%arg0: tensor<1x8x6xi1>, %arg1: tensor<8x8x1xi1>, %arg2: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
+  %0 = linalg.conv_1d_ncw_fcw 
+  {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
+   ins(%arg0, %arg1 : tensor<1x8x6xi1>, tensor<8x8x1xi1>) 
+   outs(%arg2 : tensor<1x8x6xi1>) -> tensor<1x8x6xi1>
+  return %0 : tensor<1x8x6xi1>
+}
+
+// CHECK-LABEL:  func @conv2d_i1_i1_i1
+// CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: tensor<1x8x6xi1>, %[[FILTER:[0-9a-z]+]]: tensor<8x8x1xi1>, %[[OUTPUT:[0-9a-z]+]]: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
+// CHECK-DAG:    %[[I0:.+]] = arith.constant 0 : index
+// CHECK-DAG:    %[[FALSE:.+]] = arith.constant false
+// CHECK-DAG:    %[[READ0:.+]] = vector.transfer_read %[[INPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG:    %[[READ1:.+]] = vector.transfer_read %[[FILTER]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG:    %[[READ2:.+]] = vector.transfer_read %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG:    %[[TREAD0:.+]] = vector.transpose %[[READ0]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
+// CHECK-DAG:    %[[TREAD1:.+]] = vector.transpose %[[READ1]], [2, 1, 0] : vector<8x8x1xi1> to vector<1x8x8xi1>
+// CHECK-DAG:    %[[TREAD2:.+]] = vector.transpose %[[READ2]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
+// CHECK:        %[[EXTRACT:.+]] = vector.extract %[[TREAD1]][0] : vector<8x8xi1> from vector<1x8x8xi1>
+// CHECK:        %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<or>} 
+// CHECK-SAME:   %[[TREAD0]], %[[EXTRACT]], %[[TREAD2]] : vector<1x6x8xi1>, vector<8x8xi1> into vector<1x6x8xi1>
+// CHECK:        %[[TCONTRACT:.+]] = vector.transpose %[[CONTRACT]], [0, 2, 1] : vector<1x6x8xi1> to vector<1x8x6xi1>
+// CHECK:        %[[RESULT:.+]] = vector.transfer_write %[[TCONTRACT]], %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]] 
+// CHECK:        return %[[RESULT]] : tensor<1x8x6xi1>
+
+
+
 // -----
 
 func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 20, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

Normally convolutions present with the following linalg op region

^bb0(%arg14: i4, %arg15: i4, %arg16: i4):
  %17 = arith.muli %arg14, %arg15 : i4
  %18 = arith.addi %arg16, %17 : i4
  linalg.yield %18 : i4

However, for i1 due to strength reduction we get something like

^bb0(%arg14: i1, %arg15: i1, %arg16: i1):
%17 = arith.andi %arg14, %arg15 : i1
%18 = arith.ori %arg16, %17 : i1
linalg.yield %18 : i1

This PR updates the logic to support this region for i1 types.


Full diff: https://github.com/llvm/llvm-project/pull/109480.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+23-11)
  • (modified) mlir/test/Dialect/Linalg/vectorize-convolution.mlir (+29)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a376afa5ddab12..1cdf937742fd2e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2947,12 +2947,14 @@ struct Conv1DGenerator
 
     if (!setOperKind(reduceOp))
       return;
-    auto maybeKind = getCombinerOpKind(reduceOp);
-    if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
+    maybeKind = getCombinerOpKind(reduceOp);
+    // Typically convolution will have a `Add` CombiningKind but for i1 type it
+    // can get strength reduced to `OR` which is also supported.
+    if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
+                        *maybeKind != vector::CombiningKind::OR) &&
                        (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
       return;
     }
-
     auto rhsRank = rhsShapedType.getRank();
     switch (oper) {
     case Conv:
@@ -3156,9 +3158,9 @@ struct Conv1DGenerator
                                                    lhsVals[linearIndex(kw, w)],
                                                    rhsVals[kw], resVals[w]);
           } else {
-            resVals[w] = conv1dSliceAsContraction(rewriter, loc,
-                                                  lhsVals[linearIndex(kw, w)],
-                                                  rhsVals[kw], resVals[w]);
+            resVals[w] = conv1dSliceAsContraction(
+                rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
+                resVals[w], maybeKind);
           }
           break;
         case Pool:
@@ -3226,18 +3228,24 @@ struct Conv1DGenerator
   }
 
   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
-  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
-                                 Value lhs, Value rhs, Value res) {
+  Value
+  conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, Value lhs,
+                           Value rhs, Value res,
+                           std::optional<vector::CombiningKind> maybeKind) {
     vector::IteratorType par = vector::IteratorType::parallel;
     vector::IteratorType red = vector::IteratorType::reduction;
     AffineExpr n, w, f, c;
     bindDims(ctx, n, w, f, c);
     lhs = promote(rewriter, loc, lhs, res.getType());
     rhs = promote(rewriter, loc, rhs, res.getType());
-    return rewriter.create<vector::ContractionOp>(
+    auto ContrationOp = rewriter.create<vector::ContractionOp>(
         loc, lhs, rhs, res,
         /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
         /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
+    if (maybeKind) {
+      ContrationOp.setKind(*maybeKind);
+    }
+    return ContrationOp;
   }
 
   // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
@@ -3627,6 +3635,7 @@ struct Conv1DGenerator
   int strideW, dilationW;
   Value lhsShaped, rhsShaped, resShaped;
   ShapedType lhsShapedType, rhsShapedType, resShapedType;
+  std::optional<vector::CombiningKind> maybeKind;
 
   // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
   // Returns true iff it is a valid conv/pooling op.
@@ -3642,7 +3651,8 @@ struct Conv1DGenerator
     switch (numBlockArguments) {
     case 1: {
       // Will be convolution if feeder is a MulOp.
-      // Otherwise, if it can be pooling.
+      // A strength reduced version of MulOp for i1 type is AndOp which is also
+      // supported. Otherwise, it can be pooling.
       auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
                                          llvm::IsaPred<BlockArgument>);
       Operation *feedOp = (*feedValIt).getDefiningOp();
@@ -3650,7 +3660,9 @@ struct Conv1DGenerator
         oper = Pool;
         isPoolExt = true;
         poolExtOp = feedOp->getName().getIdentifier();
-      } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
+      } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+                    (isa<arith::AndIOp>(feedOp) &&
+                     feedOp->getResultTypes()[0].isInteger(1))) &&
                    llvm::all_of(feedOp->getOperands(), [](Value v) {
                      if (isa<BlockArgument>(v))
                        return true;
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 93e36a69567bd5..84e790954b4d02 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -654,6 +654,35 @@ func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter
 // CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[CAST0]], %[[CAST1]], %[[READ2]]
 // CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]]
 
+// -----
+
+func.func @conv2d_i1_i1_i1(%arg0: tensor<1x8x6xi1>, %arg1: tensor<8x8x1xi1>, %arg2: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
+  %0 = linalg.conv_1d_ncw_fcw 
+  {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
+   ins(%arg0, %arg1 : tensor<1x8x6xi1>, tensor<8x8x1xi1>) 
+   outs(%arg2 : tensor<1x8x6xi1>) -> tensor<1x8x6xi1>
+  return %0 : tensor<1x8x6xi1>
+}
+
+// CHECK-LABEL:  func @conv2d_i1_i1_i1
+// CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: tensor<1x8x6xi1>, %[[FILTER:[0-9a-z]+]]: tensor<8x8x1xi1>, %[[OUTPUT:[0-9a-z]+]]: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
+// CHECK-DAG:    %[[I0:.+]] = arith.constant 0 : index
+// CHECK-DAG:    %[[FALSE:.+]] = arith.constant false
+// CHECK-DAG:    %[[READ0:.+]] = vector.transfer_read %[[INPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG:    %[[READ1:.+]] = vector.transfer_read %[[FILTER]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG:    %[[READ2:.+]] = vector.transfer_read %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG:    %[[TREAD0:.+]] = vector.transpose %[[READ0]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
+// CHECK-DAG:    %[[TREAD1:.+]] = vector.transpose %[[READ1]], [2, 1, 0] : vector<8x8x1xi1> to vector<1x8x8xi1>
+// CHECK-DAG:    %[[TREAD2:.+]] = vector.transpose %[[READ2]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
+// CHECK:        %[[EXTRACT:.+]] = vector.extract %[[TREAD1]][0] : vector<8x8xi1> from vector<1x8x8xi1>
+// CHECK:        %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<or>} 
+// CHECK-SAME:   %[[TREAD0]], %[[EXTRACT]], %[[TREAD2]] : vector<1x6x8xi1>, vector<8x8xi1> into vector<1x6x8xi1>
+// CHECK:        %[[TCONTRACT:.+]] = vector.transpose %[[CONTRACT]], [0, 2, 1] : vector<1x6x8xi1> to vector<1x8x6xi1>
+// CHECK:        %[[RESULT:.+]] = vector.transfer_write %[[TCONTRACT]], %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]] 
+// CHECK:        return %[[RESULT]] : tensor<1x8x6xi1>
+
+
+
 // -----
 
 func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! A few minor comments.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
Comment on lines 659 to 665
func.func @conv2d_i1_i1_i1(%arg0: tensor<1x8x6xi1>, %arg1: tensor<8x8x1xi1>, %arg2: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
%0 = linalg.conv_1d_ncw_fcw
{dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
ins(%arg0, %arg1 : tensor<1x8x6xi1>, tensor<8x8x1xi1>)
outs(%arg2 : tensor<1x8x6xi1>) -> tensor<1x8x6xi1>
return %0 : tensor<1x8x6xi1>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you adopt one of the existing tests for linalg.conv_1d_ncw_fcw instead? And move this next to the original example that you would adopt? Thanks! This way it will be much easier to see all the cases that are tested for a particular Op. Also, it would be good to add a few more examples for i1.

IIUC, the only difference between i1 and e.g. i32 would be the combining Op in vector.contract? If that's the case, then IMHO you can write rather reduced CHECK lines that primarily verify the contract Op. Everything else should be identical to what we get today for i32, right? As per https://mlir.llvm.org/getting_started/TestingGuide/:

focus on testing the minimal set of functionalities needed

There's obviously room for interpretation and having more CHECK lines is also totally fine :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants