26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
39 llvm::RoundingMode::NearestTiesToEven;
48 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
49 const APInt &lhsVal = llvm::cast<IntegerAttr>(
lhs).getValue();
50 const APInt &rhsVal = llvm::cast<IntegerAttr>(
rhs).getValue();
51 APInt value = binFn(lhsVal, rhsVal);
52 return IntegerAttr::get(res.
getType(), value);
71static IntegerOverflowFlagsAttr
73 IntegerOverflowFlagsAttr val2) {
74 return IntegerOverflowFlagsAttr::get(val1.getContext(),
75 val1.getValue() & val2.getValue());
81 case arith::CmpIPredicate::eq:
82 return arith::CmpIPredicate::ne;
83 case arith::CmpIPredicate::ne:
84 return arith::CmpIPredicate::eq;
85 case arith::CmpIPredicate::slt:
86 return arith::CmpIPredicate::sge;
87 case arith::CmpIPredicate::sle:
88 return arith::CmpIPredicate::sgt;
89 case arith::CmpIPredicate::sgt:
90 return arith::CmpIPredicate::sle;
91 case arith::CmpIPredicate::sge:
92 return arith::CmpIPredicate::slt;
93 case arith::CmpIPredicate::ult:
94 return arith::CmpIPredicate::uge;
95 case arith::CmpIPredicate::ule:
96 return arith::CmpIPredicate::ugt;
97 case arith::CmpIPredicate::ugt:
98 return arith::CmpIPredicate::ule;
99 case arith::CmpIPredicate::uge:
100 return arith::CmpIPredicate::ult;
102 llvm_unreachable(
"unknown cmpi predicate kind");
111static llvm::RoundingMode
115 switch (*roundingMode) {
116 case RoundingMode::downward:
117 return llvm::RoundingMode::TowardNegative;
118 case RoundingMode::to_nearest_away:
119 return llvm::RoundingMode::NearestTiesToAway;
120 case RoundingMode::to_nearest_even:
121 return llvm::RoundingMode::NearestTiesToEven;
122 case RoundingMode::toward_zero:
123 return llvm::RoundingMode::TowardZero;
124 case RoundingMode::upward:
125 return llvm::RoundingMode::TowardPositive;
127 llvm_unreachable(
"Unhandled rounding mode");
131 return arith::CmpIPredicateAttr::get(pred.getContext(),
157 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
161 if (!shapedType.hasStaticShape())
171#include "ArithCanonicalization.inc"
180 auto i1Type = IntegerType::get(type.
getContext(), 1);
181 if (
auto shapedType = dyn_cast<ShapedType>(type))
182 return shapedType.cloneWith(std::nullopt, i1Type);
183 if (llvm::isa<UnrankedTensorType>(type))
184 return UnrankedTensorType::get(i1Type);
192void arith::ConstantOp::getAsmResultNames(
195 if (
auto intCst = dyn_cast<IntegerAttr>(getValue())) {
196 auto intType = dyn_cast<IntegerType>(type);
199 if (intType && intType.getWidth() == 1)
200 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
203 SmallString<32> specialNameBuffer;
204 llvm::raw_svector_ostream specialName(specialNameBuffer);
205 specialName <<
'c' << intCst.getValue();
207 specialName <<
'_' << type;
208 setNameFn(getResult(), specialName.str());
210 setNameFn(getResult(),
"cst");
216LogicalResult arith::ConstantOp::verify() {
219 if (llvm::isa<IntegerType>(type) &&
220 !llvm::cast<IntegerType>(type).isSignless())
221 return emitOpError(
"integer return type must be signless");
223 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
225 "value must be an integer, float, or elements attribute");
231 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
233 "initializing scalable vectors with elements attribute is not supported"
234 " unless it's a vector splat");
238bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
240 auto typedAttr = dyn_cast<TypedAttr>(value);
241 if (!typedAttr || typedAttr.getType() != type)
245 if (!intType.isSignless())
249 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
252ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
253 Type type, Location loc) {
254 if (isBuildableWith(value, type))
255 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
259OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
264 arith::ConstantOp::build(builder,
result, type,
274 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
275 assert(
result &&
"builder didn't return the right type");
287 arith::ConstantOp::build(builder,
result, type,
296 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
297 assert(
result &&
"builder didn't return the right type");
308 arith::ConstantOp::build(builder,
result, type,
314 const APInt &
value) {
317 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
318 assert(
result &&
"builder didn't return the right type");
324 const APInt &
value) {
329 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
330 return constOp.getType().isSignlessInteger();
335 FloatType type,
const APFloat &
value) {
336 arith::ConstantOp::build(builder,
result, type,
343 const APFloat &
value) {
346 auto result = dyn_cast<ConstantFloatOp>(builder.
create(state));
347 assert(
result &&
"builder didn't return the right type");
353 const APFloat &
value) {
358 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
359 return llvm::isa<FloatType>(constOp.getType());
374 auto result = dyn_cast<ConstantIndexOp>(builder.
create(state));
375 assert(
result &&
"builder didn't return the right type");
385 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
386 return constOp.getType().isIndex();
394 "type doesn't have a zero representation");
396 assert(zeroAttr &&
"unsupported type for zero attribute");
397 return arith::ConstantOp::create(builder, loc, zeroAttr);
410 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
411 if (getRhs() == sub.getRhs())
415 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
416 if (getLhs() == sub.getRhs())
420 adaptor.getOperands(),
421 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
426 patterns.
add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
427 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
434std::optional<SmallVector<int64_t, 4>>
435arith::AddUIExtendedOp::getShapeForUnroll() {
436 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
437 return llvm::to_vector<4>(vt.getShape());
444 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
448arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
449 SmallVectorImpl<OpFoldResult> &results) {
450 Type overflowTy = getOverflow().getType();
456 results.push_back(getLhs());
457 results.push_back(falseValue);
466 adaptor.getOperands(),
467 [](APInt a,
const APInt &
b) { return std::move(a) + b; })) {
470 results.push_back(sumAttr);
471 results.push_back(sumAttr);
475 ArrayRef({sumAttr, adaptor.getLhs()}),
481 results.push_back(sumAttr);
482 results.push_back(overflowAttr);
489void arith::AddUIExtendedOp::getCanonicalizationPatterns(
490 RewritePatternSet &patterns, MLIRContext *context) {
491 patterns.
add<AddUIExtendedToAddI>(context);
498std::optional<SmallVector<int64_t, 4>>
499arith::SubUIExtendedOp::getShapeForUnroll() {
500 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
501 return llvm::to_vector<4>(vt.getShape());
508 return lhs.ult(
rhs) ? APInt::getAllOnes(1) : APInt::getZero(1);
512arith::SubUIExtendedOp::fold(FoldAdaptor adaptor,
513 SmallVectorImpl<OpFoldResult> &results) {
514 Type borrowTy = getBorrow().getType();
520 results.push_back(getLhs());
521 results.push_back(falseValue);
526 if (getLhs() == getRhs()) {
533 results.push_back(zeroDiff);
534 results.push_back(falseValue);
540 adaptor.getOperands(),
541 [](APInt a,
const APInt &
b) { return std::move(a) - b; })) {
544 results.push_back(diffAttr);
545 results.push_back(diffAttr);
549 adaptor.getOperands(),
555 results.push_back(diffAttr);
556 results.push_back(borrowAttr);
563void arith::SubUIExtendedOp::getCanonicalizationPatterns(
564 RewritePatternSet &patterns, MLIRContext *context) {
565 patterns.
add<SubUIExtendedToSubI>(context);
572OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
574 if (getOperand(0) == getOperand(1)) {
575 auto shapedType = dyn_cast<ShapedType>(
getType());
577 if (!shapedType || shapedType.hasStaticShape())
584 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
586 if (getRhs() ==
add.getRhs())
589 if (getRhs() ==
add.getLhs())
594 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
595 if (getLhs() == sub.getLhs())
599 adaptor.getOperands(),
600 [](APInt a,
const APInt &
b) { return std::move(a) - b; });
603void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
604 MLIRContext *context) {
605 patterns.
add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
606 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
607 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
614OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
625 adaptor.getOperands(),
626 [](
const APInt &a,
const APInt &
b) { return a * b; });
629void arith::MulIOp::getAsmResultNames(
631 if (!isa<IndexType>(
getType()))
636 auto isVscale = [](Operation *op) {
637 return op && op->getName().getStringRef() ==
"vector.vscale";
640 IntegerAttr baseValue;
641 auto isVscaleExpr = [&](Value a, Value
b) {
643 isVscale(
b.getDefiningOp());
646 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
650 SmallString<32> specialNameBuffer;
651 llvm::raw_svector_ostream specialName(specialNameBuffer);
652 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
653 setNameFn(getResult(), specialName.str());
656void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
657 MLIRContext *context) {
658 patterns.
add<MulIMulIConstant>(context);
665std::optional<SmallVector<int64_t, 4>>
666arith::MulSIExtendedOp::getShapeForUnroll() {
667 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
668 return llvm::to_vector<4>(vt.getShape());
673arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
674 SmallVectorImpl<OpFoldResult> &results) {
677 Attribute zero = adaptor.getRhs();
678 results.push_back(zero);
679 results.push_back(zero);
685 adaptor.getOperands(),
686 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
689 llvm::APIntOps::mulhs);
690 assert(highAttr &&
"Unexpected constant-folding failure");
692 results.push_back(lowAttr);
693 results.push_back(highAttr);
700void arith::MulSIExtendedOp::getCanonicalizationPatterns(
701 RewritePatternSet &patterns, MLIRContext *context) {
702 patterns.
add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
709std::optional<SmallVector<int64_t, 4>>
710arith::MulUIExtendedOp::getShapeForUnroll() {
711 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
712 return llvm::to_vector<4>(vt.getShape());
717arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
718 SmallVectorImpl<OpFoldResult> &results) {
721 Attribute zero = adaptor.getRhs();
722 results.push_back(zero);
723 results.push_back(zero);
731 results.push_back(getLhs());
732 results.push_back(zero);
738 adaptor.getOperands(),
739 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
742 llvm::APIntOps::mulhu);
743 assert(highAttr &&
"Unexpected constant-folding failure");
745 results.push_back(lowAttr);
746 results.push_back(highAttr);
753void arith::MulUIExtendedOp::getCanonicalizationPatterns(
754 RewritePatternSet &patterns, MLIRContext *context) {
755 patterns.
add<MulUIExtendedToMulI>(context);
764 arith::IntegerOverflowFlags ovfFlags) {
765 auto mul =
lhs.getDefiningOp<mlir::arith::MulIOp>();
766 if (!
mul || !bitEnumContainsAll(
mul.getOverflowFlags(), ovfFlags))
778OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
784 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
790 [&](APInt a,
const APInt &
b) {
798 return div0 ? Attribute() :
result;
818OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
824 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
828 bool overflowOrDiv0 =
false;
830 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
831 if (overflowOrDiv0 || !b) {
832 overflowOrDiv0 = true;
835 return a.sdiv_ov(
b, overflowOrDiv0);
838 return overflowOrDiv0 ? Attribute() :
result;
865 APInt one(a.getBitWidth(), 1,
true);
866 APInt val = a.ssub_ov(one, overflow).sdiv_ov(
b, overflow);
867 return val.sadd_ov(one, overflow);
874OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
879 bool overflowOrDiv0 =
false;
881 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
882 if (overflowOrDiv0 || !b) {
883 overflowOrDiv0 = true;
886 APInt quotient = a.udiv(
b);
889 APInt one(a.getBitWidth(), 1,
true);
890 return quotient.uadd_ov(one, overflowOrDiv0);
893 return overflowOrDiv0 ? Attribute() :
result;
904OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
912 bool overflowOrDiv0 =
false;
914 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
915 if (overflowOrDiv0 || !b) {
916 overflowOrDiv0 = true;
922 unsigned bits = a.getBitWidth();
923 APInt zero = APInt::getZero(bits);
924 bool aGtZero = a.sgt(zero);
925 bool bGtZero =
b.sgt(zero);
926 if (aGtZero && bGtZero) {
933 bool overflowNegA =
false;
934 bool overflowNegB =
false;
935 bool overflowDiv =
false;
936 bool overflowNegRes =
false;
937 if (!aGtZero && !bGtZero) {
939 APInt posA = zero.ssub_ov(a, overflowNegA);
940 APInt posB = zero.ssub_ov(
b, overflowNegB);
942 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
945 if (!aGtZero && bGtZero) {
947 APInt posA = zero.ssub_ov(a, overflowNegA);
948 APInt
div = posA.sdiv_ov(
b, overflowDiv);
949 APInt res = zero.ssub_ov(
div, overflowNegRes);
950 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
954 APInt posB = zero.ssub_ov(
b, overflowNegB);
955 APInt
div = a.sdiv_ov(posB, overflowDiv);
956 APInt res = zero.ssub_ov(
div, overflowNegRes);
958 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
962 return overflowOrDiv0 ? Attribute() :
result;
973OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
979 bool overflowOrDiv =
false;
981 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
983 overflowOrDiv = true;
986 return a.sfloordiv_ov(
b, overflowOrDiv);
989 return overflowOrDiv ? Attribute() :
result;
996OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
1004 [&](APInt a,
const APInt &
b) {
1005 if (div0 || b.isZero()) {
1012 return div0 ? Attribute() :
result;
1023OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
1031 [&](APInt a,
const APInt &
b) {
1032 if (div0 || b.isZero()) {
1039 return div0 ? Attribute() :
result;
1057 for (
bool reversePrev : {
false,
true}) {
1058 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
1059 .getDefiningOp<arith::AndIOp>();
1063 Value other = (reversePrev ? op.getLhs() : op.getRhs());
1064 if (other != prev.getLhs() && other != prev.getRhs())
1067 return prev.getResult();
1072OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
1079 intValue.isAllOnes())
1084 intValue.isAllOnes())
1089 intValue.isAllOnes())
1097 adaptor.getOperands(),
1098 [](APInt a,
const APInt &
b) { return std::move(a) & b; });
1105OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1108 if (rhsVal.isZero())
1111 if (rhsVal.isAllOnes())
1112 return adaptor.getRhs();
1119 intValue.isAllOnes())
1120 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1124 intValue.isAllOnes())
1125 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1128 adaptor.getOperands(),
1129 [](APInt a,
const APInt &
b) { return std::move(a) | b; });
1136OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1141 if (getLhs() == getRhs())
1145 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1146 if (prev.getRhs() == getRhs())
1147 return prev.getLhs();
1148 if (prev.getLhs() == getRhs())
1149 return prev.getRhs();
1153 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1154 if (prev.getRhs() == getLhs())
1155 return prev.getLhs();
1156 if (prev.getLhs() == getLhs())
1157 return prev.getRhs();
1161 adaptor.getOperands(),
1162 [](APInt a,
const APInt &
b) { return std::move(a) ^ b; });
1165void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1166 MLIRContext *context) {
1167 patterns.
add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1174OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1176 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1177 return op.getOperand();
1179 [](
const APFloat &a) { return -a; });
1186OpFoldResult arith::FlushDenormalsOp::fold(FoldAdaptor adaptor) {
1192 if (
auto op = this->getOperand().getDefiningOp<arith::FlushDenormalsOp>())
1193 return op.getResult();
1197 adaptor.getOperands(), [](
const APFloat &a) {
1199 return APFloat::getZero(a.getSemantics(), a.isNegative());
1208OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1213 auto rm = getRoundingmode();
1215 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1217 result.add(b, convertArithRoundingModeToLLVMIR(rm));
1226OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1231 auto rm = getRoundingmode();
1233 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1235 result.subtract(b, convertArithRoundingModeToLLVMIR(rm));
1240void arith::SubFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1241 MLIRContext *context) {
1242 patterns.
add<SubFOfNegZero>(context);
1249OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1251 if (getLhs() == getRhs())
1265OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1267 if (getLhs() == getRhs())
1281OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1283 if (getLhs() == getRhs())
1289 if (intValue.isMaxSignedValue())
1292 if (intValue.isMinSignedValue())
1297 llvm::APIntOps::smax);
1304OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1306 if (getLhs() == getRhs())
1312 if (intValue.isMaxValue())
1315 if (intValue.isMinValue())
1320 llvm::APIntOps::umax);
1327OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1329 if (getLhs() == getRhs())
1343OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1345 if (getLhs() == getRhs())
1359OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1361 if (getLhs() == getRhs())
1367 if (intValue.isMinSignedValue())
1370 if (intValue.isMaxSignedValue())
1375 llvm::APIntOps::smin);
1382OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1384 if (getLhs() == getRhs())
1390 if (intValue.isMinValue())
1393 if (intValue.isMaxValue())
1398 llvm::APIntOps::umin);
1405OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1410 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1411 arith::FastMathFlags::nsz)) {
1417 auto rm = getRoundingmode();
1419 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1421 result.multiply(b, convertArithRoundingModeToLLVMIR(rm));
1426void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1427 MLIRContext *context) {
1428 patterns.
add<MulFOfNegF>(context);
1435OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1440 auto rm = getRoundingmode();
1442 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1444 result.divide(b, convertArithRoundingModeToLLVMIR(rm));
1449void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1450 MLIRContext *context) {
1451 patterns.
add<DivFOfNegF>(context);
1458OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1460 [](
const APFloat &a,
const APFloat &
b) {
1465 (void)result.mod(b);
1474template <
typename... Types>
1480template <
typename... ShapedTypes,
typename... ElementTypes>
1483 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1487 if (!llvm::isa<ElementTypes...>(underlyingType))
1490 return underlyingType;
1494template <
typename... ElementTypes>
1501template <
typename... ElementTypes>
1510 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1511 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1512 if (!rankedTensorA || !rankedTensorB)
1514 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1518 if (inputs.size() != 1 || outputs.size() != 1)
1530template <
typename ValType,
typename Op>
1535 if (llvm::cast<ValType>(srcType).getWidth() >=
1536 llvm::cast<ValType>(dstType).getWidth())
1538 << dstType <<
" must be wider than operand type " << srcType;
1544template <
typename ValType,
typename Op>
1549 if (llvm::cast<ValType>(srcType).getWidth() <=
1550 llvm::cast<ValType>(dstType).getWidth())
1552 << dstType <<
" must be shorter than operand type " << srcType;
1558template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1563 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1564 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1565 if (!srcType || !dstType)
1568 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1569 srcType.getIntOrFloatBitWidth());
1574static FailureOr<APFloat>
1576 const llvm::fltSemantics &targetSemantics,
1580 using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior;
1581 if (sourceValue.isInfinity() &&
1582 (targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
1583 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly))
1585 if (sourceValue.isNaN() &&
1586 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly)
1589 bool losesInfo =
false;
1590 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1591 if (losesInfo || status != APFloat::opOK)
1601OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1602 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1603 getInMutable().assign(
lhs.getIn());
1608 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1610 adaptor.getOperands(),
getType(),
1611 [bitWidth](
const APInt &a,
bool &castStatus) {
1612 return a.zext(bitWidth);
1620LogicalResult arith::ExtUIOp::verify() {
1628OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1629 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1630 getInMutable().assign(
lhs.getIn());
1635 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1637 adaptor.getOperands(),
getType(),
1638 [bitWidth](
const APInt &a,
bool &castStatus) {
1639 return a.sext(bitWidth);
1647void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1648 MLIRContext *context) {
1649 patterns.
add<ExtSIOfExtUI>(context);
1652LogicalResult arith::ExtSIOp::verify() {
1662OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1663 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1664 if (truncFOp.getOperand().getType() ==
getType()) {
1665 arith::FastMathFlags truncFMF =
1666 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1667 bool isTruncContract =
1668 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1669 arith::FastMathFlags extFMF =
1670 getFastmath().value_or(arith::FastMathFlags::none);
1671 bool isExtContract =
1672 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1673 if (isTruncContract && isExtContract) {
1674 return truncFOp.getOperand();
1680 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1682 adaptor.getOperands(),
getType(),
1683 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1703bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
1708LogicalResult arith::ScalingExtFOp::verify() {
1716OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1719 Value src = getOperand().getDefiningOp()->getOperand(0);
1724 if (llvm::cast<IntegerType>(srcType).getWidth() >
1725 llvm::cast<IntegerType>(dstType).getWidth()) {
1732 if (srcType == dstType)
1738 setOperand(getOperand().getDefiningOp()->getOperand(0));
1743 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1745 adaptor.getOperands(),
getType(),
1746 [bitWidth](
const APInt &a,
bool &castStatus) {
1747 return a.trunc(bitWidth);
1755void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1756 MLIRContext *context) {
1758 .
add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1762LogicalResult arith::TruncIOp::verify() {
1772OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1774 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1775 Value src = extOp.getIn();
1777 auto intermediateType =
1780 if (llvm::APFloatBase::isRepresentableBy(
1781 srcType.getFloatSemantics(),
1782 intermediateType.getFloatSemantics())) {
1784 if (srcType.getWidth() > resElemType.getWidth()) {
1790 if (srcType == resElemType)
1795 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1797 adaptor.getOperands(),
getType(),
1798 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1799 llvm::RoundingMode llvmRoundingMode =
1801 FailureOr<APFloat>
result =
1811void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1812 MLIRContext *context) {
1813 patterns.
add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1820LogicalResult arith::TruncFOp::verify() {
1828OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) {
1830 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1832 adaptor.getOperands(),
getType(),
1833 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1834 llvm::RoundingMode llvmRoundingMode =
1836 FailureOr<APFloat>
result =
1851 if (!srcType || !dstType)
1853 return srcType != dstType &&
1857LogicalResult arith::ConvertFOp::verify() {
1860 if (srcType == dstType)
1861 return emitError(
"result element type ")
1862 << dstType <<
" must be different from operand element type "
1864 if (srcType.getWidth() != dstType.getWidth())
1865 return emitError(
"result element type ")
1866 << dstType <<
" must have the same bitwidth as operand element type "
1875bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
1880LogicalResult arith::ScalingTruncFOp::verify() {
1888void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1889 MLIRContext *context) {
1890 patterns.
add<AndOfExtUI, AndOfExtSI>(context);
1897void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1898 MLIRContext *context) {
1899 patterns.
add<OrOfExtUI, OrOfExtSI>(context);
1906template <
typename From,
typename To>
1914 return srcType && dstType;
1925OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1928 adaptor.getOperands(),
getType(),
1929 [&resEleType](
const APInt &a,
bool &castStatus) {
1930 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1931 APFloat apf(floatTy.getFloatSemantics(),
1932 APInt::getZero(floatTy.getWidth()));
1933 apf.convertFromAPInt(a,
false,
1934 APFloat::rmNearestTiesToEven);
1939void arith::UIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1940 MLIRContext *context) {
1941 patterns.
add<UIToFPOfExtUI>(context);
1952OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1955 adaptor.getOperands(),
getType(),
1956 [&resEleType](
const APInt &a,
bool &castStatus) {
1957 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1958 APFloat apf(floatTy.getFloatSemantics(),
1959 APInt::getZero(floatTy.getWidth()));
1960 apf.convertFromAPInt(a,
true,
1961 APFloat::rmNearestTiesToEven);
1966void arith::SIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1967 MLIRContext *context) {
1968 patterns.
add<SIToFPOfExtSI, SIToFPOfExtUI>(context);
1979OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1981 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1983 adaptor.getOperands(),
getType(),
1984 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1986 APSInt api(bitWidth,
true);
1987 castStatus = APFloat::opInvalidOp !=
1988 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
2001OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
2003 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
2005 adaptor.getOperands(),
getType(),
2006 [&bitWidth](
const APFloat &a,
bool &castStatus) {
2008 APSInt api(bitWidth,
false);
2009 castStatus = APFloat::opInvalidOp !=
2010 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
2024 return intTy.getWidth();
2025 return IndexType::kInternalStorageBitWidth;
2034 if (!srcType || !dstType)
2038 (srcType.isSignlessInteger() && dstType.
isIndex());
2041bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
2046OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
2048 unsigned resultBitwidth = 64;
2050 resultBitwidth = intTy.getWidth();
2053 adaptor.getOperands(),
getType(),
2054 [resultBitwidth](
const APInt &a,
bool & ) {
2055 return a.sextOrTrunc(resultBitwidth);
2062 if (
auto inner = getOperand().getDefiningOp<arith::IndexCastOp>()) {
2063 Value x = inner.getOperand();
2072void arith::IndexCastOp::getCanonicalizationPatterns(
2073 RewritePatternSet &patterns, MLIRContext *context) {
2074 patterns.
add<IndexCastOfExtSI>(context);
2081bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
2086OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
2088 unsigned resultBitwidth = 64;
2090 resultBitwidth = intTy.getWidth();
2093 adaptor.getOperands(),
getType(),
2094 [resultBitwidth](
const APInt &a,
bool & ) {
2095 return a.zextOrTrunc(resultBitwidth);
2102 if (
auto inner = getOperand().getDefiningOp<arith::IndexCastUIOp>()) {
2103 Value x = inner.getOperand();
2112void arith::IndexCastUIOp::getCanonicalizationPatterns(
2113 RewritePatternSet &patterns, MLIRContext *context) {
2114 patterns.
add<IndexCastUIOfExtUI>(context);
2127 if (!srcType || !dstType)
2133OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
2135 auto operand = adaptor.getIn();
2140 if (
auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
2141 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
2143 if (llvm::isa<ShapedType>(resType))
2151 APInt bits = llvm::isa<FloatAttr>(operand)
2152 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
2153 : llvm::cast<IntegerAttr>(operand).getValue();
2155 "trying to fold on broken IR: operands have incompatible types");
2157 if (
auto resFloatType = dyn_cast<FloatType>(resType))
2158 return FloatAttr::get(resType,
2159 APFloat(resFloatType.getFloatSemantics(), bits));
2160 return IntegerAttr::get(resType, bits);
2163void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2164 MLIRContext *context) {
2165 patterns.
add<BitcastOfBitcast>(context);
2175 const APInt &
lhs,
const APInt &
rhs) {
2176 switch (predicate) {
2177 case arith::CmpIPredicate::eq:
2179 case arith::CmpIPredicate::ne:
2181 case arith::CmpIPredicate::slt:
2183 case arith::CmpIPredicate::sle:
2185 case arith::CmpIPredicate::sgt:
2187 case arith::CmpIPredicate::sge:
2189 case arith::CmpIPredicate::ult:
2191 case arith::CmpIPredicate::ule:
2193 case arith::CmpIPredicate::ugt:
2195 case arith::CmpIPredicate::uge:
2198 llvm_unreachable(
"unknown cmpi predicate kind");
2203 switch (predicate) {
2204 case arith::CmpIPredicate::eq:
2205 case arith::CmpIPredicate::sle:
2206 case arith::CmpIPredicate::sge:
2207 case arith::CmpIPredicate::ule:
2208 case arith::CmpIPredicate::uge:
2210 case arith::CmpIPredicate::ne:
2211 case arith::CmpIPredicate::slt:
2212 case arith::CmpIPredicate::sgt:
2213 case arith::CmpIPredicate::ult:
2214 case arith::CmpIPredicate::ugt:
2217 llvm_unreachable(
"unknown cmpi predicate kind");
2221 if (
auto intType = dyn_cast<IntegerType>(t)) {
2222 return intType.getWidth();
2224 if (
auto vectorIntType = dyn_cast<VectorType>(t)) {
2225 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
2227 return std::nullopt;
2230OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2232 if (getLhs() == getRhs()) {
2238 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2240 std::optional<int64_t> integerWidth =
2242 if (integerWidth && integerWidth.value() == 1 &&
2243 getPredicate() == arith::CmpIPredicate::ne)
2244 return extOp.getOperand();
2246 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2248 std::optional<int64_t> integerWidth =
2250 if (integerWidth && integerWidth.value() == 1 &&
2251 getPredicate() == arith::CmpIPredicate::ne)
2252 return extOp.getOperand();
2257 getPredicate() == arith::CmpIPredicate::ne)
2264 getPredicate() == arith::CmpIPredicate::eq)
2269 if (adaptor.getLhs() && !adaptor.getRhs()) {
2271 using Pred = CmpIPredicate;
2272 const std::pair<Pred, Pred> invPreds[] = {
2273 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2274 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2275 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2276 {Pred::ne, Pred::ne},
2278 Pred origPred = getPredicate();
2279 for (
auto pred : invPreds) {
2280 if (origPred == pred.first) {
2281 setPredicate(pred.second);
2282 Value
lhs = getLhs();
2283 Value
rhs = getRhs();
2284 getLhsMutable().assign(
rhs);
2285 getRhsMutable().assign(
lhs);
2289 llvm_unreachable(
"unknown cmpi predicate kind");
2294 if (
auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2297 [pred = getPredicate()](
const APInt &
lhs,
const APInt &
rhs) {
2306void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2307 MLIRContext *context) {
2308 patterns.
insert<CmpIExtSI, CmpIExtUI>(context);
2318 const APFloat &
lhs,
const APFloat &
rhs) {
2319 auto cmpResult =
lhs.compare(
rhs);
2320 switch (predicate) {
2321 case arith::CmpFPredicate::AlwaysFalse:
2323 case arith::CmpFPredicate::OEQ:
2324 return cmpResult == APFloat::cmpEqual;
2325 case arith::CmpFPredicate::OGT:
2326 return cmpResult == APFloat::cmpGreaterThan;
2327 case arith::CmpFPredicate::OGE:
2328 return cmpResult == APFloat::cmpGreaterThan ||
2329 cmpResult == APFloat::cmpEqual;
2330 case arith::CmpFPredicate::OLT:
2331 return cmpResult == APFloat::cmpLessThan;
2332 case arith::CmpFPredicate::OLE:
2333 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2334 case arith::CmpFPredicate::ONE:
2335 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2336 case arith::CmpFPredicate::ORD:
2337 return cmpResult != APFloat::cmpUnordered;
2338 case arith::CmpFPredicate::UEQ:
2339 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2340 case arith::CmpFPredicate::UGT:
2341 return cmpResult == APFloat::cmpUnordered ||
2342 cmpResult == APFloat::cmpGreaterThan;
2343 case arith::CmpFPredicate::UGE:
2344 return cmpResult == APFloat::cmpUnordered ||
2345 cmpResult == APFloat::cmpGreaterThan ||
2346 cmpResult == APFloat::cmpEqual;
2347 case arith::CmpFPredicate::ULT:
2348 return cmpResult == APFloat::cmpUnordered ||
2349 cmpResult == APFloat::cmpLessThan;
2350 case arith::CmpFPredicate::ULE:
2351 return cmpResult == APFloat::cmpUnordered ||
2352 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2353 case arith::CmpFPredicate::UNE:
2354 return cmpResult != APFloat::cmpEqual;
2355 case arith::CmpFPredicate::UNO:
2356 return cmpResult == APFloat::cmpUnordered;
2357 case arith::CmpFPredicate::AlwaysTrue:
2360 llvm_unreachable(
"unknown cmpf predicate kind");
2364 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2365 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2368 if (
lhs &&
lhs.getValue().isNaN())
2370 if (
rhs &&
rhs.getValue().isNaN())
2386 using namespace arith;
2388 case CmpFPredicate::UEQ:
2389 case CmpFPredicate::OEQ:
2390 return CmpIPredicate::eq;
2391 case CmpFPredicate::UGT:
2392 case CmpFPredicate::OGT:
2393 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2394 case CmpFPredicate::UGE:
2395 case CmpFPredicate::OGE:
2396 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2397 case CmpFPredicate::ULT:
2398 case CmpFPredicate::OLT:
2399 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2400 case CmpFPredicate::ULE:
2401 case CmpFPredicate::OLE:
2402 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2403 case CmpFPredicate::UNE:
2404 case CmpFPredicate::ONE:
2405 return CmpIPredicate::ne;
2407 llvm_unreachable(
"Unexpected predicate!");
2417 const APFloat &
rhs = flt.getValue();
2425 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2426 int mantissaWidth = floatTy.getFPMantissaWidth();
2427 if (mantissaWidth <= 0)
2433 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2435 intVal = si.getIn();
2436 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2438 intVal = ui.getIn();
2445 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2446 auto intWidth = intTy.getWidth();
2449 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2454 if ((
int)intWidth > mantissaWidth) {
2456 int exponent = ilogb(
rhs);
2457 if (exponent == APFloat::IEK_Inf) {
2458 int maxExponent = ilogb(APFloat::getLargest(
rhs.getSemantics()));
2459 if (maxExponent < (
int)valueBits) {
2466 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2475 switch (op.getPredicate()) {
2476 case CmpFPredicate::ORD:
2481 case CmpFPredicate::UNO:
2494 APFloat signedMax(
rhs.getSemantics());
2495 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2496 APFloat::rmNearestTiesToEven);
2497 if (signedMax <
rhs) {
2498 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2499 pred == CmpIPredicate::sle)
2510 APFloat unsignedMax(
rhs.getSemantics());
2511 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2512 APFloat::rmNearestTiesToEven);
2513 if (unsignedMax <
rhs) {
2514 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2515 pred == CmpIPredicate::ule)
2527 APFloat signedMin(
rhs.getSemantics());
2528 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2529 APFloat::rmNearestTiesToEven);
2530 if (signedMin >
rhs) {
2531 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2532 pred == CmpIPredicate::sge)
2542 APFloat unsignedMin(
rhs.getSemantics());
2543 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2544 APFloat::rmNearestTiesToEven);
2545 if (unsignedMin >
rhs) {
2546 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2547 pred == CmpIPredicate::uge)
2562 APSInt rhsInt(intWidth, isUnsigned);
2563 if (APFloat::opInvalidOp ==
2564 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2570 if (!
rhs.isZero()) {
2571 APFloat apf(floatTy.getFloatSemantics(),
2572 APInt::getZero(floatTy.getWidth()));
2573 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2575 bool equal = apf ==
rhs;
2581 case CmpIPredicate::ne:
2585 case CmpIPredicate::eq:
2589 case CmpIPredicate::ule:
2592 if (
rhs.isNegative()) {
2598 case CmpIPredicate::sle:
2601 if (
rhs.isNegative())
2602 pred = CmpIPredicate::slt;
2604 case CmpIPredicate::ult:
2607 if (
rhs.isNegative()) {
2612 pred = CmpIPredicate::ule;
2614 case CmpIPredicate::slt:
2617 if (!
rhs.isNegative())
2618 pred = CmpIPredicate::sle;
2620 case CmpIPredicate::ugt:
2623 if (
rhs.isNegative()) {
2629 case CmpIPredicate::sgt:
2632 if (
rhs.isNegative())
2633 pred = CmpIPredicate::sge;
2635 case CmpIPredicate::uge:
2638 if (
rhs.isNegative()) {
2643 pred = CmpIPredicate::ugt;
2645 case CmpIPredicate::sge:
2648 if (!
rhs.isNegative())
2649 pred = CmpIPredicate::sgt;
2659 ConstantOp::create(rewriter, op.getLoc(), intVal.
getType(),
2665void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2666 MLIRContext *context) {
2667 patterns.
insert<CmpFIntToFPConst>(context);
2681 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2697 arith::XOrIOp::create(
2698 rewriter, op.getLoc(), op.getCondition(),
2700 op.getCondition().
getType(), 1)));
2708void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2709 MLIRContext *context) {
2710 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2711 SelectI1ToNot, SelectToExtUI>(context);
2714OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2715 Value trueVal = getTrueValue();
2716 Value falseVal = getFalseValue();
2717 if (trueVal == falseVal)
2720 Value condition = getCondition();
2738 if (
getType().isSignlessInteger(1) &&
2744 auto pred = cmp.getPredicate();
2745 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2746 auto cmpLhs = cmp.getLhs();
2747 auto cmpRhs = cmp.getRhs();
2755 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2756 (cmpRhs == trueVal && cmpLhs == falseVal))
2757 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2764 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2766 assert(cond.getType().hasStaticShape() &&
2767 "DenseElementsAttr must have static shape");
2769 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2771 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2772 SmallVector<Attribute> results;
2773 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2774 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2775 cond.value_end<BoolAttr>());
2776 auto lhsVals = llvm::make_range(
lhs.value_begin<Attribute>(),
2777 lhs.value_end<Attribute>());
2778 auto rhsVals = llvm::make_range(
rhs.value_begin<Attribute>(),
2779 rhs.value_end<Attribute>());
2781 for (
auto [condVal, lhsVal, rhsVal] :
2782 llvm::zip_equal(condVals, lhsVals, rhsVals))
2783 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2793ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &
result) {
2794 Type conditionType, resultType;
2795 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2803 conditionType = resultType;
2810 result.addTypes(resultType);
2812 {conditionType, resultType, resultType},
2816void arith::SelectOp::print(OpAsmPrinter &p) {
2817 p <<
" " << getOperands();
2820 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().
getType()))
2821 p << condType <<
", ";
2825LogicalResult arith::SelectOp::verify() {
2826 Type conditionType = getCondition().getType();
2833 if (!llvm::isa<TensorType, VectorType>(resultType))
2834 return emitOpError() <<
"expected condition to be a signless i1, but got "
2837 if (conditionType != shapedConditionType) {
2838 return emitOpError() <<
"expected condition type to have the same shape "
2839 "as the result type, expected "
2840 << shapedConditionType <<
", but got "
2849OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2854 bool bounded =
false;
2856 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2857 bounded = b.ult(b.getBitWidth());
2860 return bounded ?
result : Attribute();
2867OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2872 bool bounded =
false;
2874 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2875 bounded = b.ult(b.getBitWidth());
2878 return bounded ?
result : Attribute();
2885OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2890 bool bounded =
false;
2892 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2893 bounded = b.ult(b.getBitWidth());
2896 return bounded ?
result : Attribute();
2906 bool useOnlyFiniteValue) {
2908 case AtomicRMWKind::maximumf: {
2909 const llvm::fltSemantics &semantic =
2910 llvm::cast<FloatType>(resultType).getFloatSemantics();
2911 APFloat identity = useOnlyFiniteValue
2912 ? APFloat::getLargest(semantic,
true)
2913 : APFloat::getInf(semantic,
true);
2916 case AtomicRMWKind::maxnumf: {
2917 const llvm::fltSemantics &semantic =
2918 llvm::cast<FloatType>(resultType).getFloatSemantics();
2919 APFloat identity = APFloat::getNaN(semantic,
true);
2922 case AtomicRMWKind::addf:
2923 case AtomicRMWKind::addi:
2924 case AtomicRMWKind::maxu:
2925 case AtomicRMWKind::ori:
2926 case AtomicRMWKind::xori:
2928 case AtomicRMWKind::andi:
2931 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2932 case AtomicRMWKind::maxs:
2934 resultType, APInt::getSignedMinValue(
2935 llvm::cast<IntegerType>(resultType).getWidth()));
2936 case AtomicRMWKind::minimumf: {
2937 const llvm::fltSemantics &semantic =
2938 llvm::cast<FloatType>(resultType).getFloatSemantics();
2939 APFloat identity = useOnlyFiniteValue
2940 ? APFloat::getLargest(semantic,
false)
2941 : APFloat::getInf(semantic,
false);
2945 case AtomicRMWKind::minnumf: {
2946 const llvm::fltSemantics &semantic =
2947 llvm::cast<FloatType>(resultType).getFloatSemantics();
2948 APFloat identity = APFloat::getNaN(semantic,
false);
2951 case AtomicRMWKind::mins:
2953 resultType, APInt::getSignedMaxValue(
2954 llvm::cast<IntegerType>(resultType).getWidth()));
2955 case AtomicRMWKind::minu:
2958 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2959 case AtomicRMWKind::muli:
2961 case AtomicRMWKind::mulf:
2973 std::optional<AtomicRMWKind> maybeKind =
2976 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2977 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2978 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2979 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2980 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2981 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2983 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2984 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2985 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::xori; })
2986 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2987 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2988 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2989 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2990 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2991 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2992 .Default(std::nullopt);
2994 return std::nullopt;
2997 bool useOnlyFiniteValue =
false;
2998 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2999 if (fmfOpInterface) {
3000 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
3001 useOnlyFiniteValue =
3002 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
3010 useOnlyFiniteValue);
3016 bool useOnlyFiniteValue) {
3018 useOnlyFiniteValue))
3019 return arith::ConstantOp::create(builder, loc, attr);
3028 case AtomicRMWKind::addf:
3029 return arith::AddFOp::create(builder, loc,
lhs,
rhs);
3030 case AtomicRMWKind::addi:
3031 return arith::AddIOp::create(builder, loc,
lhs,
rhs);
3032 case AtomicRMWKind::mulf:
3033 return arith::MulFOp::create(builder, loc,
lhs,
rhs);
3034 case AtomicRMWKind::muli:
3035 return arith::MulIOp::create(builder, loc,
lhs,
rhs);
3036 case AtomicRMWKind::maximumf:
3037 return arith::MaximumFOp::create(builder, loc,
lhs,
rhs);
3038 case AtomicRMWKind::minimumf:
3039 return arith::MinimumFOp::create(builder, loc,
lhs,
rhs);
3040 case AtomicRMWKind::maxnumf:
3041 return arith::MaxNumFOp::create(builder, loc,
lhs,
rhs);
3042 case AtomicRMWKind::minnumf:
3043 return arith::MinNumFOp::create(builder, loc,
lhs,
rhs);
3044 case AtomicRMWKind::maxs:
3045 return arith::MaxSIOp::create(builder, loc,
lhs,
rhs);
3046 case AtomicRMWKind::mins:
3047 return arith::MinSIOp::create(builder, loc,
lhs,
rhs);
3048 case AtomicRMWKind::maxu:
3049 return arith::MaxUIOp::create(builder, loc,
lhs,
rhs);
3050 case AtomicRMWKind::minu:
3051 return arith::MinUIOp::create(builder, loc,
lhs,
rhs);
3052 case AtomicRMWKind::ori:
3053 return arith::OrIOp::create(builder, loc,
lhs,
rhs);
3054 case AtomicRMWKind::andi:
3055 return arith::AndIOp::create(builder, loc,
lhs,
rhs);
3056 case AtomicRMWKind::xori:
3057 return arith::XOrIOp::create(builder, loc,
lhs,
rhs);
3070#define GET_OP_CLASSES
3071#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
3077#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static constexpr llvm::RoundingMode kDefaultRoundingMode
Default rounding mode according to default LLVM floating-point environment.
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=kDefaultRoundingMode)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(std::optional< RoundingMode > roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
static APInt calculateUnsignedBorrow(const APInt &lhs, const APInt &rhs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static unsigned getIndexCastWidth(Type t)
Return the bit-width of t for the purpose of index_cast width checks.
static LogicalResult verifyTruncateOp(Op op)
static Type getElementType(Type type)
Determine the element type of type.
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Specialization of arith.constant op that returns an integer of index type.
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Specialization of arith.constant op that returns an integer value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
llvm::function_ref< Fn > function_ref
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.