23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVectorExtras.h"
38 if (
auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
39 return boolAttr.getValue();
40 if (
auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
41 if (splatAttr.getElementType().isInteger(1))
42 return splatAttr.getSplatValue<
bool>();
57 if (
auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
58 assert(indices.size() == 1 &&
"must have exactly one index for a vector");
59 return vector.getValues<
Attribute>()[indices[0]];
62 if (
auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
63 assert(!indices.empty() &&
"must have at least one index for an array");
65 indices.drop_front());
72 bool div0 = b.isZero();
73 bool overflow = a.isMinSignedValue() && b.isAllOnes();
75 return div0 || overflow;
83 #include "SPIRVCanonicalization.inc"
94 struct CombineChainedAccessChain final
98 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
100 auto parentAccessChainOp =
101 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
103 if (!parentAccessChainOp) {
109 llvm::append_range(indices, accessChainOp.getIndices());
112 accessChainOp, parentAccessChainOp.getBasePtr(), indices);
119 void spirv::AccessChainOp::getCanonicalizationPatterns(
121 results.
add<CombineChainedAccessChain>(context);
136 Value lhs = op.getOperand1();
137 Value rhs = op.getOperand2();
142 Value constituents[2] = {rhs, lhs};
165 auto adds = constFoldBinaryOp<IntegerAttr>(
167 [](
const APInt &a,
const APInt &b) {
return a + b; });
171 auto carrys = constFoldBinaryOp<IntegerAttr>(
172 ArrayRef{adds, lhsAttr}, [](
const APInt &a,
const APInt &b) {
174 return a.ult(b) ? (zero + 1) : zero;
181 spirv::ConstantOp::create(rewriter, loc, constituentType, adds);
184 spirv::ConstantOp::create(rewriter, loc, constituentType, carrys);
187 Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
190 spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0);
198 void spirv::IAddCarryOp::getCanonicalizationPatterns(
209 template <
typename MulOp,
bool IsSigned>
216 Value lhs = op.getOperand1();
217 Value rhs = op.getOperand2();
223 Value constituents[2] = {zero, zero};
243 auto lowBits = constFoldBinaryOp<IntegerAttr>(
245 [](
const APInt &a,
const APInt &b) {
return a * b; });
250 auto highBits = constFoldBinaryOp<IntegerAttr>(
251 {lhsAttr, rhsAttr}, [](
const APInt &a,
const APInt &b) {
253 return llvm::APIntOps::mulhs(a, b);
255 return llvm::APIntOps::mulhu(a, b);
263 spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits);
266 spirv::ConstantOp::create(rewriter, loc, constituentType, highBits);
269 Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
272 spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0);
281 void spirv::SMulExtendedOp::getCanonicalizationPatterns(
292 Value lhs = op.getOperand1();
293 Value rhs = op.getOperand2();
299 Value constituents[2] = {lhs, zero};
310 void spirv::UMulExtendedOp::getCanonicalizationPatterns(
333 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
345 bool isApplicable =
false;
346 if (
auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
347 auto currInt = cast<IntegerAttr>(currValue);
348 isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
349 }
else if (
auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
350 auto currVec = cast<DenseElementsAttr>(currValue);
351 isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
352 currVec.getValues<APInt>()),
353 [](
const auto &pair) {
354 auto &[prev, curr] = pair;
355 return prev.urem(curr) == 0;
365 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
381 Value curInput = getOperand();
386 if (
auto prevCast = curInput.
getDefiningOp<spirv::BitcastOp>()) {
387 Value prevInput = prevCast.getOperand();
391 getOperandMutable().assign(prevInput);
403 OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
404 Value compositeOp = getComposite();
406 while (
auto insertOp =
409 return insertOp.getObject();
410 compositeOp = insertOp.getComposite();
413 if (
auto constructOp =
415 auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
417 constructOp.getConstituents().size() == type.getNumElements()) {
418 auto i = llvm::cast<IntegerAttr>(*
getIndices().begin());
419 if (i.getValue().getSExtValue() <
420 static_cast<int64_t
>(constructOp.getConstituents().size()))
421 return constructOp.getConstituents()[i.getValue().getSExtValue()];
426 return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
446 return getOperand1();
453 return constFoldBinaryOp<IntegerAttr>(
454 adaptor.getOperands(),
455 [](APInt a,
const APInt &b) { return std::move(a) + b; });
465 return getOperand2();
468 return getOperand1();
475 return constFoldBinaryOp<IntegerAttr>(
476 adaptor.getOperands(),
477 [](
const APInt &a,
const APInt &b) { return a * b; });
486 if (getOperand1() == getOperand2())
494 return constFoldBinaryOp<IntegerAttr>(
495 adaptor.getOperands(),
496 [](APInt a,
const APInt &b) { return std::move(a) - b; });
506 return getOperand1();
516 bool div0OrOverflow =
false;
517 auto res = constFoldBinaryOp<IntegerAttr>(
518 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
519 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
520 div0OrOverflow = true;
525 return div0OrOverflow ?
Attribute() : res;
547 bool div0OrOverflow =
false;
548 auto res = constFoldBinaryOp<IntegerAttr>(
549 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
550 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
551 div0OrOverflow = true;
554 APInt c = a.abs().urem(b.abs());
557 if (b.isNegative()) {
558 APInt zero = APInt::getZero(c.getBitWidth());
559 return a.isNegative() ? (zero - c) : (b + c);
561 return a.isNegative() ? (b - c) : c;
563 return div0OrOverflow ?
Attribute() : res;
585 bool div0OrOverflow =
false;
586 auto res = constFoldBinaryOp<IntegerAttr>(
587 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
588 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
589 div0OrOverflow = true;
594 return div0OrOverflow ?
Attribute() : res;
604 return getOperand1();
613 auto res = constFoldBinaryOp<IntegerAttr>(
614 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
615 if (div0 || b.isZero()) {
640 auto res = constFoldBinaryOp<IntegerAttr>(
641 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
642 if (div0 || b.isZero()) {
655 OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
657 auto op = getOperand();
658 if (
auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
659 return negateOp->getOperand(0);
664 return constFoldUnaryOp<IntegerAttr>(
665 adaptor.getOperands(), [](
const APInt &a) {
666 APInt zero = APInt::getZero(a.getBitWidth());
675 OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
677 auto op = getOperand();
678 if (
auto notOp = op.getDefiningOp<spirv::NotOp>())
679 return notOp->getOperand(0);
684 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
694 OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
695 if (std::optional<bool> rhs =
699 return getOperand1();
703 return adaptor.getOperand2();
714 spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
716 if (getOperand1() == getOperand2()) {
718 if (isa<IntegerType>(
getType()))
720 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
724 return constFoldBinaryOp<IntegerAttr>(
725 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
726 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
734 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
735 if (std::optional<bool> rhs =
739 return getOperand1();
743 if (getOperand1() == getOperand2()) {
745 if (isa<IntegerType>(
getType()))
747 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
751 return constFoldBinaryOp<IntegerAttr>(
752 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
753 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
761 OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
763 auto op = getOperand();
764 if (
auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
765 return notOp->getOperand(0);
770 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
772 APInt zero = APInt::getZero(1);
773 return a == 1 ? zero : (zero + 1);
777 void spirv::LogicalNotOp::getCanonicalizationPatterns(
780 .
add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
781 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
789 OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
793 return adaptor.getOperand2();
798 return getOperand1();
809 OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
811 Value trueVals = getTrueValue();
812 Value falseVals = getFalseValue();
813 if (trueVals == falseVals)
821 return *boolAttr ? trueVals : falseVals;
824 if (!operands[0] || !operands[1] || !operands[2])
830 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
831 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
832 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
833 if (!condAttrs || !trueAttrs || !falseAttrs)
836 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<
Attribute>());
837 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<
BoolAttr>(),
839 for (
auto [result, cond, falseRes] : iters) {
840 if (!cond.getValue())
844 auto resultType = trueAttrs.getType();
852 OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
854 if (getOperand1() == getOperand2()) {
856 if (isa<IntegerType>(
getType()))
858 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
862 return constFoldBinaryOp<IntegerAttr>(
863 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
864 return a == b ? APInt::getAllOnes(1) : APInt::
getZero(1);
872 OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
874 if (getOperand1() == getOperand2()) {
876 if (isa<IntegerType>(
getType()))
878 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
882 return constFoldBinaryOp<IntegerAttr>(
883 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
893 spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
895 if (getOperand1() == getOperand2()) {
897 if (isa<IntegerType>(
getType()))
899 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
903 return constFoldBinaryOp<IntegerAttr>(
904 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
914 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
916 if (getOperand1() == getOperand2()) {
918 if (isa<IntegerType>(
getType()))
920 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
924 return constFoldBinaryOp<IntegerAttr>(
925 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
935 spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
937 if (getOperand1() == getOperand2()) {
939 if (isa<IntegerType>(
getType()))
941 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
945 return constFoldBinaryOp<IntegerAttr>(
946 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
956 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
958 if (getOperand1() == getOperand2()) {
960 if (isa<IntegerType>(
getType()))
962 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
966 return constFoldBinaryOp<IntegerAttr>(
967 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
976 OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
978 if (getOperand1() == getOperand2()) {
980 if (isa<IntegerType>(
getType()))
982 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
986 return constFoldBinaryOp<IntegerAttr>(
987 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
997 spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
999 if (getOperand1() == getOperand2()) {
1001 if (isa<IntegerType>(
getType()))
1003 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1007 return constFoldBinaryOp<IntegerAttr>(
1008 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
1017 OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1019 if (getOperand1() == getOperand2()) {
1021 if (isa<IntegerType>(
getType()))
1023 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1027 return constFoldBinaryOp<IntegerAttr>(
1028 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
1038 spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1040 if (getOperand1() == getOperand2()) {
1042 if (isa<IntegerType>(
getType()))
1044 if (
auto vecTy = dyn_cast<VectorType>(
getType()))
1048 return constFoldBinaryOp<IntegerAttr>(
1049 adaptor.getOperands(),
getType(), [](
const APInt &a,
const APInt &b) {
1059 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1062 return getOperand1();
1073 bool shiftToLarge =
false;
1074 auto res = constFoldBinaryOp<IntegerAttr>(
1075 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
1076 if (shiftToLarge || b.uge(a.getBitWidth())) {
1077 shiftToLarge = true;
1082 return shiftToLarge ?
Attribute() : res;
1090 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1093 return getOperand1();
1104 bool shiftToLarge =
false;
1105 auto res = constFoldBinaryOp<IntegerAttr>(
1106 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
1107 if (shiftToLarge || b.uge(a.getBitWidth())) {
1108 shiftToLarge = true;
1113 return shiftToLarge ?
Attribute() : res;
1121 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1124 return getOperand1();
1135 bool shiftToLarge =
false;
1136 auto res = constFoldBinaryOp<IntegerAttr>(
1137 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
1138 if (shiftToLarge || b.uge(a.getBitWidth())) {
1139 shiftToLarge = true;
1144 return shiftToLarge ?
Attribute() : res;
1152 spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1154 if (getOperand1() == getOperand2()) {
1155 return getOperand1();
1161 if (rhsMask.isZero())
1162 return getOperand2();
1165 if (rhsMask.isAllOnes())
1166 return getOperand1();
1169 if (
auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1172 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1173 return getOperand1();
1182 return constFoldBinaryOp<IntegerAttr>(
1183 adaptor.getOperands(),
1184 [](
const APInt &a,
const APInt &b) { return a & b; });
1191 OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1193 if (getOperand1() == getOperand2()) {
1194 return getOperand1();
1200 if (rhsMask.isZero())
1201 return getOperand1();
1204 if (rhsMask.isAllOnes())
1205 return getOperand2();
1213 return constFoldBinaryOp<IntegerAttr>(
1214 adaptor.getOperands(),
1215 [](
const APInt &a,
const APInt &b) { return a | b; });
1223 spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1226 return getOperand1();
1230 if (getOperand1() == getOperand2())
1238 return constFoldBinaryOp<IntegerAttr>(
1239 adaptor.getOperands(),
1240 [](
const APInt &a,
const APInt &b) { return a ^ b; });
1273 struct ConvertSelectionOpToSelect final :
OpRewritePattern<spirv::SelectionOp> {
1278 Operation *op = selectionOp.getOperation();
1287 if (llvm::range_size(body) != 4) {
1291 Block *headerBlock = selectionOp.getHeaderBlock();
1292 if (!onlyContainsBranchConditionalOp(headerBlock)) {
1296 auto brConditionalOp =
1297 cast<spirv::BranchConditionalOp>(headerBlock->
front());
1301 Block *mergeBlock = selectionOp.getMergeBlock();
1303 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1306 Value trueValue = getSrcValue(trueBlock);
1307 Value falseValue = getSrcValue(falseBlock);
1308 Value ptrValue = getDstPtr(trueBlock);
1309 auto storeOpAttributes =
1310 cast<spirv::StoreOp>(trueBlock->
front())->getAttrs();
1312 auto selectOp = spirv::SelectOp::create(
1313 rewriter, selectionOp.getLoc(), trueValue.
getType(),
1314 brConditionalOp.getCondition(), trueValue, falseValue);
1315 spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
1316 selectOp.getResult(), storeOpAttributes);
1330 LogicalResult canCanonicalizeSelection(
Block *trueBlock,
Block *falseBlock,
1331 Block *mergeBlock)
const;
1333 bool onlyContainsBranchConditionalOp(
Block *block)
const {
1334 return llvm::hasSingleElement(*block) &&
1335 isa<spirv::BranchConditionalOp>(block->
front());
1338 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs)
const {
1339 return lhs->getDiscardableAttrDictionary() ==
1340 rhs->getDiscardableAttrDictionary() &&
1341 lhs.getProperties() == rhs.getProperties();
1346 auto storeOp = cast<spirv::StoreOp>(block->
front());
1347 return storeOp.getValue();
1352 auto storeOp = cast<spirv::StoreOp>(block->
front());
1353 return storeOp.getPtr();
1357 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1360 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1364 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->
front());
1365 auto trueBrBranchOp =
1366 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->
begin()));
1367 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->
front());
1368 auto falseBrBranchOp =
1369 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->
begin()));
1371 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1381 bool isScalarOrVector =
1382 llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1383 .isScalarOrVector();
1387 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1388 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1392 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1393 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1401 void spirv::SelectionOp::getCanonicalizationPatterns(
RewritePatternSet &results,
1403 results.
add<ConvertSelectionOpToSelect>(context);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
static MLIRContext * getContext(OpFoldResult val)
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Block * getSuccessor(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult matchAndRewrite(spirv::IAddCarryOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.