23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallVectorExtras.h"
38 if (
auto boolAttr = dyn_cast<BoolAttr>(attr))
39 return boolAttr.getValue();
40 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
41 if (splatAttr.getElementType().isInteger(1))
42 return splatAttr.getSplatValue<
bool>();
57 if (
auto vector = dyn_cast<ElementsAttr>(composite)) {
58 assert(
indices.size() == 1 &&
"must have exactly one index for a vector");
62 if (
auto array = dyn_cast<ArrayAttr>(composite)) {
63 assert(!
indices.empty() &&
"must have at least one index for an array");
72 bool div0 =
b.isZero();
73 bool overflow = a.isMinSignedValue() &&
b.isAllOnes();
75 return div0 || overflow;
83#include "SPIRVCanonicalization.inc"
94struct CombineChainedAccessChain final
98 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
99 PatternRewriter &rewriter)
const override {
100 auto parentAccessChainOp =
101 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
103 if (!parentAccessChainOp) {
108 SmallVector<Value, 4>
indices(parentAccessChainOp.getIndices());
109 llvm::append_range(
indices, accessChainOp.getIndices());
112 accessChainOp, parentAccessChainOp.getBasePtr(),
indices);
119void spirv::AccessChainOp::getCanonicalizationPatterns(
121 results.
add<CombineChainedAccessChain>(context);
128template <
typename Op>
132 static constexpr bool IsSub = std::is_same_v<Op, spirv::ISubBorrowOp>;
142 std::array<Value, 2> constituents =
157 [](
const APInt &a,
const APInt &
b) {
return IsSub ? a -
b : a +
b; });
162 {lhsAttr, rhsAttr}, [](
const APInt &a,
const APInt &
b) {
163 bool wrapped =
IsSub ? a.ult(
b) : (a +
b).ult(a);
164 return APInt(a.getBitWidth(), wrapped ? 1 : 0);
170 op, op.getType(), rewriter.
getArrayAttr({lowBits, wrapBit}));
176void spirv::IAddCarryOp::getCanonicalizationPatterns(
182void spirv::ISubBorrowOp::getCanonicalizationPatterns(
191template <
typename MulOp,
bool IsSigned>
200 Type constituentType =
lhs.getType();
204 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
205 Value constituents[2] = {zero, zero};
227 [](
const APInt &a,
const APInt &
b) {
return a *
b; });
233 {lhsAttr, rhsAttr}, [](
const APInt &a,
const APInt &
b) {
235 return llvm::APIntOps::mulhs(a,
b);
237 return llvm::APIntOps::mulhu(a,
b);
244 op, op.getType(), rewriter.
getArrayAttr({lowBits, highBits}));
250void spirv::SMulExtendedOp::getCanonicalizationPatterns(
263 Type constituentType =
lhs.getType();
267 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
268 Value constituents[2] = {
lhs, zero};
279void spirv::UMulExtendedOp::getCanonicalizationPatterns(
302 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
314 bool isApplicable =
false;
315 if (
auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
316 auto currInt = cast<IntegerAttr>(currValue);
317 if (currInt.getValue().isZero())
319 isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
320 }
else if (
auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
321 auto currVec = cast<DenseElementsAttr>(currValue);
322 if (llvm::any_of(currVec.getValues<APInt>(),
323 [](
const APInt &curr) { return curr.isZero(); }))
325 isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
326 currVec.getValues<APInt>()),
327 [](
const auto &pair) {
328 auto &[prev, curr] = pair;
329 return prev.urem(curr) == 0;
339 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
355 Value curInput = getOperand();
360 if (
auto prevCast = curInput.
getDefiningOp<spirv::BitcastOp>()) {
361 Value prevInput = prevCast.getOperand();
365 getOperandMutable().assign(prevInput);
377OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
378 Value compositeOp = getComposite();
380 while (
auto insertOp =
383 return insertOp.getObject();
384 compositeOp = insertOp.getComposite();
387 if (
auto constructOp =
389 auto type = cast<spirv::CompositeType>(constructOp.getType());
391 constructOp.getConstituents().size() == type.getNumElements()) {
392 auto i = cast<IntegerAttr>(*
getIndices().begin());
393 if (i.getValue().getSExtValue() <
394 static_cast<int64_t>(constructOp.getConstituents().size()))
395 return constructOp.getConstituents()[i.getValue().getSExtValue()];
400 return static_cast<unsigned>(cast<IntegerAttr>(attr).getInt());
420 return getOperand1();
428 adaptor.getOperands(),
429 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
439 return getOperand2();
442 return getOperand1();
450 adaptor.getOperands(),
451 [](
const APInt &a,
const APInt &
b) { return a * b; });
460 if (getOperand1() == getOperand2())
469 adaptor.getOperands(),
470 [](APInt a,
const APInt &
b) { return std::move(a) - b; });
480 return getOperand1();
490 bool div0OrOverflow =
false;
492 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
493 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
494 div0OrOverflow = true;
499 return div0OrOverflow ?
Attribute() : res;
521 bool div0OrOverflow =
false;
523 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
524 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
525 div0OrOverflow = true;
528 APInt c = a.abs().urem(
b.abs());
531 if (
b.isNegative()) {
532 APInt zero = APInt::getZero(c.getBitWidth());
533 return a.isNegative() ? (zero - c) : (b + c);
535 return a.isNegative() ? (
b - c) : c;
537 return div0OrOverflow ?
Attribute() : res;
559 bool div0OrOverflow =
false;
561 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
562 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
563 div0OrOverflow = true;
568 return div0OrOverflow ?
Attribute() : res;
578 return getOperand1();
588 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
589 if (div0 || b.isZero()) {
615 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
616 if (div0 || b.isZero()) {
629OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
631 auto op = getOperand();
632 if (
auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
633 return negateOp->getOperand(0);
639 adaptor.getOperands(), [](
const APInt &a) {
640 APInt zero = APInt::getZero(a.getBitWidth());
649OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
651 auto op = getOperand();
652 if (
auto notOp = op.getDefiningOp<spirv::NotOp>())
653 return notOp->getOperand(0);
668OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
669 if (std::optional<bool>
rhs =
673 return getOperand1();
677 return adaptor.getOperand2();
688spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
690 if (getOperand1() == getOperand2()) {
692 if (isa<IntegerType>(
getType()))
694 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
699 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
700 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
708OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
709 if (std::optional<bool>
rhs =
713 return getOperand1();
717 if (getOperand1() == getOperand2()) {
719 if (isa<IntegerType>(
getType()))
721 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
726 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
727 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
735OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
737 auto op = getOperand();
738 if (
auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
739 return notOp->getOperand(0);
746 APInt zero = APInt::getZero(1);
747 return a == 1 ? zero : (zero + 1);
751void spirv::LogicalNotOp::getCanonicalizationPatterns(
754 .
add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
755 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
763OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
767 return adaptor.getOperand2();
772 return getOperand1();
783OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
785 Value trueVals = getTrueValue();
786 Value falseVals = getFalseValue();
787 if (trueVals == falseVals)
795 return *boolAttr ? trueVals : falseVals;
798 if (!operands[0] || !operands[1] || !operands[2])
804 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
805 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
806 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
807 if (!condAttrs || !trueAttrs || !falseAttrs)
810 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<
Attribute>());
811 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<
BoolAttr>(),
813 for (
auto [
result, cond, falseRes] : iters) {
814 if (!cond.getValue())
818 auto resultType = trueAttrs.getType();
826OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
828 if (getOperand1() == getOperand2()) {
830 if (isa<IntegerType>(
getType()))
832 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
837 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
838 return a ==
b ? APInt::getAllOnes(1) : APInt::
getZero(1);
846OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
848 if (getOperand1() == getOperand2()) {
850 if (isa<IntegerType>(
getType()))
852 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
857 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
858 return a ==
b ? APInt::getZero(1) : APInt::getAllOnes(1);
867spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
869 if (getOperand1() == getOperand2()) {
871 if (isa<IntegerType>(
getType()))
873 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
878 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
879 return a.sgt(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
888 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
890 if (getOperand1() == getOperand2()) {
892 if (isa<IntegerType>(
getType()))
894 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
899 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
900 return a.sge(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
909spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
911 if (getOperand1() == getOperand2()) {
913 if (isa<IntegerType>(
getType()))
915 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
920 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
921 return a.ugt(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
930 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
932 if (getOperand1() == getOperand2()) {
934 if (isa<IntegerType>(
getType()))
936 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
941 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
942 return a.uge(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
950OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
952 if (getOperand1() == getOperand2()) {
954 if (isa<IntegerType>(
getType()))
956 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
961 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
962 return a.slt(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
971spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
973 if (getOperand1() == getOperand2()) {
975 if (isa<IntegerType>(
getType()))
977 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
982 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
983 return a.sle(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
991OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
993 if (getOperand1() == getOperand2()) {
995 if (isa<IntegerType>(
getType()))
997 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1002 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
1003 return a.ult(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
1012spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1014 if (getOperand1() == getOperand2()) {
1016 if (isa<IntegerType>(
getType()))
1018 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1023 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
1024 return a.ule(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
1033 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1036 return getOperand1();
1047 bool shiftToLarge =
false;
1049 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
1050 if (shiftToLarge || b.uge(a.getBitWidth())) {
1051 shiftToLarge = true;
1056 return shiftToLarge ?
Attribute() : res;
1064 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1067 return getOperand1();
1078 bool shiftToLarge =
false;
1080 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
1081 if (shiftToLarge || b.uge(a.getBitWidth())) {
1082 shiftToLarge = true;
1087 return shiftToLarge ?
Attribute() : res;
1095 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1098 return getOperand1();
1109 bool shiftToLarge =
false;
1111 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
1112 if (shiftToLarge || b.uge(a.getBitWidth())) {
1113 shiftToLarge = true;
1118 return shiftToLarge ?
Attribute() : res;
1126spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1128 if (getOperand1() == getOperand2()) {
1129 return getOperand1();
1135 if (rhsMask.isZero())
1136 return getOperand2();
1139 if (rhsMask.isAllOnes())
1140 return getOperand1();
1143 if (
auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1146 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1147 return getOperand1();
1157 adaptor.getOperands(),
1158 [](
const APInt &a,
const APInt &
b) { return a & b; });
1165OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1167 if (getOperand1() == getOperand2()) {
1168 return getOperand1();
1174 if (rhsMask.isZero())
1175 return getOperand1();
1178 if (rhsMask.isAllOnes())
1179 return getOperand2();
1188 adaptor.getOperands(),
1189 [](
const APInt &a,
const APInt &
b) { return a | b; });
1197spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1200 return getOperand1();
1204 if (getOperand1() == getOperand2())
1213 adaptor.getOperands(),
1214 [](
const APInt &a,
const APInt &
b) { return a ^ b; });
1247struct ConvertSelectionOpToSelect final :
OpRewritePattern<spirv::SelectionOp> {
1250 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1251 PatternRewriter &rewriter)
const override {
1252 Operation *op = selectionOp.getOperation();
1261 if (llvm::range_size(body) != 4) {
1265 Block *headerBlock = selectionOp.getHeaderBlock();
1266 if (!onlyContainsBranchConditionalOp(headerBlock)) {
1270 auto brConditionalOp =
1271 cast<spirv::BranchConditionalOp>(headerBlock->
front());
1273 Block *trueBlock = brConditionalOp.getSuccessor(0);
1274 Block *falseBlock = brConditionalOp.getSuccessor(1);
1275 Block *mergeBlock = selectionOp.getMergeBlock();
1277 if (
failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1280 Value trueValue = getSrcValue(trueBlock);
1281 Value falseValue = getSrcValue(falseBlock);
1282 Value ptrValue = getDstPtr(trueBlock);
1283 auto storeOpAttributes =
1284 cast<spirv::StoreOp>(trueBlock->
front())->getAttrs();
1286 auto selectOp = spirv::SelectOp::create(
1287 rewriter, selectionOp.getLoc(), trueValue.
getType(),
1288 brConditionalOp.getCondition(), trueValue, falseValue);
1289 spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
1290 selectOp.getResult(), storeOpAttributes);
1304 LogicalResult canCanonicalizeSelection(
Block *trueBlock,
Block *falseBlock,
1305 Block *mergeBlock)
const;
1307 bool onlyContainsBranchConditionalOp(
Block *block)
const {
1308 return llvm::hasSingleElement(*block) &&
1309 isa<spirv::BranchConditionalOp>(block->
front());
1312 bool isSameAttrList(spirv::StoreOp
lhs, spirv::StoreOp
rhs)
const {
1313 return lhs->getDiscardableAttrDictionary() ==
1314 rhs->getDiscardableAttrDictionary() &&
1315 lhs.getProperties() ==
rhs.getProperties();
1319 Value getSrcValue(
Block *block)
const {
1320 auto storeOp = cast<spirv::StoreOp>(block->
front());
1321 return storeOp.getValue();
1325 Value getDstPtr(
Block *block)
const {
1326 auto storeOp = cast<spirv::StoreOp>(block->
front());
1327 return storeOp.getPtr();
1331LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1334 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1338 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->
front());
1339 auto trueBrBranchOp =
1340 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->
begin()));
1341 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->
front());
1342 auto falseBrBranchOp =
1343 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->
begin()));
1345 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1355 bool isScalarOrVector =
1356 cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1357 .isScalarOrVector();
1361 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1362 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1366 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1367 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1375void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1376 MLIRContext *context) {
1377 results.
add<ConvertSelectionOpToSelect>(context);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
ArithmeticExtendedBinaryFold< spirv::ISubBorrowOp > ISubBorrowFold
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
MulExtendedFold< spirv::SMulExtendedOp, true > SMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
ArithmeticExtendedBinaryFold< spirv::IAddCarryOp > IAddCarryFold
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override
static constexpr bool IsSub
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})