Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Vectorization support for convolution of i1 type #109480

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2947,12 +2947,14 @@ struct Conv1DGenerator

if (!setOperKind(reduceOp))
return;
auto maybeKind = getCombinerOpKind(reduceOp);
if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
maybeKind = getCombinerOpKind(reduceOp);
// Typically convolution will have a `Add` CombiningKind but for i1 type it
// can get strength reduced to `OR` which is also supported.
nirvedhmeshram marked this conversation as resolved.
Show resolved Hide resolved
if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
*maybeKind != vector::CombiningKind::OR) &&
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
return;
}

nirvedhmeshram marked this conversation as resolved.
Show resolved Hide resolved
auto rhsRank = rhsShapedType.getRank();
switch (oper) {
case Conv:
Expand Down Expand Up @@ -3156,9 +3158,9 @@ struct Conv1DGenerator
lhsVals[linearIndex(kw, w)],
rhsVals[kw], resVals[w]);
} else {
resVals[w] = conv1dSliceAsContraction(rewriter, loc,
lhsVals[linearIndex(kw, w)],
rhsVals[kw], resVals[w]);
resVals[w] = conv1dSliceAsContraction(
rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
resVals[w], maybeKind);
}
break;
case Pool:
Expand Down Expand Up @@ -3226,18 +3228,24 @@ struct Conv1DGenerator
}

// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
Value
conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, Value lhs,
Value rhs, Value res,
std::optional<vector::CombiningKind> maybeKind) {
nirvedhmeshram marked this conversation as resolved.
Show resolved Hide resolved
vector::IteratorType par = vector::IteratorType::parallel;
vector::IteratorType red = vector::IteratorType::reduction;
AffineExpr n, w, f, c;
bindDims(ctx, n, w, f, c);
lhs = promote(rewriter, loc, lhs, res.getType());
rhs = promote(rewriter, loc, rhs, res.getType());
return rewriter.create<vector::ContractionOp>(
auto ContrationOp = rewriter.create<vector::ContractionOp>(
nirvedhmeshram marked this conversation as resolved.
Show resolved Hide resolved
loc, lhs, rhs, res,
/*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
/*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
if (maybeKind) {
ContrationOp.setKind(*maybeKind);
}
return ContrationOp;
}

// Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
Expand Down Expand Up @@ -3627,6 +3635,7 @@ struct Conv1DGenerator
int strideW, dilationW;
Value lhsShaped, rhsShaped, resShaped;
ShapedType lhsShapedType, rhsShapedType, resShapedType;
std::optional<vector::CombiningKind> maybeKind;
nirvedhmeshram marked this conversation as resolved.
Show resolved Hide resolved

// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
// Returns true iff it is a valid conv/pooling op.
Expand All @@ -3642,15 +3651,18 @@ struct Conv1DGenerator
switch (numBlockArguments) {
case 1: {
// Will be convolution if feeder is a MulOp.
// Otherwise, if it can be pooling.
// A strength reduced version of MulOp for i1 type is AndOp which is also
// supported. Otherwise, it can be pooling.
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
llvm::IsaPred<BlockArgument>);
Operation *feedOp = (*feedValIt).getDefiningOp();
if (isCastOfBlockArgument(feedOp)) {
oper = Pool;
isPoolExt = true;
poolExtOp = feedOp->getName().getIdentifier();
} else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
} else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
(isa<arith::AndIOp>(feedOp) &&
feedOp->getResultTypes()[0].isInteger(1))) &&
llvm::all_of(feedOp->getOperands(), [](Value v) {
if (isa<BlockArgument>(v))
return true;
Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Dialect/Linalg/vectorize-convolution.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,35 @@ func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter
// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[CAST0]], %[[CAST1]], %[[READ2]]
// CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]]

// -----

func.func @conv2d_i1_i1_i1(%arg0: tensor<1x8x6xi1>, %arg1: tensor<8x8x1xi1>, %arg2: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
%0 = linalg.conv_1d_ncw_fcw
{dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
ins(%arg0, %arg1 : tensor<1x8x6xi1>, tensor<8x8x1xi1>)
outs(%arg2 : tensor<1x8x6xi1>) -> tensor<1x8x6xi1>
return %0 : tensor<1x8x6xi1>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you adopt one of the existing tests for linalg.conv_1d_ncw_fcw instead? And move this next to the original example that you would adopt? Thanks! This way it will be much easier to see all the cases that are tested for a particular Op. Also, it would be good to add a few more examples for i1.

IIUC, the only difference between i1 and e.g. i32 would be the combining Op in vector.contract? If that's the case, then IMHO you can write rather reduced CHECK lines that primarily verify the contract Op. Everything else should be identical to what we get today for i32, right? As per https://mlir.llvm.org/getting_started/TestingGuide/:

focus on testing the minimal set of functionalities needed

There's obviously room for interpretation and having more CHECK lines is also totally fine :)


// CHECK-LABEL: func @conv2d_i1_i1_i1
// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: tensor<1x8x6xi1>, %[[FILTER:[0-9a-z]+]]: tensor<8x8x1xi1>, %[[OUTPUT:[0-9a-z]+]]: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[FALSE:.+]] = arith.constant false
// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[INPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[FILTER]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
// CHECK-DAG: %[[TREAD0:.+]] = vector.transpose %[[READ0]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
// CHECK-DAG: %[[TREAD1:.+]] = vector.transpose %[[READ1]], [2, 1, 0] : vector<8x8x1xi1> to vector<1x8x8xi1>
// CHECK-DAG: %[[TREAD2:.+]] = vector.transpose %[[READ2]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[TREAD1]][0] : vector<8x8xi1> from vector<1x8x8xi1>
// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<or>}
// CHECK-SAME: %[[TREAD0]], %[[EXTRACT]], %[[TREAD2]] : vector<1x6x8xi1>, vector<8x8xi1> into vector<1x6x8xi1>
// CHECK: %[[TCONTRACT:.+]] = vector.transpose %[[CONTRACT]], [0, 2, 1] : vector<1x6x8xi1> to vector<1x8x6xi1>
// CHECK: %[[RESULT:.+]] = vector.transfer_write %[[TCONTRACT]], %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]]
// CHECK: return %[[RESULT]] : tensor<1x8x6xi1>



// -----

func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {
Expand Down
Loading