From e5fc9412991ea5331c1d7891e131bc22fe0fccbe Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Wed, 18 Sep 2024 22:31:23 +0000 Subject: [PATCH 1/2] [mlir][linalg] Vectorization support for convolution of i1 type --- .../Linalg/Transforms/Vectorization.cpp | 34 +++++++++++++------ .../Dialect/Linalg/vectorize-convolution.mlir | 29 ++++++++++++++++ 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index a376afa5ddab12..1cdf937742fd2e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2947,12 +2947,14 @@ struct Conv1DGenerator if (!setOperKind(reduceOp)) return; - auto maybeKind = getCombinerOpKind(reduceOp); - if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD && + maybeKind = getCombinerOpKind(reduceOp); + // Typically convolution will have a `Add` CombiningKind but for i1 type it + // can get strength reduced to `OR` which is also supported. + if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD && + *maybeKind != vector::CombiningKind::OR) && (oper != Pool || !isSupportedPoolKind(*maybeKind)))) { return; } - auto rhsRank = rhsShapedType.getRank(); switch (oper) { case Conv: @@ -3156,9 +3158,9 @@ struct Conv1DGenerator lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } else { - resVals[w] = conv1dSliceAsContraction(rewriter, loc, - lhsVals[linearIndex(kw, w)], - rhsVals[kw], resVals[w]); + resVals[w] = conv1dSliceAsContraction( + rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], + resVals[w], maybeKind); } break; case Pool: @@ -3226,18 +3228,24 @@ struct Conv1DGenerator } // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} - Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, - Value lhs, Value rhs, Value res) { + Value + conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, Value lhs, + Value rhs, Value res, + std::optional maybeKind) { vector::IteratorType par = vector::IteratorType::parallel; vector::IteratorType red = vector::IteratorType::reduction; AffineExpr n, w, f, c; bindDims(ctx, n, w, f, c); lhs = promote(rewriter, loc, lhs, res.getType()); rhs = promote(rewriter, loc, rhs, res.getType()); - return rewriter.create( + auto ContrationOp = rewriter.create( loc, lhs, rhs, res, /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, /*iteratorTypes=*/ArrayRef{par, par, par, red}); + if (maybeKind) { + ContrationOp.setKind(*maybeKind); + } + return ContrationOp; } // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel @@ -3627,6 +3635,7 @@ struct Conv1DGenerator int strideW, dilationW; Value lhsShaped, rhsShaped, resShaped; ShapedType lhsShapedType, rhsShapedType, resShapedType; + std::optional maybeKind; // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops. // Returns true iff it is a valid conv/pooling op. @@ -3642,7 +3651,8 @@ struct Conv1DGenerator switch (numBlockArguments) { case 1: { // Will be convolution if feeder is a MulOp. - // Otherwise, if it can be pooling. + // A strength reduced version of MulOp for i1 type is AndOp which is also + // supported. Otherwise, it can be pooling. auto feedValIt = llvm::find_if_not(reduceOp->getOperands(), llvm::IsaPred); Operation *feedOp = (*feedValIt).getDefiningOp(); @@ -3650,7 +3660,9 @@ struct Conv1DGenerator oper = Pool; isPoolExt = true; poolExtOp = feedOp->getName().getIdentifier(); - } else if (!(isa(feedOp) && + } else if (!((isa(feedOp) || + (isa(feedOp) && + feedOp->getResultTypes()[0].isInteger(1))) && llvm::all_of(feedOp->getOperands(), [](Value v) { if (isa(v)) return true; diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir index 93e36a69567bd5..84e790954b4d02 100644 --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -654,6 +654,35 @@ func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter // CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind} %[[CAST0]], %[[CAST1]], %[[READ2]] // CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]] +// ----- + +func.func @conv2d_i1_i1_i1(%arg0: tensor<1x8x6xi1>, %arg1: tensor<8x8x1xi1>, %arg2: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> { + %0 = linalg.conv_1d_ncw_fcw + {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} + ins(%arg0, %arg1 : tensor<1x8x6xi1>, tensor<8x8x1xi1>) + outs(%arg2 : tensor<1x8x6xi1>) -> tensor<1x8x6xi1> + return %0 : tensor<1x8x6xi1> +} + +// CHECK-LABEL: func @conv2d_i1_i1_i1 +// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: tensor<1x8x6xi1>, %[[FILTER:[0-9a-z]+]]: tensor<8x8x1xi1>, %[[OUTPUT:[0-9a-z]+]]: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> { +// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[INPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]] +// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[FILTER]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]] +// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]] +// CHECK-DAG: %[[TREAD0:.+]] = vector.transpose %[[READ0]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1> +// CHECK-DAG: %[[TREAD1:.+]] = vector.transpose %[[READ1]], [2, 1, 0] : vector<8x8x1xi1> to vector<1x8x8xi1> +// CHECK-DAG: %[[TREAD2:.+]] = vector.transpose %[[READ2]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[TREAD1]][0] : vector<8x8xi1> from vector<1x8x8xi1> +// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: %[[TREAD0]], %[[EXTRACT]], %[[TREAD2]] : vector<1x6x8xi1>, vector<8x8xi1> into vector<1x6x8xi1> +// CHECK: %[[TCONTRACT:.+]] = vector.transpose %[[CONTRACT]], [0, 2, 1] : vector<1x6x8xi1> to vector<1x8x6xi1> +// CHECK: %[[RESULT:.+]] = vector.transfer_write %[[TCONTRACT]], %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]] +// CHECK: return %[[RESULT]] : tensor<1x8x6xi1> + + + // ----- func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) { From 6f839f3d5da6a787c8e67cbb9e109deffe1365bd Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Mon, 23 Sep 2024 22:14:16 +0000 Subject: [PATCH 2/2] address reviwer comments --- .../Linalg/Transforms/Vectorization.cpp | 32 +++---- .../Dialect/Linalg/vectorize-convolution.mlir | 83 ++++++++++++------- 2 files changed, 70 insertions(+), 45 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 1cdf937742fd2e..c06ad50ff4bfe0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2947,14 +2947,17 @@ struct Conv1DGenerator if (!setOperKind(reduceOp)) return; - maybeKind = getCombinerOpKind(reduceOp); + auto maybeKind = getCombinerOpKind(reduceOp); // Typically convolution will have a `Add` CombiningKind but for i1 type it - // can get strength reduced to `OR` which is also supported. + // can get strength reduced to `OR` which is also supported. This strength + // reduction logic is in `buildBinaryFn` helper in the Linalg dialect. if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD && *maybeKind != vector::CombiningKind::OR) && (oper != Pool || !isSupportedPoolKind(*maybeKind)))) { return; } + reductionKind = maybeKind.value(); + auto rhsRank = rhsShapedType.getRank(); switch (oper) { case Conv: @@ -3158,9 +3161,9 @@ struct Conv1DGenerator lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } else { - resVals[w] = conv1dSliceAsContraction( - rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], - resVals[w], maybeKind); + resVals[w] = conv1dSliceAsContraction(rewriter, loc, + lhsVals[linearIndex(kw, w)], + rhsVals[kw], resVals[w]); } break; case Pool: @@ -3228,24 +3231,20 @@ struct Conv1DGenerator } // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} - Value - conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, Value lhs, - Value rhs, Value res, - std::optional maybeKind) { + Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, + Value lhs, Value rhs, Value res) { vector::IteratorType par = vector::IteratorType::parallel; vector::IteratorType red = vector::IteratorType::reduction; AffineExpr n, w, f, c; bindDims(ctx, n, w, f, c); lhs = promote(rewriter, loc, lhs, res.getType()); rhs = promote(rewriter, loc, rhs, res.getType()); - auto ContrationOp = rewriter.create( + auto contrationOp = rewriter.create( loc, lhs, rhs, res, /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, /*iteratorTypes=*/ArrayRef{par, par, par, red}); - if (maybeKind) { - ContrationOp.setKind(*maybeKind); - } - return ContrationOp; + contrationOp.setKind(reductionKind); + return contrationOp; } // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel @@ -3635,7 +3634,7 @@ struct Conv1DGenerator int strideW, dilationW; Value lhsShaped, rhsShaped, resShaped; ShapedType lhsShapedType, rhsShapedType, resShapedType; - std::optional maybeKind; + vector::CombiningKind reductionKind; // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops. // Returns true iff it is a valid conv/pooling op. @@ -3652,7 +3651,8 @@ struct Conv1DGenerator case 1: { // Will be convolution if feeder is a MulOp. // A strength reduced version of MulOp for i1 type is AndOp which is also - // supported. Otherwise, it can be pooling. + // supported. Otherwise, it can be pooling. This strength reduction logic + // is in `buildBinaryFn` helper in the Linalg dialect. auto feedValIt = llvm::find_if_not(reduceOp->getOperands(), llvm::IsaPred); Operation *feedOp = (*feedValIt).getDefiningOp(); diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir index 84e790954b4d02..d6cf57fc3c1448 100644 --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -61,6 +61,33 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x // ----- +// This test is same as above but for i1 type. +func.func @conv1d_nwc_4x2x8_memref_i1(%input: memref<4x6x3xi1>, %filter: memref<1x3x8xi1>, %output: memref<4x2x8xi1>) { + linalg.conv_1d_nwc_wcf + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x6x3xi1>, memref<1x3x8xi1>) + outs(%output : memref<4x2x8xi1>) + return +} +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func @conv1d_nwc_4x2x8_memref_i1 +/// w == 0, kw == 0 +// CHECK: %[[CONTRACT_0:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1> + +/// w == 1, kw == 0 +// CHECK: %[[CONTRACT_1:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1> + +// ----- + // The i8i8i32 case is similar to f32 case, so checking one case is enough for // test coverage. func.func @conv1d_nwc_4x2x8_i8i8i32_memref(%input: memref<4x6x3xi8>, %filter: memref<1x3x8xi8>, %output: memref<4x2x8xi32>) { @@ -324,6 +351,33 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x // ----- +// This test is same as above but for i1 type. +func.func @conv1d_ncw_4x8x2_memref_i1(%input: memref<4x3x6xi1>, %filter: memref<8x3x1xi1>, %output: memref<4x8x2xi1>) { + linalg.conv_1d_ncw_fcw + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x3x6xi1>, memref<8x3x1xi1>) + outs(%output : memref<4x8x2xi1>) + return +} + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func @conv1d_ncw_4x8x2_memref_i1 +/// w == 0, kw == 0 +// CHECK: vector.contract { +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1> + +/// w == 1, kw == 0 +// CHECK: vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1> + +// ----- + func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x3x2xf32>, %output: memref<4x8x2xf32>) { linalg.conv_1d_ncw_fcw {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} @@ -654,35 +708,6 @@ func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter // CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind} %[[CAST0]], %[[CAST1]], %[[READ2]] // CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]] -// ----- - -func.func @conv2d_i1_i1_i1(%arg0: tensor<1x8x6xi1>, %arg1: tensor<8x8x1xi1>, %arg2: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> { - %0 = linalg.conv_1d_ncw_fcw - {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} - ins(%arg0, %arg1 : tensor<1x8x6xi1>, tensor<8x8x1xi1>) - outs(%arg2 : tensor<1x8x6xi1>) -> tensor<1x8x6xi1> - return %0 : tensor<1x8x6xi1> -} - -// CHECK-LABEL: func @conv2d_i1_i1_i1 -// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: tensor<1x8x6xi1>, %[[FILTER:[0-9a-z]+]]: tensor<8x8x1xi1>, %[[OUTPUT:[0-9a-z]+]]: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> { -// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[FALSE:.+]] = arith.constant false -// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[INPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]] -// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[FILTER]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]] -// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]] -// CHECK-DAG: %[[TREAD0:.+]] = vector.transpose %[[READ0]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1> -// CHECK-DAG: %[[TREAD1:.+]] = vector.transpose %[[READ1]], [2, 1, 0] : vector<8x8x1xi1> to vector<1x8x8xi1> -// CHECK-DAG: %[[TREAD2:.+]] = vector.transpose %[[READ2]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1> -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[TREAD1]][0] : vector<8x8xi1> from vector<1x8x8xi1> -// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind} -// CHECK-SAME: %[[TREAD0]], %[[EXTRACT]], %[[TREAD2]] : vector<1x6x8xi1>, vector<8x8xi1> into vector<1x6x8xi1> -// CHECK: %[[TCONTRACT:.+]] = vector.transpose %[[CONTRACT]], [0, 2, 1] : vector<1x6x8xi1> to vector<1x8x6xi1> -// CHECK: %[[RESULT:.+]] = vector.transfer_write %[[TCONTRACT]], %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]] -// CHECK: return %[[RESULT]] : tensor<1x8x6xi1> - - - // ----- func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {