24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVectorExtras.h"
39 if (
auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
40 return boolAttr.getValue();
41 if (
auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
42 if (splatAttr.getElementType().isInteger(1))
43 return splatAttr.getSplatValue<
bool>();
58 if (
auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
59 assert(indices.size() == 1 &&
"must have exactly one index for a vector");
60 return vector.getValues<
Attribute>()[indices[0]];
63 if (
auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
64 assert(!indices.empty() &&
"must have at least one index for an array");
66 indices.drop_front());
73 bool div0 = b.isZero();
74 bool overflow = a.isMinSignedValue() && b.isAllOnes();
76 return div0 || overflow;
84 #include "SPIRVCanonicalization.inc"
95 struct CombineChainedAccessChain final
99 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
101 auto parentAccessChainOp =
102 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
104 if (!parentAccessChainOp) {
110 llvm::append_range(indices, accessChainOp.getIndices());
113 accessChainOp, parentAccessChainOp.getBasePtr(), indices);
120 void spirv::AccessChainOp::getCanonicalizationPatterns(
122 results.
add<CombineChainedAccessChain>(context);
137 Value lhs = op.getOperand1();
138 Value rhs = op.getOperand2();
143 Value constituents[2] = {rhs, lhs};
166 auto adds = constFoldBinaryOp<IntegerAttr>(
168 [](
const APInt &a,
const APInt &b) {
return a + b; });
172 auto carrys = constFoldBinaryOp<IntegerAttr>(
173 ArrayRef{adds, lhsAttr}, [](
const APInt &a,
const APInt &b) {
175 return a.ult(b) ? (zero + 1) : zero;
182 rewriter.
create<spirv::ConstantOp>(loc, constituentType, adds);
185 rewriter.
create<spirv::ConstantOp>(loc, constituentType, carrys);
188 Value undef = rewriter.
create<spirv::UndefOp>(loc, op.getType());
191 rewriter.
create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
199 void spirv::IAddCarryOp::getCanonicalizationPatterns(
210 template <
typename MulOp,
bool IsSigned>
217 Value lhs = op.getOperand1();
218 Value rhs = op.getOperand2();
224 Value constituents[2] = {zero, zero};
244 auto lowBits = constFoldBinaryOp<IntegerAttr>(
246 [](
const APInt &a,
const APInt &b) {
return a * b; });
251 auto highBits = constFoldBinaryOp<IntegerAttr>(
252 {lhsAttr, rhsAttr}, [](
const APInt &a,
const APInt &b) {
254 return llvm::APIntOps::mulhs(a, b);
256 return llvm::APIntOps::mulhu(a, b);
264 rewriter.
create<spirv::ConstantOp>(loc, constituentType, lowBits);
267 rewriter.
create<spirv::ConstantOp>(loc, constituentType, highBits);
270 Value undef = rewriter.
create<spirv::UndefOp>(loc, op.getType());
273 rewriter.
create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
282 void spirv::SMulExtendedOp::getCanonicalizationPatterns(
293 Value lhs = op.getOperand1();
294 Value rhs = op.getOperand2();
300 Value constituents[2] = {lhs, zero};
311 void spirv::UMulExtendedOp::getCanonicalizationPatterns(
335 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
339 IntegerAttr prevValue;
340 IntegerAttr currValue;
345 APInt prevConstValue = prevValue.getValue();
346 APInt currConstValue = currValue.getValue();
350 if (prevConstValue.urem(currConstValue) != 0 &&
351 currConstValue.urem(prevConstValue) != 0)
357 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
373 Value curInput = getOperand();
378 if (
auto prevCast = curInput.
getDefiningOp<spirv::BitcastOp>()) {
379 Value prevInput = prevCast.getOperand();
383 getOperandMutable().assign(prevInput);
395 OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
396 Value compositeOp = getComposite();
398 while (
auto insertOp =
401 return insertOp.getObject();
402 compositeOp = insertOp.getComposite();
405 if (
auto constructOp =
407 auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
409 constructOp.getConstituents().size() == type.getNumElements()) {
410 auto i = llvm::cast<IntegerAttr>(*
getIndices().begin());
411 if (i.getValue().getSExtValue() <
412 static_cast<int64_t
>(constructOp.getConstituents().size()))
413 return constructOp.getConstituents()[i.getValue().getSExtValue()];
418 return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
438 return getOperand1();
445 return constFoldBinaryOp<IntegerAttr>(
446 adaptor.getOperands(),
447 [](APInt a,
const APInt &b) { return std::move(a) + b; });
457 return getOperand2();
460 return getOperand1();
467 return constFoldBinaryOp<IntegerAttr>(
468 adaptor.getOperands(),
469 [](
const APInt &a,
const APInt &b) { return a * b; });
478 if (getOperand1() == getOperand2())
486 return constFoldBinaryOp<IntegerAttr>(
487 adaptor.getOperands(),
488 [](APInt a,
const APInt &b) { return std::move(a) - b; });
498 return getOperand1();
508 bool div0OrOverflow =
false;
509 auto res = constFoldBinaryOp<IntegerAttr>(
510 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
511 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
512 div0OrOverflow = true;
517 return div0OrOverflow ?
Attribute() : res;
539 bool div0OrOverflow =
false;
540 auto res = constFoldBinaryOp<IntegerAttr>(
541 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
542 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
543 div0OrOverflow = true;
546 APInt c = a.abs().urem(b.abs());
549 if (b.isNegative()) {
550 APInt zero = APInt::getZero(c.getBitWidth());
551 return a.isNegative() ? (zero - c) : (b + c);
553 return a.isNegative() ? (b - c) : c;
555 return div0OrOverflow ?
Attribute() : res;
577 bool div0OrOverflow =
false;
578 auto res = constFoldBinaryOp<IntegerAttr>(
579 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
580 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
581 div0OrOverflow = true;
586 return div0OrOverflow ?
Attribute() : res;
596 return getOperand1();
605 auto res = constFoldBinaryOp<IntegerAttr>(
606 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
607 if (div0 || b.isZero()) {
632 auto res = constFoldBinaryOp<IntegerAttr>(
633 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
634 if (div0 || b.isZero()) {
647 OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
649 auto op = getOperand();
650 if (
auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
651 return negateOp->getOperand(0);
656 return constFoldUnaryOp<IntegerAttr>(
657 adaptor.getOperands(), [](
const APInt &a) {
658 APInt zero = APInt::getZero(a.getBitWidth());
667 OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
669 auto op = getOperand();
670 if (
auto notOp = op.getDefiningOp<spirv::NotOp>())
671 return notOp->getOperand(0);
676 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
686 OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
687 if (std::optional<bool> rhs =
691 return getOperand1();
695 return adaptor.getOperand2();
706 spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
708 if (getOperand1() == getOperand2()) {
710 if (isa<IntegerType>(
getType()))
712 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
716 return constFoldBinaryOp<IntegerAttr>(
717 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
718 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
726 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
727 if (std::optional<bool> rhs =
731 return getOperand1();
735 if (getOperand1() == getOperand2()) {
737 if (isa<IntegerType>(
getType()))
739 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
743 return constFoldBinaryOp<IntegerAttr>(
744 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
745 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
753 OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
755 auto op = getOperand();
756 if (
auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
757 return notOp->getOperand(0);
762 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
764 APInt zero = APInt::getZero(1);
765 return a == 1 ? zero : (zero + 1);
769 void spirv::LogicalNotOp::getCanonicalizationPatterns(
772 .
add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
773 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
781 OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
785 return adaptor.getOperand2();
790 return getOperand1();
801 OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
803 Value trueVals = getTrueValue();
804 Value falseVals = getFalseValue();
805 if (trueVals == falseVals)
813 return *boolAttr ? trueVals : falseVals;
816 if (!operands[0] || !operands[1] || !operands[2])
822 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
823 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
824 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
825 if (!condAttrs || !trueAttrs || !falseAttrs)
828 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<
Attribute>());
829 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<
BoolAttr>(),
831 for (
auto [result, cond, falseRes] : iters) {
832 if (!cond.getValue())
836 auto resultType = trueAttrs.getType();
844 OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
846 if (getOperand1() == getOperand2()) {
848 if (isa<IntegerType>(
getType()))
850 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
854 return constFoldBinaryOp<IntegerAttr>(
855 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
856 return a == b ? APInt::getAllOnes(1) : APInt::
getZero(1);
864 OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
866 if (getOperand1() == getOperand2()) {
868 if (isa<IntegerType>(
getType()))
870 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
874 return constFoldBinaryOp<IntegerAttr>(
875 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
885 spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
887 if (getOperand1() == getOperand2()) {
889 if (isa<IntegerType>(
getType()))
891 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
895 return constFoldBinaryOp<IntegerAttr>(
896 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
906 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
908 if (getOperand1() == getOperand2()) {
910 if (isa<IntegerType>(
getType()))
912 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
916 return constFoldBinaryOp<IntegerAttr>(
917 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
927 spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
929 if (getOperand1() == getOperand2()) {
931 if (isa<IntegerType>(
getType()))
933 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
937 return constFoldBinaryOp<IntegerAttr>(
938 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
948 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
950 if (getOperand1() == getOperand2()) {
952 if (isa<IntegerType>(
getType()))
954 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
958 return constFoldBinaryOp<IntegerAttr>(
959 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
968 OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
970 if (getOperand1() == getOperand2()) {
972 if (isa<IntegerType>(
getType()))
974 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
978 return constFoldBinaryOp<IntegerAttr>(
979 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
989 spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
991 if (getOperand1() == getOperand2()) {
993 if (isa<IntegerType>(
getType()))
995 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
999 return constFoldBinaryOp<IntegerAttr>(
1000 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
1009 OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1011 if (getOperand1() == getOperand2()) {
1013 if (isa<IntegerType>(
getType()))
1015 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1019 return constFoldBinaryOp<IntegerAttr>(
1020 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
1030 spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1032 if (getOperand1() == getOperand2()) {
1034 if (isa<IntegerType>(
getType()))
1036 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1040 return constFoldBinaryOp<IntegerAttr>(
1041 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
1051 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1054 return getOperand1();
1065 bool shiftToLarge =
false;
1066 auto res = constFoldBinaryOp<IntegerAttr>(
1067 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
1068 if (shiftToLarge || b.uge(a.getBitWidth())) {
1069 shiftToLarge = true;
1074 return shiftToLarge ?
Attribute() : res;
1082 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1085 return getOperand1();
1096 bool shiftToLarge =
false;
1097 auto res = constFoldBinaryOp<IntegerAttr>(
1098 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
1099 if (shiftToLarge || b.uge(a.getBitWidth())) {
1100 shiftToLarge = true;
1105 return shiftToLarge ?
Attribute() : res;
1113 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1116 return getOperand1();
1127 bool shiftToLarge =
false;
1128 auto res = constFoldBinaryOp<IntegerAttr>(
1129 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
1130 if (shiftToLarge || b.uge(a.getBitWidth())) {
1131 shiftToLarge = true;
1136 return shiftToLarge ?
Attribute() : res;
1144 spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1146 if (getOperand1() == getOperand2()) {
1147 return getOperand1();
1153 if (rhsMask.isZero())
1154 return getOperand2();
1157 if (rhsMask.isAllOnes())
1158 return getOperand1();
1161 if (
auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1164 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1165 return getOperand1();
1174 return constFoldBinaryOp<IntegerAttr>(
1175 adaptor.getOperands(),
1176 [](
const APInt &a,
const APInt &b) { return a & b; });
1183 OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1185 if (getOperand1() == getOperand2()) {
1186 return getOperand1();
1192 if (rhsMask.isZero())
1193 return getOperand1();
1196 if (rhsMask.isAllOnes())
1197 return getOperand2();
1205 return constFoldBinaryOp<IntegerAttr>(
1206 adaptor.getOperands(),
1207 [](
const APInt &a,
const APInt &b) { return a | b; });
1215 spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1218 return getOperand1();
1222 if (getOperand1() == getOperand2())
1230 return constFoldBinaryOp<IntegerAttr>(
1231 adaptor.getOperands(),
1232 [](
const APInt &a,
const APInt &b) { return a ^ b; });
1265 struct ConvertSelectionOpToSelect final :
OpRewritePattern<spirv::SelectionOp> {
1270 Operation *op = selectionOp.getOperation();
1279 if (llvm::range_size(body) != 4) {
1283 Block *headerBlock = selectionOp.getHeaderBlock();
1284 if (!onlyContainsBranchConditionalOp(headerBlock)) {
1288 auto brConditionalOp =
1289 cast<spirv::BranchConditionalOp>(headerBlock->
front());
1293 Block *mergeBlock = selectionOp.getMergeBlock();
1295 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1298 Value trueValue = getSrcValue(trueBlock);
1299 Value falseValue = getSrcValue(falseBlock);
1300 Value ptrValue = getDstPtr(trueBlock);
1301 auto storeOpAttributes =
1302 cast<spirv::StoreOp>(trueBlock->
front())->getAttrs();
1304 auto selectOp = rewriter.
create<spirv::SelectOp>(
1305 selectionOp.getLoc(), trueValue.
getType(),
1306 brConditionalOp.getCondition(), trueValue, falseValue);
1307 rewriter.
create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
1308 selectOp.getResult(), storeOpAttributes);
1322 LogicalResult canCanonicalizeSelection(
Block *trueBlock,
Block *falseBlock,
1323 Block *mergeBlock)
const;
1325 bool onlyContainsBranchConditionalOp(
Block *block)
const {
1326 return llvm::hasSingleElement(*block) &&
1327 isa<spirv::BranchConditionalOp>(block->
front());
1330 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs)
const {
1331 return lhs->getDiscardableAttrDictionary() ==
1332 rhs->getDiscardableAttrDictionary() &&
1333 lhs.getProperties() == rhs.getProperties();
1338 auto storeOp = cast<spirv::StoreOp>(block->
front());
1339 return storeOp.getValue();
1344 auto storeOp = cast<spirv::StoreOp>(block->
front());
1345 return storeOp.getPtr();
1349 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1352 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1356 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->
front());
1357 auto trueBrBranchOp =
1358 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->
begin()));
1359 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->
front());
1360 auto falseBrBranchOp =
1361 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->
begin()));
1363 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1373 bool isScalarOrVector =
1374 llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1375 .isScalarOrVector();
1379 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1380 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1384 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1385 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1393 void spirv::SelectionOp::getCanonicalizationPatterns(
RewritePatternSet &results,
1395 results.
add<ConvertSelectionOpToSelect>(context);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
static MLIRContext * getContext(OpFoldResult val)
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Block * getSuccessor(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult matchAndRewrite(spirv::IAddCarryOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...