Skip to content

Commit

Permalink
Fix issue NVIDIA#2196
Browse files Browse the repository at this point in the history
Canonicalization patterns should return failure if they don't modify the
IR.
  • Loading branch information
schweitzpgi committed Sep 6, 2024
1 parent 5b030df commit b616ed6
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
31 changes: 19 additions & 12 deletions lib/Optimizer/Dialect/CC/CCOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,9 @@ struct FuseCastCascade : public OpRewritePattern<cudaq::cc::CastOp> {
// %5 = cc.cast %3 : (!cc.ptr<T>) -> !cc.ptr<V>
rewriter.replaceOpWithNewOp<cudaq::cc::CastOp>(castOp, castOp.getType(),
castToCast.getValue());
return success();
}
return success();
return failure();
}
};

Expand All @@ -406,11 +407,13 @@ struct SimplifyIntegerCompare : public OpRewritePattern<arith::CmpIOp> {
auto rhsVal = rhsCast.getValue();
if (lhsVal.getType() == rhsVal.getType() &&
lhsCast.getSint() == rhsCast.getSint() &&
lhsCast.getZint() == rhsCast.getZint())
lhsCast.getZint() == rhsCast.getZint()) {
rewriter.replaceOpWithNewOp<arith::CmpIOp>(
compare, compare.getType(), compare.getPredicate(), lhsVal, rhsVal);
return success();
}
}
return success();
return failure();
}
};

Expand All @@ -429,8 +432,9 @@ struct FuseComplexCreate : public OpRewritePattern<complex::CreateOp> {
auto arrAttr = rewriter.getArrayAttr({rePart, imPart});
rewriter.replaceOpWithNewOp<complex::ConstantOp>(
create, ComplexType::get(eleTy), arrAttr);
return success();
}
return success();
return failure();
}
};
} // namespace
Expand Down Expand Up @@ -764,7 +768,7 @@ struct FuseAddressArithmetic
return success();
}
}
return success();
return failure();
}
};
} // namespace
Expand Down Expand Up @@ -967,7 +971,7 @@ struct FuseWithConstantArray
return success();
}
}
return success();
return failure();
}
};
} // namespace
Expand Down Expand Up @@ -1058,8 +1062,9 @@ struct ForwardStdvecInitData
Value cast = rewriter.create<cudaq::cc::CastOp>(
data.getLoc(), data.getType(), ini.getBuffer());
rewriter.replaceOp(data, cast);
return success();
}
return success();
return failure();
}
};
} // namespace
Expand All @@ -1085,8 +1090,9 @@ struct ForwardStdvecInitSize
Value cast = rewriter.create<cudaq::cc::CastOp>(
size.getLoc(), size.getType(), ini.getLength());
rewriter.replaceOp(size, cast);
return success();
}
return success();
return failure();
}
};
} // namespace
Expand Down Expand Up @@ -1507,9 +1513,9 @@ struct HoistLoopInvariantArgs : public OpRewritePattern<cudaq::cc::LoopOp> {
}
}
}
return success();
}

return success();
return failure();
}
};
} // namespace
Expand Down Expand Up @@ -1640,7 +1646,7 @@ struct EraseScopeWhenNotNeeded : public OpRewritePattern<cudaq::cc::ScopeOp> {
LogicalResult matchAndRewrite(cudaq::cc::ScopeOp scope,
PatternRewriter &rewriter) const override {
if (scope.hasAllocation())
return success();
return failure();

// scope does not allocate, so the region can be inlined into the parent.
auto loc = scope.getLoc();
Expand Down Expand Up @@ -2149,8 +2155,9 @@ struct ReplaceInLoop : public OpRewritePattern<FROM> {
rewriter.splitBlock(scopeBlock, scopePt);
rewriter.setInsertionPointToEnd(scopeBlock);
rewriter.replaceOpWithNewOp<WITH>(fromOp, fromOp.getOperands());
return success();
}
return success();
return failure();
}
};

Expand Down
6 changes: 4 additions & 2 deletions lib/Optimizer/Dialect/Quake/QuakeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,13 @@ struct ForwardConcatExtractPattern
auto index = extract.getConstantIndex();
if (index < concatQubits.size()) {
auto qOpValue = concatQubits[index];
if (isa<quake::RefType>(qOpValue.getType()))
if (isa<quake::RefType>(qOpValue.getType())) {
rewriter.replaceOp(extract, {qOpValue});
return success();
}
}
}
return success();
return failure();
}
};

Expand Down

0 comments on commit b616ed6

Please sign in to comment.