23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallVectorExtras.h"
38 if (
auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
39 return boolAttr.getValue();
40 if (
auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
41 if (splatAttr.getElementType().isInteger(1))
42 return splatAttr.getSplatValue<
bool>();
57 if (
auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
58 assert(
indices.size() == 1 &&
"must have exactly one index for a vector");
62 if (
auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
63 assert(!
indices.empty() &&
"must have at least one index for an array");
72 bool div0 =
b.isZero();
73 bool overflow = a.isMinSignedValue() &&
b.isAllOnes();
75 return div0 || overflow;
83#include "SPIRVCanonicalization.inc"
94struct CombineChainedAccessChain final
98 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
99 PatternRewriter &rewriter)
const override {
100 auto parentAccessChainOp =
101 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
103 if (!parentAccessChainOp) {
108 SmallVector<Value, 4>
indices(parentAccessChainOp.getIndices());
109 llvm::append_range(
indices, accessChainOp.getIndices());
112 accessChainOp, parentAccessChainOp.getBasePtr(),
indices);
119void spirv::AccessChainOp::getCanonicalizationPatterns(
121 results.
add<CombineChainedAccessChain>(context);
138 Type constituentType =
lhs.getType();
167 [](
const APInt &a,
const APInt &
b) {
return a +
b; });
172 ArrayRef{adds, lhsAttr}, [](
const APInt &a,
const APInt &
b) {
173 APInt zero = APInt::getZero(a.getBitWidth());
174 return a.ult(
b) ? (zero + 1) : zero;
181 spirv::ConstantOp::create(rewriter, loc, constituentType, adds);
184 spirv::ConstantOp::create(rewriter, loc, constituentType, carrys);
187 Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
190 spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0);
198void spirv::IAddCarryOp::getCanonicalizationPatterns(
209template <
typename MulOp,
bool IsSigned>
218 Type constituentType =
lhs.getType();
222 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
223 Value constituents[2] = {zero, zero};
245 [](
const APInt &a,
const APInt &
b) {
return a *
b; });
251 {lhsAttr, rhsAttr}, [](
const APInt &a,
const APInt &
b) {
253 return llvm::APIntOps::mulhs(a,
b);
255 return llvm::APIntOps::mulhu(a,
b);
263 spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits);
266 spirv::ConstantOp::create(rewriter, loc, constituentType, highBits);
269 Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
272 spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0);
281void spirv::SMulExtendedOp::getCanonicalizationPatterns(
294 Type constituentType =
lhs.getType();
298 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
299 Value constituents[2] = {
lhs, zero};
310void spirv::UMulExtendedOp::getCanonicalizationPatterns(
333 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
345 bool isApplicable =
false;
346 if (
auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
347 auto currInt = cast<IntegerAttr>(currValue);
348 isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
349 }
else if (
auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
350 auto currVec = cast<DenseElementsAttr>(currValue);
351 isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
352 currVec.getValues<APInt>()),
353 [](
const auto &pair) {
354 auto &[prev, curr] = pair;
355 return prev.urem(curr) == 0;
365 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
381 Value curInput = getOperand();
386 if (
auto prevCast = curInput.
getDefiningOp<spirv::BitcastOp>()) {
387 Value prevInput = prevCast.getOperand();
391 getOperandMutable().assign(prevInput);
403OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
404 Value compositeOp = getComposite();
406 while (
auto insertOp =
409 return insertOp.getObject();
410 compositeOp = insertOp.getComposite();
413 if (
auto constructOp =
415 auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
417 constructOp.getConstituents().size() == type.getNumElements()) {
418 auto i = llvm::cast<IntegerAttr>(*
getIndices().begin());
419 if (i.getValue().getSExtValue() <
420 static_cast<int64_t>(constructOp.getConstituents().size()))
421 return constructOp.getConstituents()[i.getValue().getSExtValue()];
426 return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
446 return getOperand1();
454 adaptor.getOperands(),
455 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
465 return getOperand2();
468 return getOperand1();
476 adaptor.getOperands(),
477 [](
const APInt &a,
const APInt &
b) { return a * b; });
486 if (getOperand1() == getOperand2())
495 adaptor.getOperands(),
496 [](APInt a,
const APInt &
b) { return std::move(a) - b; });
506 return getOperand1();
516 bool div0OrOverflow =
false;
518 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
519 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
520 div0OrOverflow = true;
525 return div0OrOverflow ?
Attribute() : res;
547 bool div0OrOverflow =
false;
549 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
550 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
551 div0OrOverflow = true;
554 APInt c = a.abs().urem(
b.abs());
557 if (
b.isNegative()) {
558 APInt zero = APInt::getZero(c.getBitWidth());
559 return a.isNegative() ? (zero - c) : (b + c);
561 return a.isNegative() ? (
b - c) : c;
563 return div0OrOverflow ?
Attribute() : res;
585 bool div0OrOverflow =
false;
587 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
588 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
589 div0OrOverflow = true;
594 return div0OrOverflow ?
Attribute() : res;
604 return getOperand1();
614 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
615 if (div0 || b.isZero()) {
641 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
642 if (div0 || b.isZero()) {
655OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
657 auto op = getOperand();
658 if (
auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
659 return negateOp->getOperand(0);
665 adaptor.getOperands(), [](
const APInt &a) {
666 APInt zero = APInt::getZero(a.getBitWidth());
675OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
677 auto op = getOperand();
678 if (
auto notOp = op.getDefiningOp<spirv::NotOp>())
679 return notOp->getOperand(0);
694OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
695 if (std::optional<bool>
rhs =
699 return getOperand1();
703 return adaptor.getOperand2();
714spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
716 if (getOperand1() == getOperand2()) {
718 if (isa<IntegerType>(
getType()))
720 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
725 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
726 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
734OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
735 if (std::optional<bool>
rhs =
739 return getOperand1();
743 if (getOperand1() == getOperand2()) {
745 if (isa<IntegerType>(
getType()))
747 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
752 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
753 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
761OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
763 auto op = getOperand();
764 if (
auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
765 return notOp->getOperand(0);
772 APInt zero = APInt::getZero(1);
773 return a == 1 ? zero : (zero + 1);
777void spirv::LogicalNotOp::getCanonicalizationPatterns(
780 .
add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
781 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
789OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
793 return adaptor.getOperand2();
798 return getOperand1();
809OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
811 Value trueVals = getTrueValue();
812 Value falseVals = getFalseValue();
813 if (trueVals == falseVals)
821 return *boolAttr ? trueVals : falseVals;
824 if (!operands[0] || !operands[1] || !operands[2])
830 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
831 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
832 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
833 if (!condAttrs || !trueAttrs || !falseAttrs)
836 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<
Attribute>());
837 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<
BoolAttr>(),
839 for (
auto [
result, cond, falseRes] : iters) {
840 if (!cond.getValue())
844 auto resultType = trueAttrs.getType();
852OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
854 if (getOperand1() == getOperand2()) {
856 if (isa<IntegerType>(
getType()))
858 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
863 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
864 return a ==
b ? APInt::getAllOnes(1) : APInt::
getZero(1);
872OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
874 if (getOperand1() == getOperand2()) {
876 if (isa<IntegerType>(
getType()))
878 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
883 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
884 return a ==
b ? APInt::getZero(1) : APInt::getAllOnes(1);
893spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
895 if (getOperand1() == getOperand2()) {
897 if (isa<IntegerType>(
getType()))
899 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
904 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
905 return a.sgt(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
914 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
916 if (getOperand1() == getOperand2()) {
918 if (isa<IntegerType>(
getType()))
920 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
925 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
926 return a.sge(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
935spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
937 if (getOperand1() == getOperand2()) {
939 if (isa<IntegerType>(
getType()))
941 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
946 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
947 return a.ugt(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
956 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
958 if (getOperand1() == getOperand2()) {
960 if (isa<IntegerType>(
getType()))
962 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
967 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
968 return a.uge(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
976OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
978 if (getOperand1() == getOperand2()) {
980 if (isa<IntegerType>(
getType()))
982 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
987 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
988 return a.slt(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
997spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
999 if (getOperand1() == getOperand2()) {
1001 if (isa<IntegerType>(
getType()))
1003 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1008 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
1009 return a.sle(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
1017OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1019 if (getOperand1() == getOperand2()) {
1021 if (isa<IntegerType>(
getType()))
1023 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1028 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
1029 return a.ult(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
1038spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1040 if (getOperand1() == getOperand2()) {
1042 if (isa<IntegerType>(
getType()))
1044 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1049 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &
b) {
1050 return a.ule(
b) ? APInt::getAllOnes(1) : APInt::
getZero(1);
1059 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1062 return getOperand1();
1073 bool shiftToLarge =
false;
1075 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
1076 if (shiftToLarge || b.uge(a.getBitWidth())) {
1077 shiftToLarge = true;
1082 return shiftToLarge ?
Attribute() : res;
1090 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1093 return getOperand1();
1104 bool shiftToLarge =
false;
1106 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
1107 if (shiftToLarge || b.uge(a.getBitWidth())) {
1108 shiftToLarge = true;
1113 return shiftToLarge ?
Attribute() : res;
1121 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1124 return getOperand1();
1135 bool shiftToLarge =
false;
1137 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
1138 if (shiftToLarge || b.uge(a.getBitWidth())) {
1139 shiftToLarge = true;
1144 return shiftToLarge ?
Attribute() : res;
1152spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1154 if (getOperand1() == getOperand2()) {
1155 return getOperand1();
1161 if (rhsMask.isZero())
1162 return getOperand2();
1165 if (rhsMask.isAllOnes())
1166 return getOperand1();
1169 if (
auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1172 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1173 return getOperand1();
1183 adaptor.getOperands(),
1184 [](
const APInt &a,
const APInt &
b) { return a & b; });
1191OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1193 if (getOperand1() == getOperand2()) {
1194 return getOperand1();
1200 if (rhsMask.isZero())
1201 return getOperand1();
1204 if (rhsMask.isAllOnes())
1205 return getOperand2();
1214 adaptor.getOperands(),
1215 [](
const APInt &a,
const APInt &
b) { return a | b; });
1223spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1226 return getOperand1();
1230 if (getOperand1() == getOperand2())
1239 adaptor.getOperands(),
1240 [](
const APInt &a,
const APInt &
b) { return a ^ b; });
1273struct ConvertSelectionOpToSelect final :
OpRewritePattern<spirv::SelectionOp> {
1276 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1277 PatternRewriter &rewriter)
const override {
1278 Operation *op = selectionOp.getOperation();
1287 if (llvm::range_size(body) != 4) {
1291 Block *headerBlock = selectionOp.getHeaderBlock();
1292 if (!onlyContainsBranchConditionalOp(headerBlock)) {
1296 auto brConditionalOp =
1297 cast<spirv::BranchConditionalOp>(headerBlock->
front());
1299 Block *trueBlock = brConditionalOp.getSuccessor(0);
1300 Block *falseBlock = brConditionalOp.getSuccessor(1);
1301 Block *mergeBlock = selectionOp.getMergeBlock();
1303 if (
failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1306 Value trueValue = getSrcValue(trueBlock);
1307 Value falseValue = getSrcValue(falseBlock);
1308 Value ptrValue = getDstPtr(trueBlock);
1309 auto storeOpAttributes =
1310 cast<spirv::StoreOp>(trueBlock->
front())->getAttrs();
1312 auto selectOp = spirv::SelectOp::create(
1313 rewriter, selectionOp.getLoc(), trueValue.
getType(),
1314 brConditionalOp.getCondition(), trueValue, falseValue);
1315 spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
1316 selectOp.getResult(), storeOpAttributes);
1330 LogicalResult canCanonicalizeSelection(
Block *trueBlock,
Block *falseBlock,
1331 Block *mergeBlock)
const;
1333 bool onlyContainsBranchConditionalOp(
Block *block)
const {
1334 return llvm::hasSingleElement(*block) &&
1335 isa<spirv::BranchConditionalOp>(block->
front());
1338 bool isSameAttrList(spirv::StoreOp
lhs, spirv::StoreOp
rhs)
const {
1339 return lhs->getDiscardableAttrDictionary() ==
1340 rhs->getDiscardableAttrDictionary() &&
1341 lhs.getProperties() ==
rhs.getProperties();
1345 Value getSrcValue(
Block *block)
const {
1346 auto storeOp = cast<spirv::StoreOp>(block->
front());
1347 return storeOp.getValue();
1351 Value getDstPtr(
Block *block)
const {
1352 auto storeOp = cast<spirv::StoreOp>(block->
front());
1353 return storeOp.getPtr();
1357LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1360 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1364 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->
front());
1365 auto trueBrBranchOp =
1366 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->
begin()));
1367 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->
front());
1368 auto falseBrBranchOp =
1369 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->
begin()));
1371 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1381 bool isScalarOrVector =
1382 llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1383 .isScalarOrVector();
1387 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1388 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1392 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1393 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1401void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1402 MLIRContext *context) {
1403 results.
add<ConvertSelectionOpToSelect>(context);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
MulExtendedFold< spirv::SMulExtendedOp, true > SMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
LogicalResult matchAndRewrite(spirv::IAddCarryOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})