23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallVectorExtras.h"
38 if (
auto boolAttr = dyn_cast<BoolAttr>(attr))
39 return boolAttr.getValue();
40 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
41 if (splatAttr.getElementType().isInteger(1))
42 return splatAttr.getSplatValue<
bool>();
57 if (
auto vector = dyn_cast<ElementsAttr>(composite)) {
58 assert(
indices.size() == 1 &&
"must have exactly one index for a vector");
62 if (
auto array = dyn_cast<ArrayAttr>(composite)) {
63 assert(!
indices.empty() &&
"must have at least one index for an array");
72 bool div0 =
b.isZero();
73 bool overflow = a.isMinSignedValue() &&
b.isAllOnes();
75 return div0 || overflow;
83#include "SPIRVCanonicalization.inc"
94struct CombineChainedAccessChain final
98 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
99 PatternRewriter &rewriter)
const override {
100 auto parentAccessChainOp =
101 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
103 if (!parentAccessChainOp) {
108 SmallVector<Value, 4>
indices(parentAccessChainOp.getIndices());
109 llvm::append_range(
indices, accessChainOp.getIndices());
112 accessChainOp, parentAccessChainOp.getBasePtr(),
indices);
119void spirv::AccessChainOp::getCanonicalizationPatterns(
121 results.
add<CombineChainedAccessChain>(context);
128template <
typename Op>
132 static constexpr bool IsSub = std::is_same_v<Op, spirv::ISubBorrowOp>;
142 std::array<Value, 2> constituents =
157 [](
const APInt &a,
const APInt &
b) {
return IsSub ? a -
b : a +
b; });
162 {lhsAttr, rhsAttr}, [](
const APInt &a,
const APInt &
b) {
163 bool wrapped =
IsSub ? a.ult(
b) : (a +
b).ult(a);
164 return APInt(a.getBitWidth(), wrapped ? 1 : 0);
170 op, op.getType(), rewriter.
getArrayAttr({lowBits, wrapBit}));
176void spirv::IAddCarryOp::getCanonicalizationPatterns(
182void spirv::ISubBorrowOp::getCanonicalizationPatterns(
191template <
typename MulOp,
bool IsSigned>
200 Type constituentType =
lhs.getType();
204 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
205 Value constituents[2] = {zero, zero};
227 [](
const APInt &a,
const APInt &
b) {
return a *
b; });
233 {lhsAttr, rhsAttr}, [](
const APInt &a,
const APInt &
b) {
235 return llvm::APIntOps::mulhs(a,
b);
237 return llvm::APIntOps::mulhu(a,
b);
244 op, op.getType(), rewriter.
getArrayAttr({lowBits, highBits}));
250void spirv::SMulExtendedOp::getCanonicalizationPatterns(
263 Type constituentType =
lhs.getType();
267 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
268 Value constituents[2] = {
lhs, zero};
279void spirv::UMulExtendedOp::getCanonicalizationPatterns(
302 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
314 bool isApplicable =
false;
315 if (
auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
316 auto currInt = cast<IntegerAttr>(currValue);
317 isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
318 }
else if (
auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
319 auto currVec = cast<DenseElementsAttr>(currValue);
320 isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
321 currVec.getValues<APInt>()),
322 [](
const auto &pair) {
323 auto &[prev, curr] = pair;
324 return prev.urem(curr) == 0;
334 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
350 Value curInput = getOperand();
355 if (
auto prevCast = curInput.
getDefiningOp<spirv::BitcastOp>()) {
356 Value prevInput = prevCast.getOperand();
360 getOperandMutable().assign(prevInput);
372OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
373 Value compositeOp = getComposite();
375 while (
auto insertOp =
378 return insertOp.getObject();
379 compositeOp = insertOp.getComposite();
382 if (
auto constructOp =
384 auto type = cast<spirv::CompositeType>(constructOp.getType());
386 constructOp.getConstituents().size() == type.getNumElements()) {
387 auto i = cast<IntegerAttr>(*
getIndices().begin());
388 if (i.getValue().getSExtValue() <
389 static_cast<int64_t>(constructOp.getConstituents().size()))
390 return constructOp.getConstituents()[i.getValue().getSExtValue()];
395 return static_cast<unsigned>(cast<IntegerAttr>(attr).getInt());
415 return getOperand1();
423 adaptor.getOperands(),
424 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
434 return getOperand2();
437 return getOperand1();
445 adaptor.getOperands(),
446 [](
const APInt &a,
const APInt &
b) { return a * b; });
455 if (getOperand1() == getOperand2())
464 adaptor.getOperands(),
465 [](APInt a,
const APInt &
b) { return std::move(a) - b; });
475 return getOperand1();
485 bool div0OrOverflow =
false;
487 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
488 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
489 div0OrOverflow = true;
494 return div0OrOverflow ?
Attribute() : res;
516 bool div0OrOverflow =
false;
518 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
519 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
520 div0OrOverflow = true;
523 APInt c = a.abs().urem(
b.abs());
526 if (
b.isNegative()) {
527 APInt zero = APInt::getZero(c.getBitWidth());
528 return a.isNegative() ? (zero - c) : (b + c);
530 return a.isNegative() ? (
b - c) : c;
532 return div0OrOverflow ?
Attribute() : res;
554 bool div0OrOverflow =
false;
556 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
557 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
558 div0OrOverflow = true;
563 return div0OrOverflow ?
Attribute() : res;
573 return getOperand1();
583 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
584 if (div0 || b.isZero()) {
610 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
611 if (div0 || b.isZero()) {
624OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
626 auto op = getOperand();
627 if (
auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
628 return negateOp->getOperand(0);
634 adaptor.getOperands(), [](
const APInt &a) {
635 APInt zero = APInt::getZero(a.getBitWidth());
644OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
646 auto op = getOperand();
647 if (
auto notOp = op.getDefiningOp<spirv::NotOp>())
648 return notOp->getOperand(0);
663OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
664 if (std::optional<bool>
rhs =
668 return getOperand1();
672 return adaptor.getOperand2();
683spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
685 if (getOperand1() == getOperand2()) {
687 if (isa<IntegerType>(
getType()))
689 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
694 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
695 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
703OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
704 if (std::optional<bool>
rhs =
708 return getOperand1();
712 if (getOperand1() == getOperand2()) {
714 if (isa<IntegerType>(
getType()))
716 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
721 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
722 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
730OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
732 auto op = getOperand();
733 if (
auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
734 return notOp->getOperand(0);
741 APInt zero = APInt::getZero(1);
742 return a == 1 ? zero : (zero + 1);
746void spirv::LogicalNotOp::getCanonicalizationPatterns(
749 .
add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
750 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
758OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
762 return adaptor.getOperand2();
767 return getOperand1();
778OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
780 Value trueVals = getTrueValue();
781 Value falseVals = getFalseValue();
782 if (trueVals == falseVals)
790 return *boolAttr ? trueVals : falseVals;
793 if (!operands[0] || !operands[1] || !operands[2])
799 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
800 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
801 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
802 if (!condAttrs || !trueAttrs || !falseAttrs)
805 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<
Attribute>());
806 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<
BoolAttr>(),
808 for (
auto [
result, cond, falseRes] : iters) {
809 if (!cond.getValue())
813 auto resultType = trueAttrs.getType();
821OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
823 if (getOperand1() == getOperand2()) {
825 if (isa<IntegerType>(
getType()))
827 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
832 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
833 return a ==
b ? APInt::getAllOnes(1) : APInt::
getZero(1);
841OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
843 if (getOperand1() == getOperand2()) {
845 if (isa<IntegerType>(
getType()))
847 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
852 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
853 return a ==
b ? APInt::getZero(1) : APInt::getAllOnes(1);
862spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
864 if (getOperand1() == getOperand2()) {
866 if (isa<IntegerType>(
getType()))
868 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
873 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
874 return a.sgt(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
883 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
885 if (getOperand1() == getOperand2()) {
887 if (isa<IntegerType>(
getType()))
889 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
894 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
895 return a.sge(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
904spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
906 if (getOperand1() == getOperand2()) {
908 if (isa<IntegerType>(
getType()))
910 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
915 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
916 return a.ugt(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
925 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
927 if (getOperand1() == getOperand2()) {
929 if (isa<IntegerType>(
getType()))
931 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
936 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
937 return a.uge(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
945OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
947 if (getOperand1() == getOperand2()) {
949 if (isa<IntegerType>(
getType()))
951 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
956 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
957 return a.slt(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
966spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
968 if (getOperand1() == getOperand2()) {
970 if (isa<IntegerType>(
getType()))
972 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
977 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
978 return a.sle(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
986OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
988 if (getOperand1() == getOperand2()) {
990 if (isa<IntegerType>(
getType()))
992 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
997 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
998 return a.ult(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
1007spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1009 if (getOperand1() == getOperand2()) {
1011 if (isa<IntegerType>(
getType()))
1013 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1018 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
1019 return a.ule(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
1028 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1031 return getOperand1();
1042 bool shiftToLarge =
false;
1044 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
1045 if (shiftToLarge || b.uge(a.getBitWidth())) {
1046 shiftToLarge = true;
1051 return shiftToLarge ?
Attribute() : res;
1059 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1062 return getOperand1();
1073 bool shiftToLarge =
false;
1075 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
1076 if (shiftToLarge || b.uge(a.getBitWidth())) {
1077 shiftToLarge = true;
1082 return shiftToLarge ?
Attribute() : res;
1090 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1093 return getOperand1();
1104 bool shiftToLarge =
false;
1106 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
1107 if (shiftToLarge || b.uge(a.getBitWidth())) {
1108 shiftToLarge = true;
1113 return shiftToLarge ?
Attribute() : res;
1121spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1123 if (getOperand1() == getOperand2()) {
1124 return getOperand1();
1130 if (rhsMask.isZero())
1131 return getOperand2();
1134 if (rhsMask.isAllOnes())
1135 return getOperand1();
1138 if (
auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1141 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1142 return getOperand1();
1152 adaptor.getOperands(),
1153 [](
const APInt &a,
const APInt &
b) { return a & b; });
1160OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1162 if (getOperand1() == getOperand2()) {
1163 return getOperand1();
1169 if (rhsMask.isZero())
1170 return getOperand1();
1173 if (rhsMask.isAllOnes())
1174 return getOperand2();
1183 adaptor.getOperands(),
1184 [](
const APInt &a,
const APInt &
b) { return a | b; });
1192spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1195 return getOperand1();
1199 if (getOperand1() == getOperand2())
1208 adaptor.getOperands(),
1209 [](
const APInt &a,
const APInt &
b) { return a ^ b; });
1242struct ConvertSelectionOpToSelect final :
OpRewritePattern<spirv::SelectionOp> {
1245 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1246 PatternRewriter &rewriter)
const override {
1247 Operation *op = selectionOp.getOperation();
1256 if (llvm::range_size(body) != 4) {
1260 Block *headerBlock = selectionOp.getHeaderBlock();
1261 if (!onlyContainsBranchConditionalOp(headerBlock)) {
1265 auto brConditionalOp =
1266 cast<spirv::BranchConditionalOp>(headerBlock->
front());
1268 Block *trueBlock = brConditionalOp.getSuccessor(0);
1269 Block *falseBlock = brConditionalOp.getSuccessor(1);
1270 Block *mergeBlock = selectionOp.getMergeBlock();
1272 if (
failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1275 Value trueValue = getSrcValue(trueBlock);
1276 Value falseValue = getSrcValue(falseBlock);
1277 Value ptrValue = getDstPtr(trueBlock);
1278 auto storeOpAttributes =
1279 cast<spirv::StoreOp>(trueBlock->
front())->getAttrs();
1281 auto selectOp = spirv::SelectOp::create(
1282 rewriter, selectionOp.getLoc(), trueValue.
getType(),
1283 brConditionalOp.getCondition(), trueValue, falseValue);
1284 spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
1285 selectOp.getResult(), storeOpAttributes);
1299 LogicalResult canCanonicalizeSelection(
Block *trueBlock,
Block *falseBlock,
1300 Block *mergeBlock)
const;
1302 bool onlyContainsBranchConditionalOp(
Block *block)
const {
1303 return llvm::hasSingleElement(*block) &&
1304 isa<spirv::BranchConditionalOp>(block->
front());
1307 bool isSameAttrList(spirv::StoreOp
lhs, spirv::StoreOp
rhs)
const {
1308 return lhs->getDiscardableAttrDictionary() ==
1309 rhs->getDiscardableAttrDictionary() &&
1310 lhs.getProperties() ==
rhs.getProperties();
1314 Value getSrcValue(
Block *block)
const {
1315 auto storeOp = cast<spirv::StoreOp>(block->
front());
1316 return storeOp.getValue();
1320 Value getDstPtr(
Block *block)
const {
1321 auto storeOp = cast<spirv::StoreOp>(block->
front());
1322 return storeOp.getPtr();
1326LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1329 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1333 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->
front());
1334 auto trueBrBranchOp =
1335 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->
begin()));
1336 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->
front());
1337 auto falseBrBranchOp =
1338 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->
begin()));
1340 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1350 bool isScalarOrVector =
1351 cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1352 .isScalarOrVector();
1356 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1357 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1361 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1362 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1370void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1371 MLIRContext *context) {
1372 results.
add<ConvertSelectionOpToSelect>(context);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
ArithmeticExtendedBinaryFold< spirv::ISubBorrowOp > ISubBorrowFold
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
MulExtendedFold< spirv::SMulExtendedOp, true > SMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
ArithmeticExtendedBinaryFold< spirv::IAddCarryOp > IAddCarryFold
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override
static constexpr bool IsSub
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})