23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallVectorExtras.h"
38 if (
auto boolAttr = dyn_cast<BoolAttr>(attr))
39 return boolAttr.getValue();
40 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
41 if (splatAttr.getElementType().isInteger(1))
42 return splatAttr.getSplatValue<
bool>();
57 if (
auto vector = dyn_cast<ElementsAttr>(composite)) {
58 assert(
indices.size() == 1 &&
"must have exactly one index for a vector");
62 if (
auto array = dyn_cast<ArrayAttr>(composite)) {
63 assert(!
indices.empty() &&
"must have at least one index for an array");
72 bool div0 =
b.isZero();
73 bool overflow = a.isMinSignedValue() &&
b.isAllOnes();
75 return div0 || overflow;
83#include "SPIRVCanonicalization.inc"
94struct CombineChainedAccessChain final
98 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
99 PatternRewriter &rewriter)
const override {
100 auto parentAccessChainOp =
101 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
103 if (!parentAccessChainOp) {
108 SmallVector<Value, 4>
indices(parentAccessChainOp.getIndices());
109 llvm::append_range(
indices, accessChainOp.getIndices());
112 accessChainOp, parentAccessChainOp.getBasePtr(),
indices);
119void spirv::AccessChainOp::getCanonicalizationPatterns(
121 results.
add<CombineChainedAccessChain>(context);
138 Type constituentType =
lhs.getType();
167 [](
const APInt &a,
const APInt &
b) {
return a +
b; });
172 ArrayRef{adds, lhsAttr}, [](
const APInt &a,
const APInt &
b) {
173 APInt zero = APInt::getZero(a.getBitWidth());
174 return a.ult(
b) ? (zero + 1) : zero;
181 spirv::ConstantOp::create(rewriter, loc, constituentType, adds);
184 spirv::ConstantOp::create(rewriter, loc, constituentType, carrys);
187 Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
190 spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0);
198void spirv::IAddCarryOp::getCanonicalizationPatterns(
209template <
typename MulOp,
bool IsSigned>
218 Type constituentType =
lhs.getType();
222 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
223 Value constituents[2] = {zero, zero};
245 [](
const APInt &a,
const APInt &
b) {
return a *
b; });
251 {lhsAttr, rhsAttr}, [](
const APInt &a,
const APInt &
b) {
253 return llvm::APIntOps::mulhs(a,
b);
255 return llvm::APIntOps::mulhu(a,
b);
262 spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits);
265 spirv::ConstantOp::create(rewriter, loc, constituentType, highBits);
268 Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
271 spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0);
280void spirv::SMulExtendedOp::getCanonicalizationPatterns(
293 Type constituentType =
lhs.getType();
297 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
298 Value constituents[2] = {
lhs, zero};
309void spirv::UMulExtendedOp::getCanonicalizationPatterns(
332 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
344 bool isApplicable =
false;
345 if (
auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
346 auto currInt = cast<IntegerAttr>(currValue);
347 isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
348 }
else if (
auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
349 auto currVec = cast<DenseElementsAttr>(currValue);
350 isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
351 currVec.getValues<APInt>()),
352 [](
const auto &pair) {
353 auto &[prev, curr] = pair;
354 return prev.urem(curr) == 0;
364 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
380 Value curInput = getOperand();
385 if (
auto prevCast = curInput.
getDefiningOp<spirv::BitcastOp>()) {
386 Value prevInput = prevCast.getOperand();
390 getOperandMutable().assign(prevInput);
402OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
403 Value compositeOp = getComposite();
405 while (
auto insertOp =
408 return insertOp.getObject();
409 compositeOp = insertOp.getComposite();
412 if (
auto constructOp =
414 auto type = cast<spirv::CompositeType>(constructOp.getType());
416 constructOp.getConstituents().size() == type.getNumElements()) {
417 auto i = cast<IntegerAttr>(*
getIndices().begin());
418 if (i.getValue().getSExtValue() <
419 static_cast<int64_t>(constructOp.getConstituents().size()))
420 return constructOp.getConstituents()[i.getValue().getSExtValue()];
425 return static_cast<unsigned>(cast<IntegerAttr>(attr).getInt());
445 return getOperand1();
453 adaptor.getOperands(),
454 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
464 return getOperand2();
467 return getOperand1();
475 adaptor.getOperands(),
476 [](
const APInt &a,
const APInt &
b) { return a * b; });
485 if (getOperand1() == getOperand2())
494 adaptor.getOperands(),
495 [](APInt a,
const APInt &
b) { return std::move(a) - b; });
505 return getOperand1();
515 bool div0OrOverflow =
false;
517 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
518 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
519 div0OrOverflow = true;
524 return div0OrOverflow ?
Attribute() : res;
546 bool div0OrOverflow =
false;
548 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
549 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
550 div0OrOverflow = true;
553 APInt c = a.abs().urem(
b.abs());
556 if (
b.isNegative()) {
557 APInt zero = APInt::getZero(c.getBitWidth());
558 return a.isNegative() ? (zero - c) : (b + c);
560 return a.isNegative() ? (
b - c) : c;
562 return div0OrOverflow ?
Attribute() : res;
584 bool div0OrOverflow =
false;
586 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
587 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
588 div0OrOverflow = true;
593 return div0OrOverflow ?
Attribute() : res;
603 return getOperand1();
613 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
614 if (div0 || b.isZero()) {
640 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
641 if (div0 || b.isZero()) {
654OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
656 auto op = getOperand();
657 if (
auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
658 return negateOp->getOperand(0);
664 adaptor.getOperands(), [](
const APInt &a) {
665 APInt zero = APInt::getZero(a.getBitWidth());
674OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
676 auto op = getOperand();
677 if (
auto notOp = op.getDefiningOp<spirv::NotOp>())
678 return notOp->getOperand(0);
693OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
694 if (std::optional<bool>
rhs =
698 return getOperand1();
702 return adaptor.getOperand2();
713spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
715 if (getOperand1() == getOperand2()) {
717 if (isa<IntegerType>(
getType()))
719 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
724 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
725 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
733OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
734 if (std::optional<bool>
rhs =
738 return getOperand1();
742 if (getOperand1() == getOperand2()) {
744 if (isa<IntegerType>(
getType()))
746 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
751 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
752 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
760OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
762 auto op = getOperand();
763 if (
auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
764 return notOp->getOperand(0);
771 APInt zero = APInt::getZero(1);
772 return a == 1 ? zero : (zero + 1);
776void spirv::LogicalNotOp::getCanonicalizationPatterns(
779 .
add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
780 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
788OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
792 return adaptor.getOperand2();
797 return getOperand1();
808OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
810 Value trueVals = getTrueValue();
811 Value falseVals = getFalseValue();
812 if (trueVals == falseVals)
820 return *boolAttr ? trueVals : falseVals;
823 if (!operands[0] || !operands[1] || !operands[2])
829 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
830 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
831 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
832 if (!condAttrs || !trueAttrs || !falseAttrs)
835 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<
Attribute>());
836 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<
BoolAttr>(),
838 for (
auto [
result, cond, falseRes] : iters) {
839 if (!cond.getValue())
843 auto resultType = trueAttrs.getType();
851OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
853 if (getOperand1() == getOperand2()) {
855 if (isa<IntegerType>(
getType()))
857 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
862 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
863 return a ==
b ? APInt::getAllOnes(1) : APInt::
getZero(1);
871OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
873 if (getOperand1() == getOperand2()) {
875 if (isa<IntegerType>(
getType()))
877 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
882 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
883 return a ==
b ? APInt::getZero(1) : APInt::getAllOnes(1);
892spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
894 if (getOperand1() == getOperand2()) {
896 if (isa<IntegerType>(
getType()))
898 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
903 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
904 return a.sgt(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
913 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
915 if (getOperand1() == getOperand2()) {
917 if (isa<IntegerType>(
getType()))
919 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
924 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
925 return a.sge(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
934spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
936 if (getOperand1() == getOperand2()) {
938 if (isa<IntegerType>(
getType()))
940 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
945 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
946 return a.ugt(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
955 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
957 if (getOperand1() == getOperand2()) {
959 if (isa<IntegerType>(
getType()))
961 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
966 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
967 return a.uge(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
975OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
977 if (getOperand1() == getOperand2()) {
979 if (isa<IntegerType>(
getType()))
981 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
986 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
987 return a.slt(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
996spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
998 if (getOperand1() == getOperand2()) {
1000 if (isa<IntegerType>(
getType()))
1002 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1007 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
1008 return a.sle(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
1016OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1018 if (getOperand1() == getOperand2()) {
1020 if (isa<IntegerType>(
getType()))
1022 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1027 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
1028 return a.ult(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
1037spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1039 if (getOperand1() == getOperand2()) {
1041 if (isa<IntegerType>(
getType()))
1043 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1048 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
1049 return a.ule(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
1058 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1061 return getOperand1();
1072 bool shiftToLarge =
false;
1074 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
1075 if (shiftToLarge || b.uge(a.getBitWidth())) {
1076 shiftToLarge = true;
1081 return shiftToLarge ?
Attribute() : res;
1089 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1092 return getOperand1();
1103 bool shiftToLarge =
false;
1105 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
1106 if (shiftToLarge || b.uge(a.getBitWidth())) {
1107 shiftToLarge = true;
1112 return shiftToLarge ?
Attribute() : res;
1120 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1123 return getOperand1();
1134 bool shiftToLarge =
false;
1136 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
1137 if (shiftToLarge || b.uge(a.getBitWidth())) {
1138 shiftToLarge = true;
1143 return shiftToLarge ?
Attribute() : res;
1151spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1153 if (getOperand1() == getOperand2()) {
1154 return getOperand1();
1160 if (rhsMask.isZero())
1161 return getOperand2();
1164 if (rhsMask.isAllOnes())
1165 return getOperand1();
1168 if (
auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1171 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1172 return getOperand1();
1182 adaptor.getOperands(),
1183 [](
const APInt &a,
const APInt &
b) { return a & b; });
1190OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1192 if (getOperand1() == getOperand2()) {
1193 return getOperand1();
1199 if (rhsMask.isZero())
1200 return getOperand1();
1203 if (rhsMask.isAllOnes())
1204 return getOperand2();
1213 adaptor.getOperands(),
1214 [](
const APInt &a,
const APInt &
b) { return a | b; });
1222spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1225 return getOperand1();
1229 if (getOperand1() == getOperand2())
1238 adaptor.getOperands(),
1239 [](
const APInt &a,
const APInt &
b) { return a ^ b; });
1272struct ConvertSelectionOpToSelect final :
OpRewritePattern<spirv::SelectionOp> {
1275 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1276 PatternRewriter &rewriter)
const override {
1277 Operation *op = selectionOp.getOperation();
1286 if (llvm::range_size(body) != 4) {
1290 Block *headerBlock = selectionOp.getHeaderBlock();
1291 if (!onlyContainsBranchConditionalOp(headerBlock)) {
1295 auto brConditionalOp =
1296 cast<spirv::BranchConditionalOp>(headerBlock->
front());
1298 Block *trueBlock = brConditionalOp.getSuccessor(0);
1299 Block *falseBlock = brConditionalOp.getSuccessor(1);
1300 Block *mergeBlock = selectionOp.getMergeBlock();
1302 if (
failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1305 Value trueValue = getSrcValue(trueBlock);
1306 Value falseValue = getSrcValue(falseBlock);
1307 Value ptrValue = getDstPtr(trueBlock);
1308 auto storeOpAttributes =
1309 cast<spirv::StoreOp>(trueBlock->
front())->getAttrs();
1311 auto selectOp = spirv::SelectOp::create(
1312 rewriter, selectionOp.getLoc(), trueValue.
getType(),
1313 brConditionalOp.getCondition(), trueValue, falseValue);
1314 spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
1315 selectOp.getResult(), storeOpAttributes);
1329 LogicalResult canCanonicalizeSelection(
Block *trueBlock,
Block *falseBlock,
1330 Block *mergeBlock)
const;
1332 bool onlyContainsBranchConditionalOp(
Block *block)
const {
1333 return llvm::hasSingleElement(*block) &&
1334 isa<spirv::BranchConditionalOp>(block->
front());
1337 bool isSameAttrList(spirv::StoreOp
lhs, spirv::StoreOp
rhs)
const {
1338 return lhs->getDiscardableAttrDictionary() ==
1339 rhs->getDiscardableAttrDictionary() &&
1340 lhs.getProperties() ==
rhs.getProperties();
1344 Value getSrcValue(
Block *block)
const {
1345 auto storeOp = cast<spirv::StoreOp>(block->
front());
1346 return storeOp.getValue();
1350 Value getDstPtr(
Block *block)
const {
1351 auto storeOp = cast<spirv::StoreOp>(block->
front());
1352 return storeOp.getPtr();
1356LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1359 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1363 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->
front());
1364 auto trueBrBranchOp =
1365 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->
begin()));
1366 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->
front());
1367 auto falseBrBranchOp =
1368 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->
begin()));
1370 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1380 bool isScalarOrVector =
1381 cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1382 .isScalarOrVector();
1386 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1387 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1391 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1392 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1400void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1401 MLIRContext *context) {
1402 results.
add<ConvertSelectionOpToSelect>(context);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
MulExtendedFold< spirv::SMulExtendedOp, true > SMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
LogicalResult matchAndRewrite(spirv::IAddCarryOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})