24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVectorExtras.h"
39 if (
auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
40 return boolAttr.getValue();
41 if (
auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
42 if (splatAttr.getElementType().isInteger(1))
43 return splatAttr.getSplatValue<
bool>();
58 if (
auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
59 assert(indices.size() == 1 &&
"must have exactly one index for a vector");
60 return vector.getValues<
Attribute>()[indices[0]];
63 if (
auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
64 assert(!indices.empty() &&
"must have at least one index for an array");
66 indices.drop_front());
73 bool div0 = b.isZero();
74 bool overflow = a.isMinSignedValue() && b.isAllOnes();
76 return div0 || overflow;
84 #include "SPIRVCanonicalization.inc"
95 struct CombineChainedAccessChain final
99 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
101 auto parentAccessChainOp =
102 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
104 if (!parentAccessChainOp) {
110 llvm::append_range(indices, accessChainOp.getIndices());
113 accessChainOp, parentAccessChainOp.getBasePtr(), indices);
120 void spirv::AccessChainOp::getCanonicalizationPatterns(
122 results.
add<CombineChainedAccessChain>(context);
137 Value lhs = op.getOperand1();
138 Value rhs = op.getOperand2();
143 Value constituents[2] = {rhs, lhs};
166 auto adds = constFoldBinaryOp<IntegerAttr>(
168 [](
const APInt &a,
const APInt &b) {
return a + b; });
172 auto carrys = constFoldBinaryOp<IntegerAttr>(
173 ArrayRef{adds, lhsAttr}, [](
const APInt &a,
const APInt &b) {
175 return a.ult(b) ? (zero + 1) : zero;
182 rewriter.
create<spirv::ConstantOp>(loc, constituentType, adds);
185 rewriter.
create<spirv::ConstantOp>(loc, constituentType, carrys);
188 Value undef = rewriter.
create<spirv::UndefOp>(loc, op.getType());
191 rewriter.
create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
199 void spirv::IAddCarryOp::getCanonicalizationPatterns(
210 template <
typename MulOp,
bool IsSigned>
217 Value lhs = op.getOperand1();
218 Value rhs = op.getOperand2();
224 Value constituents[2] = {zero, zero};
244 auto lowBits = constFoldBinaryOp<IntegerAttr>(
246 [](
const APInt &a,
const APInt &b) {
return a * b; });
251 auto highBits = constFoldBinaryOp<IntegerAttr>(
252 {lhsAttr, rhsAttr}, [](
const APInt &a,
const APInt &b) {
254 return llvm::APIntOps::mulhs(a, b);
256 return llvm::APIntOps::mulhu(a, b);
264 rewriter.
create<spirv::ConstantOp>(loc, constituentType, lowBits);
267 rewriter.
create<spirv::ConstantOp>(loc, constituentType, highBits);
270 Value undef = rewriter.
create<spirv::UndefOp>(loc, op.getType());
273 rewriter.
create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
282 void spirv::SMulExtendedOp::getCanonicalizationPatterns(
293 Value lhs = op.getOperand1();
294 Value rhs = op.getOperand2();
300 Value constituents[2] = {lhs, zero};
311 void spirv::UMulExtendedOp::getCanonicalizationPatterns(
334 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
346 bool isApplicable =
false;
347 if (
auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
348 auto currInt = cast<IntegerAttr>(currValue);
349 isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
350 }
else if (
auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
351 auto currVec = cast<DenseElementsAttr>(currValue);
352 isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
353 currVec.getValues<APInt>()),
354 [](
const auto &pair) {
355 auto &[prev, curr] = pair;
356 return prev.urem(curr) == 0;
366 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
382 Value curInput = getOperand();
387 if (
auto prevCast = curInput.
getDefiningOp<spirv::BitcastOp>()) {
388 Value prevInput = prevCast.getOperand();
392 getOperandMutable().assign(prevInput);
404 OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
405 Value compositeOp = getComposite();
407 while (
auto insertOp =
410 return insertOp.getObject();
411 compositeOp = insertOp.getComposite();
414 if (
auto constructOp =
416 auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
418 constructOp.getConstituents().size() == type.getNumElements()) {
419 auto i = llvm::cast<IntegerAttr>(*
getIndices().begin());
420 if (i.getValue().getSExtValue() <
421 static_cast<int64_t
>(constructOp.getConstituents().size()))
422 return constructOp.getConstituents()[i.getValue().getSExtValue()];
427 return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
447 return getOperand1();
454 return constFoldBinaryOp<IntegerAttr>(
455 adaptor.getOperands(),
456 [](APInt a,
const APInt &b) { return std::move(a) + b; });
466 return getOperand2();
469 return getOperand1();
476 return constFoldBinaryOp<IntegerAttr>(
477 adaptor.getOperands(),
478 [](
const APInt &a,
const APInt &b) { return a * b; });
487 if (getOperand1() == getOperand2())
495 return constFoldBinaryOp<IntegerAttr>(
496 adaptor.getOperands(),
497 [](APInt a,
const APInt &b) { return std::move(a) - b; });
507 return getOperand1();
517 bool div0OrOverflow =
false;
518 auto res = constFoldBinaryOp<IntegerAttr>(
519 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
520 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
521 div0OrOverflow = true;
526 return div0OrOverflow ?
Attribute() : res;
548 bool div0OrOverflow =
false;
549 auto res = constFoldBinaryOp<IntegerAttr>(
550 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
551 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
552 div0OrOverflow = true;
555 APInt c = a.abs().urem(b.abs());
558 if (b.isNegative()) {
559 APInt zero = APInt::getZero(c.getBitWidth());
560 return a.isNegative() ? (zero - c) : (b + c);
562 return a.isNegative() ? (b - c) : c;
564 return div0OrOverflow ?
Attribute() : res;
586 bool div0OrOverflow =
false;
587 auto res = constFoldBinaryOp<IntegerAttr>(
588 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
589 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
590 div0OrOverflow = true;
595 return div0OrOverflow ?
Attribute() : res;
605 return getOperand1();
614 auto res = constFoldBinaryOp<IntegerAttr>(
615 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
616 if (div0 || b.isZero()) {
641 auto res = constFoldBinaryOp<IntegerAttr>(
642 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
643 if (div0 || b.isZero()) {
656 OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
658 auto op = getOperand();
659 if (
auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
660 return negateOp->getOperand(0);
665 return constFoldUnaryOp<IntegerAttr>(
666 adaptor.getOperands(), [](
const APInt &a) {
667 APInt zero = APInt::getZero(a.getBitWidth());
676 OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
678 auto op = getOperand();
679 if (
auto notOp = op.getDefiningOp<spirv::NotOp>())
680 return notOp->getOperand(0);
685 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
695 OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
696 if (std::optional<bool> rhs =
700 return getOperand1();
704 return adaptor.getOperand2();
715 spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
717 if (getOperand1() == getOperand2()) {
719 if (isa<IntegerType>(
getType()))
721 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
725 return constFoldBinaryOp<IntegerAttr>(
726 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
727 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
735 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
736 if (std::optional<bool> rhs =
740 return getOperand1();
744 if (getOperand1() == getOperand2()) {
746 if (isa<IntegerType>(
getType()))
748 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
752 return constFoldBinaryOp<IntegerAttr>(
753 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
754 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
762 OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
764 auto op = getOperand();
765 if (
auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
766 return notOp->getOperand(0);
771 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
773 APInt zero = APInt::getZero(1);
774 return a == 1 ? zero : (zero + 1);
778 void spirv::LogicalNotOp::getCanonicalizationPatterns(
781 .
add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
782 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
790 OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
794 return adaptor.getOperand2();
799 return getOperand1();
810 OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
812 Value trueVals = getTrueValue();
813 Value falseVals = getFalseValue();
814 if (trueVals == falseVals)
822 return *boolAttr ? trueVals : falseVals;
825 if (!operands[0] || !operands[1] || !operands[2])
831 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
832 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
833 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
834 if (!condAttrs || !trueAttrs || !falseAttrs)
837 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<
Attribute>());
838 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<
BoolAttr>(),
840 for (
auto [result, cond, falseRes] : iters) {
841 if (!cond.getValue())
845 auto resultType = trueAttrs.getType();
853 OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
855 if (getOperand1() == getOperand2()) {
857 if (isa<IntegerType>(
getType()))
859 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
863 return constFoldBinaryOp<IntegerAttr>(
864 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
865 return a == b ? APInt::getAllOnes(1) : APInt::
getZero(1);
873 OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
875 if (getOperand1() == getOperand2()) {
877 if (isa<IntegerType>(
getType()))
879 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
883 return constFoldBinaryOp<IntegerAttr>(
884 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
894 spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
896 if (getOperand1() == getOperand2()) {
898 if (isa<IntegerType>(
getType()))
900 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
904 return constFoldBinaryOp<IntegerAttr>(
905 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
915 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
917 if (getOperand1() == getOperand2()) {
919 if (isa<IntegerType>(
getType()))
921 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
925 return constFoldBinaryOp<IntegerAttr>(
926 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
936 spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
938 if (getOperand1() == getOperand2()) {
940 if (isa<IntegerType>(
getType()))
942 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
946 return constFoldBinaryOp<IntegerAttr>(
947 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
957 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
959 if (getOperand1() == getOperand2()) {
961 if (isa<IntegerType>(
getType()))
963 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
967 return constFoldBinaryOp<IntegerAttr>(
968 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
977 OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
979 if (getOperand1() == getOperand2()) {
981 if (isa<IntegerType>(
getType()))
983 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
987 return constFoldBinaryOp<IntegerAttr>(
988 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
998 spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
1000 if (getOperand1() == getOperand2()) {
1002 if (isa<IntegerType>(
getType()))
1004 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1008 return constFoldBinaryOp<IntegerAttr>(
1009 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
1018 OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1020 if (getOperand1() == getOperand2()) {
1022 if (isa<IntegerType>(
getType()))
1024 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1028 return constFoldBinaryOp<IntegerAttr>(
1029 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
1039 spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1041 if (getOperand1() == getOperand2()) {
1043 if (isa<IntegerType>(
getType()))
1045 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1049 return constFoldBinaryOp<IntegerAttr>(
1050 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
1060 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1063 return getOperand1();
1074 bool shiftToLarge =
false;
1075 auto res = constFoldBinaryOp<IntegerAttr>(
1076 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
1077 if (shiftToLarge || b.uge(a.getBitWidth())) {
1078 shiftToLarge = true;
1083 return shiftToLarge ?
Attribute() : res;
1091 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1094 return getOperand1();
1105 bool shiftToLarge =
false;
1106 auto res = constFoldBinaryOp<IntegerAttr>(
1107 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
1108 if (shiftToLarge || b.uge(a.getBitWidth())) {
1109 shiftToLarge = true;
1114 return shiftToLarge ?
Attribute() : res;
1122 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1125 return getOperand1();
1136 bool shiftToLarge =
false;
1137 auto res = constFoldBinaryOp<IntegerAttr>(
1138 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
1139 if (shiftToLarge || b.uge(a.getBitWidth())) {
1140 shiftToLarge = true;
1145 return shiftToLarge ?
Attribute() : res;
1153 spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1155 if (getOperand1() == getOperand2()) {
1156 return getOperand1();
1162 if (rhsMask.isZero())
1163 return getOperand2();
1166 if (rhsMask.isAllOnes())
1167 return getOperand1();
1170 if (
auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1173 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1174 return getOperand1();
1183 return constFoldBinaryOp<IntegerAttr>(
1184 adaptor.getOperands(),
1185 [](
const APInt &a,
const APInt &b) { return a & b; });
1192 OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1194 if (getOperand1() == getOperand2()) {
1195 return getOperand1();
1201 if (rhsMask.isZero())
1202 return getOperand1();
1205 if (rhsMask.isAllOnes())
1206 return getOperand2();
1214 return constFoldBinaryOp<IntegerAttr>(
1215 adaptor.getOperands(),
1216 [](
const APInt &a,
const APInt &b) { return a | b; });
1224 spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1227 return getOperand1();
1231 if (getOperand1() == getOperand2())
1239 return constFoldBinaryOp<IntegerAttr>(
1240 adaptor.getOperands(),
1241 [](
const APInt &a,
const APInt &b) { return a ^ b; });
1274 struct ConvertSelectionOpToSelect final :
OpRewritePattern<spirv::SelectionOp> {
1279 Operation *op = selectionOp.getOperation();
1288 if (llvm::range_size(body) != 4) {
1292 Block *headerBlock = selectionOp.getHeaderBlock();
1293 if (!onlyContainsBranchConditionalOp(headerBlock)) {
1297 auto brConditionalOp =
1298 cast<spirv::BranchConditionalOp>(headerBlock->
front());
1302 Block *mergeBlock = selectionOp.getMergeBlock();
1304 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1307 Value trueValue = getSrcValue(trueBlock);
1308 Value falseValue = getSrcValue(falseBlock);
1309 Value ptrValue = getDstPtr(trueBlock);
1310 auto storeOpAttributes =
1311 cast<spirv::StoreOp>(trueBlock->
front())->getAttrs();
1313 auto selectOp = rewriter.
create<spirv::SelectOp>(
1314 selectionOp.getLoc(), trueValue.
getType(),
1315 brConditionalOp.getCondition(), trueValue, falseValue);
1316 rewriter.
create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
1317 selectOp.getResult(), storeOpAttributes);
1331 LogicalResult canCanonicalizeSelection(
Block *trueBlock,
Block *falseBlock,
1332 Block *mergeBlock)
const;
1334 bool onlyContainsBranchConditionalOp(
Block *block)
const {
1335 return llvm::hasSingleElement(*block) &&
1336 isa<spirv::BranchConditionalOp>(block->
front());
1339 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs)
const {
1340 return lhs->getDiscardableAttrDictionary() ==
1341 rhs->getDiscardableAttrDictionary() &&
1342 lhs.getProperties() == rhs.getProperties();
1347 auto storeOp = cast<spirv::StoreOp>(block->
front());
1348 return storeOp.getValue();
1353 auto storeOp = cast<spirv::StoreOp>(block->
front());
1354 return storeOp.getPtr();
1358 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1361 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1365 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->
front());
1366 auto trueBrBranchOp =
1367 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->
begin()));
1368 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->
front());
1369 auto falseBrBranchOp =
1370 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->
begin()));
1372 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1382 bool isScalarOrVector =
1383 llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1384 .isScalarOrVector();
1388 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1389 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1393 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1394 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1402 void spirv::SelectionOp::getCanonicalizationPatterns(
RewritePatternSet &results,
1404 results.
add<ConvertSelectionOpToSelect>(context);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
static MLIRContext * getContext(OpFoldResult val)
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Block * getSuccessor(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult matchAndRewrite(spirv::IAddCarryOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.