26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
39 llvm::RoundingMode::NearestTiesToEven;
48 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
49 const APInt &lhsVal = llvm::cast<IntegerAttr>(
lhs).getValue();
50 const APInt &rhsVal = llvm::cast<IntegerAttr>(
rhs).getValue();
51 APInt value = binFn(lhsVal, rhsVal);
52 return IntegerAttr::get(res.
getType(), value);
86static IntegerOverflowFlagsAttr
88 IntegerOverflowFlagsAttr val2) {
89 return IntegerOverflowFlagsAttr::get(val1.getContext(),
90 val1.getValue() & val2.getValue());
96 case arith::CmpIPredicate::eq:
97 return arith::CmpIPredicate::ne;
98 case arith::CmpIPredicate::ne:
99 return arith::CmpIPredicate::eq;
100 case arith::CmpIPredicate::slt:
101 return arith::CmpIPredicate::sge;
102 case arith::CmpIPredicate::sle:
103 return arith::CmpIPredicate::sgt;
104 case arith::CmpIPredicate::sgt:
105 return arith::CmpIPredicate::sle;
106 case arith::CmpIPredicate::sge:
107 return arith::CmpIPredicate::slt;
108 case arith::CmpIPredicate::ult:
109 return arith::CmpIPredicate::uge;
110 case arith::CmpIPredicate::ule:
111 return arith::CmpIPredicate::ugt;
112 case arith::CmpIPredicate::ugt:
113 return arith::CmpIPredicate::ule;
114 case arith::CmpIPredicate::uge:
115 return arith::CmpIPredicate::ult;
117 llvm_unreachable(
"unknown cmpi predicate kind");
126static llvm::RoundingMode
130 switch (*roundingMode) {
131 case RoundingMode::downward:
132 return llvm::RoundingMode::TowardNegative;
133 case RoundingMode::to_nearest_away:
134 return llvm::RoundingMode::NearestTiesToAway;
135 case RoundingMode::to_nearest_even:
136 return llvm::RoundingMode::NearestTiesToEven;
137 case RoundingMode::toward_zero:
138 return llvm::RoundingMode::TowardZero;
139 case RoundingMode::upward:
140 return llvm::RoundingMode::TowardPositive;
142 llvm_unreachable(
"Unhandled rounding mode");
146 return arith::CmpIPredicateAttr::get(pred.getContext(),
172 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
176 if (!shapedType.hasStaticShape())
186#include "ArithCanonicalization.inc"
195 auto i1Type = IntegerType::get(type.
getContext(), 1);
196 if (
auto shapedType = dyn_cast<ShapedType>(type))
197 return shapedType.cloneWith(std::nullopt, i1Type);
198 if (llvm::isa<UnrankedTensorType>(type))
199 return UnrankedTensorType::get(i1Type);
207void arith::ConstantOp::getAsmResultNames(
210 if (
auto intCst = dyn_cast<IntegerAttr>(getValue())) {
211 auto intType = dyn_cast<IntegerType>(type);
214 if (intType && intType.getWidth() == 1)
215 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
218 SmallString<32> specialNameBuffer;
219 llvm::raw_svector_ostream specialName(specialNameBuffer);
220 specialName <<
'c' << intCst.getValue();
222 specialName <<
'_' << type;
223 setNameFn(getResult(), specialName.str());
225 setNameFn(getResult(),
"cst");
231LogicalResult arith::ConstantOp::verify() {
234 if (llvm::isa<IntegerType>(type) &&
235 !llvm::cast<IntegerType>(type).isSignless())
236 return emitOpError(
"integer return type must be signless");
238 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
240 "value must be an integer, float, or elements attribute");
246 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
248 "initializing scalable vectors with elements attribute is not supported"
249 " unless it's a vector splat");
253bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
255 auto typedAttr = dyn_cast<TypedAttr>(value);
256 if (!typedAttr || typedAttr.getType() != type)
260 if (!intType.isSignless())
264 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
267ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
268 Type type, Location loc) {
269 if (isBuildableWith(value, type))
270 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
274OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
279 arith::ConstantOp::build(builder,
result, type,
289 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
290 assert(
result &&
"builder didn't return the right type");
302 arith::ConstantOp::build(builder,
result, type,
311 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
312 assert(
result &&
"builder didn't return the right type");
323 arith::ConstantOp::build(builder,
result, type,
329 const APInt &
value) {
332 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
333 assert(
result &&
"builder didn't return the right type");
339 const APInt &
value) {
344 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
345 return constOp.getType().isSignlessInteger();
350 FloatType type,
const APFloat &
value) {
351 arith::ConstantOp::build(builder,
result, type,
358 const APFloat &
value) {
361 auto result = dyn_cast<ConstantFloatOp>(builder.
create(state));
362 assert(
result &&
"builder didn't return the right type");
368 const APFloat &
value) {
373 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
374 return llvm::isa<FloatType>(constOp.getType());
389 auto result = dyn_cast<ConstantIndexOp>(builder.
create(state));
390 assert(
result &&
"builder didn't return the right type");
400 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
401 return constOp.getType().isIndex();
409 "type doesn't have a zero representation");
411 assert(zeroAttr &&
"unsupported type for zero attribute");
412 return arith::ConstantOp::create(builder, loc, zeroAttr);
425 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
426 if (getRhs() == sub.getRhs())
430 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
431 if (getLhs() == sub.getRhs())
435 adaptor.getOperands(),
436 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
441 patterns.
add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
442 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
449std::optional<SmallVector<int64_t, 4>>
450arith::AddUIExtendedOp::getShapeForUnroll() {
451 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
452 return llvm::to_vector<4>(vt.getShape());
459 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
463arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
464 SmallVectorImpl<OpFoldResult> &results) {
465 Type overflowTy = getOverflow().getType();
471 results.push_back(getLhs());
472 results.push_back(falseValue);
481 adaptor.getOperands(),
482 [](APInt a,
const APInt &
b) { return std::move(a) + b; })) {
485 results.push_back(sumAttr);
486 results.push_back(sumAttr);
490 ArrayRef({sumAttr, adaptor.getLhs()}),
496 results.push_back(sumAttr);
497 results.push_back(overflowAttr);
504void arith::AddUIExtendedOp::getCanonicalizationPatterns(
505 RewritePatternSet &patterns, MLIRContext *context) {
506 patterns.
add<AddUIExtendedToAddI>(context);
513std::optional<SmallVector<int64_t, 4>>
514arith::SubUIExtendedOp::getShapeForUnroll() {
515 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
516 return llvm::to_vector<4>(vt.getShape());
523 return lhs.ult(
rhs) ? APInt::getAllOnes(1) : APInt::getZero(1);
527arith::SubUIExtendedOp::fold(FoldAdaptor adaptor,
528 SmallVectorImpl<OpFoldResult> &results) {
529 Type borrowTy = getBorrow().getType();
535 results.push_back(getLhs());
536 results.push_back(falseValue);
541 if (getLhs() == getRhs()) {
548 results.push_back(zeroDiff);
549 results.push_back(falseValue);
555 adaptor.getOperands(),
556 [](APInt a,
const APInt &
b) { return std::move(a) - b; })) {
559 results.push_back(diffAttr);
560 results.push_back(diffAttr);
564 adaptor.getOperands(),
570 results.push_back(diffAttr);
571 results.push_back(borrowAttr);
578void arith::SubUIExtendedOp::getCanonicalizationPatterns(
579 RewritePatternSet &patterns, MLIRContext *context) {
580 patterns.
add<SubUIExtendedToSubI>(context);
587OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
589 if (getOperand(0) == getOperand(1)) {
590 auto shapedType = dyn_cast<ShapedType>(
getType());
592 if (!shapedType || shapedType.hasStaticShape())
599 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
601 if (getRhs() ==
add.getRhs())
604 if (getRhs() ==
add.getLhs())
609 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
610 if (getLhs() == sub.getLhs())
614 adaptor.getOperands(),
615 [](APInt a,
const APInt &
b) { return std::move(a) - b; });
618void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
619 MLIRContext *context) {
620 patterns.
add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
621 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
622 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
629OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
640 adaptor.getOperands(),
641 [](
const APInt &a,
const APInt &
b) { return a * b; });
644void arith::MulIOp::getAsmResultNames(
646 if (!isa<IndexType>(
getType()))
651 auto isVscale = [](Operation *op) {
652 return op && op->getName().getStringRef() ==
"vector.vscale";
655 IntegerAttr baseValue;
656 auto isVscaleExpr = [&](Value a, Value
b) {
658 isVscale(
b.getDefiningOp());
661 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
665 SmallString<32> specialNameBuffer;
666 llvm::raw_svector_ostream specialName(specialNameBuffer);
667 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
668 setNameFn(getResult(), specialName.str());
671void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
672 MLIRContext *context) {
673 patterns.
add<MulIMulIConstant>(context);
680std::optional<SmallVector<int64_t, 4>>
681arith::MulSIExtendedOp::getShapeForUnroll() {
682 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
683 return llvm::to_vector<4>(vt.getShape());
688arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
689 SmallVectorImpl<OpFoldResult> &results) {
692 Attribute zero = adaptor.getRhs();
693 results.push_back(zero);
694 results.push_back(zero);
700 adaptor.getOperands(),
701 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
704 llvm::APIntOps::mulhs);
705 assert(highAttr &&
"Unexpected constant-folding failure");
707 results.push_back(lowAttr);
708 results.push_back(highAttr);
715void arith::MulSIExtendedOp::getCanonicalizationPatterns(
716 RewritePatternSet &patterns, MLIRContext *context) {
717 patterns.
add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
724std::optional<SmallVector<int64_t, 4>>
725arith::MulUIExtendedOp::getShapeForUnroll() {
726 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
727 return llvm::to_vector<4>(vt.getShape());
732arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
733 SmallVectorImpl<OpFoldResult> &results) {
736 Attribute zero = adaptor.getRhs();
737 results.push_back(zero);
738 results.push_back(zero);
746 results.push_back(getLhs());
747 results.push_back(zero);
753 adaptor.getOperands(),
754 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
757 llvm::APIntOps::mulhu);
758 assert(highAttr &&
"Unexpected constant-folding failure");
760 results.push_back(lowAttr);
761 results.push_back(highAttr);
768void arith::MulUIExtendedOp::getCanonicalizationPatterns(
769 RewritePatternSet &patterns, MLIRContext *context) {
770 patterns.
add<MulUIExtendedToMulI>(context);
779 arith::IntegerOverflowFlags ovfFlags) {
780 auto mul =
lhs.getDefiningOp<mlir::arith::MulIOp>();
781 if (!
mul || !bitEnumContainsAll(
mul.getOverflowFlags(), ovfFlags))
793OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
799 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
805 [&](APInt a,
const APInt &
b) {
813 return div0 ? Attribute() :
result;
833OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
839 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
843 bool overflowOrDiv0 =
false;
845 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
846 if (overflowOrDiv0 || !b) {
847 overflowOrDiv0 = true;
850 return a.sdiv_ov(
b, overflowOrDiv0);
853 return overflowOrDiv0 ? Attribute() :
result;
880 APInt one(a.getBitWidth(), 1,
true);
881 APInt val = a.ssub_ov(one, overflow).sdiv_ov(
b, overflow);
882 return val.sadd_ov(one, overflow);
889OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
894 bool overflowOrDiv0 =
false;
896 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
897 if (overflowOrDiv0 || !b) {
898 overflowOrDiv0 = true;
901 APInt quotient = a.udiv(
b);
904 APInt one(a.getBitWidth(), 1,
true);
905 return quotient.uadd_ov(one, overflowOrDiv0);
908 return overflowOrDiv0 ? Attribute() :
result;
919OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
927 bool overflowOrDiv0 =
false;
929 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
930 if (overflowOrDiv0 || !b) {
931 overflowOrDiv0 = true;
937 unsigned bits = a.getBitWidth();
938 APInt zero = APInt::getZero(bits);
939 bool aGtZero = a.sgt(zero);
940 bool bGtZero =
b.sgt(zero);
941 if (aGtZero && bGtZero) {
948 bool overflowNegA =
false;
949 bool overflowNegB =
false;
950 bool overflowDiv =
false;
951 bool overflowNegRes =
false;
952 if (!aGtZero && !bGtZero) {
954 APInt posA = zero.ssub_ov(a, overflowNegA);
955 APInt posB = zero.ssub_ov(
b, overflowNegB);
957 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
960 if (!aGtZero && bGtZero) {
962 APInt posA = zero.ssub_ov(a, overflowNegA);
963 APInt
div = posA.sdiv_ov(
b, overflowDiv);
964 APInt res = zero.ssub_ov(
div, overflowNegRes);
965 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
969 APInt posB = zero.ssub_ov(
b, overflowNegB);
970 APInt
div = a.sdiv_ov(posB, overflowDiv);
971 APInt res = zero.ssub_ov(
div, overflowNegRes);
973 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
977 return overflowOrDiv0 ? Attribute() :
result;
988OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
994 bool overflowOrDiv =
false;
996 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
998 overflowOrDiv = true;
1001 return a.sfloordiv_ov(
b, overflowOrDiv);
1004 return overflowOrDiv ? Attribute() :
result;
1011OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
1019 [&](APInt a,
const APInt &
b) {
1020 if (div0 || b.isZero()) {
1027 return div0 ? Attribute() :
result;
1038OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
1046 [&](APInt a,
const APInt &
b) {
1047 if (div0 || b.isZero()) {
1054 return div0 ? Attribute() :
result;
1072 for (
bool reversePrev : {
false,
true}) {
1073 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
1074 .getDefiningOp<arith::AndIOp>();
1078 Value other = (reversePrev ? op.getLhs() : op.getRhs());
1079 if (other != prev.getLhs() && other != prev.getRhs())
1082 return prev.getResult();
1087OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
1094 intValue.isAllOnes())
1099 intValue.isAllOnes())
1104 intValue.isAllOnes())
1112 adaptor.getOperands(),
1113 [](APInt a,
const APInt &
b) { return std::move(a) & b; });
1120OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1123 if (rhsVal.isZero())
1126 if (rhsVal.isAllOnes())
1127 return adaptor.getRhs();
1134 intValue.isAllOnes())
1135 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1139 intValue.isAllOnes())
1140 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1143 adaptor.getOperands(),
1144 [](APInt a,
const APInt &
b) { return std::move(a) | b; });
1151OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1156 if (getLhs() == getRhs())
1160 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1161 if (prev.getRhs() == getRhs())
1162 return prev.getLhs();
1163 if (prev.getLhs() == getRhs())
1164 return prev.getRhs();
1168 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1169 if (prev.getRhs() == getLhs())
1170 return prev.getLhs();
1171 if (prev.getLhs() == getLhs())
1172 return prev.getRhs();
1176 adaptor.getOperands(),
1177 [](APInt a,
const APInt &
b) { return std::move(a) ^ b; });
1180void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1181 MLIRContext *context) {
1182 patterns.
add<XOrIXOrIConstant, XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(
1190OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1192 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1193 return op.getOperand();
1195 [](
const APFloat &a) { return -a; });
1202OpFoldResult arith::FlushDenormalsOp::fold(FoldAdaptor adaptor) {
1208 if (
auto op = this->getOperand().getDefiningOp<arith::FlushDenormalsOp>())
1209 return op.getResult();
1213 adaptor.getOperands(), [](
const APFloat &a) {
1215 return APFloat::getZero(a.getSemantics(), a.isNegative());
1224OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1229 auto rm = getRoundingmode();
1231 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1233 result.add(b, convertArithRoundingModeToLLVMIR(rm));
1242OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1247 auto rm = getRoundingmode();
1249 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1251 result.subtract(b, convertArithRoundingModeToLLVMIR(rm));
1256void arith::SubFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1257 MLIRContext *context) {
1258 patterns.
add<SubFOfNegZero>(context);
1265OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1267 if (getLhs() == getRhs())
1281OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1283 if (getLhs() == getRhs())
1297OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1299 if (getLhs() == getRhs())
1305 if (intValue.isMaxSignedValue())
1308 if (intValue.isMinSignedValue())
1313 llvm::APIntOps::smax);
1320OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1322 if (getLhs() == getRhs())
1328 if (intValue.isMaxValue())
1331 if (intValue.isMinValue())
1336 llvm::APIntOps::umax);
1343OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1345 if (getLhs() == getRhs())
1359OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1361 if (getLhs() == getRhs())
1375OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1377 if (getLhs() == getRhs())
1383 if (intValue.isMinSignedValue())
1386 if (intValue.isMaxSignedValue())
1391 llvm::APIntOps::smin);
1398OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1400 if (getLhs() == getRhs())
1406 if (intValue.isMinValue())
1409 if (intValue.isMaxValue())
1414 llvm::APIntOps::umin);
1421OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1426 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1427 arith::FastMathFlags::nsz)) {
1433 auto rm = getRoundingmode();
1435 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1437 result.multiply(b, convertArithRoundingModeToLLVMIR(rm));
1442void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1443 MLIRContext *context) {
1444 patterns.
add<MulFOfNegF>(context);
1451OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1456 auto rm = getRoundingmode();
1458 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1460 result.divide(b, convertArithRoundingModeToLLVMIR(rm));
1465void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1466 MLIRContext *context) {
1467 patterns.
add<DivFOfNegF>(context);
1474OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1476 [](
const APFloat &a,
const APFloat &
b) {
1481 (void)result.mod(b);
1490template <
typename... Types>
1496template <
typename... ShapedTypes,
typename... ElementTypes>
1499 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1503 if (!llvm::isa<ElementTypes...>(underlyingType))
1506 return underlyingType;
1510template <
typename... ElementTypes>
1517template <
typename... ElementTypes>
1526 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1527 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1528 if (!rankedTensorA || !rankedTensorB)
1530 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1534 if (inputs.size() != 1 || outputs.size() != 1)
1546template <
typename ValType,
typename Op>
1551 if (llvm::cast<ValType>(srcType).getWidth() >=
1552 llvm::cast<ValType>(dstType).getWidth())
1554 << dstType <<
" must be wider than operand type " << srcType;
1560template <
typename ValType,
typename Op>
1565 if (llvm::cast<ValType>(srcType).getWidth() <=
1566 llvm::cast<ValType>(dstType).getWidth())
1568 << dstType <<
" must be shorter than operand type " << srcType;
1574template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1579 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1580 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1581 if (!srcType || !dstType)
1584 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1585 srcType.getIntOrFloatBitWidth());
1590static FailureOr<APFloat>
1592 const llvm::fltSemantics &targetSemantics,
1596 using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior;
1597 if (sourceValue.isInfinity() &&
1598 (targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
1599 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly))
1601 if (sourceValue.isNaN() &&
1602 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly)
1605 bool losesInfo =
false;
1606 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1607 if (losesInfo || status != APFloat::opOK)
1617OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1618 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1619 getInMutable().assign(
lhs.getIn());
1624 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1626 adaptor.getOperands(),
getType(),
1627 [bitWidth](
const APInt &a,
bool &castStatus) {
1628 return a.zext(bitWidth);
1636LogicalResult arith::ExtUIOp::verify() {
1644OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1645 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1646 getInMutable().assign(
lhs.getIn());
1651 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1653 adaptor.getOperands(),
getType(),
1654 [bitWidth](
const APInt &a,
bool &castStatus) {
1655 return a.sext(bitWidth);
1663void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1664 MLIRContext *context) {
1665 patterns.
add<ExtSIOfExtUI>(context);
1668LogicalResult arith::ExtSIOp::verify() {
1678OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1679 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1680 if (truncFOp.getOperand().getType() ==
getType()) {
1681 arith::FastMathFlags truncFMF =
1682 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1683 bool isTruncContract =
1684 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1685 arith::FastMathFlags extFMF =
1686 getFastmath().value_or(arith::FastMathFlags::none);
1687 bool isExtContract =
1688 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1689 if (isTruncContract && isExtContract) {
1690 return truncFOp.getOperand();
1696 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1698 adaptor.getOperands(),
getType(),
1699 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1719bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
1724LogicalResult arith::ScalingExtFOp::verify() {
1732OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1735 Value src = getOperand().getDefiningOp()->getOperand(0);
1740 if (llvm::cast<IntegerType>(srcType).getWidth() >
1741 llvm::cast<IntegerType>(dstType).getWidth()) {
1748 if (srcType == dstType)
1754 setOperand(getOperand().getDefiningOp()->getOperand(0));
1759 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1761 adaptor.getOperands(),
getType(),
1762 [bitWidth](
const APInt &a,
bool &castStatus) {
1763 return a.trunc(bitWidth);
1771void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1772 MLIRContext *context) {
1774 .
add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1778LogicalResult arith::TruncIOp::verify() {
1788OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1790 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1791 Value src = extOp.getIn();
1793 auto intermediateType =
1796 if (llvm::APFloatBase::isRepresentableBy(
1797 srcType.getFloatSemantics(),
1798 intermediateType.getFloatSemantics())) {
1800 if (srcType.getWidth() > resElemType.getWidth()) {
1806 if (srcType == resElemType)
1811 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1813 adaptor.getOperands(),
getType(),
1814 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1815 llvm::RoundingMode llvmRoundingMode =
1817 FailureOr<APFloat>
result =
1827void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1828 MLIRContext *context) {
1829 patterns.
add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1836LogicalResult arith::TruncFOp::verify() {
1844OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) {
1846 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1848 adaptor.getOperands(),
getType(),
1849 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1850 llvm::RoundingMode llvmRoundingMode =
1852 FailureOr<APFloat>
result =
1867 if (!srcType || !dstType)
1869 return srcType != dstType &&
1873LogicalResult arith::ConvertFOp::verify() {
1876 if (srcType == dstType)
1877 return emitError(
"result element type ")
1878 << dstType <<
" must be different from operand element type "
1880 if (srcType.getWidth() != dstType.getWidth())
1881 return emitError(
"result element type ")
1882 << dstType <<
" must have the same bitwidth as operand element type "
1891bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
1896LogicalResult arith::ScalingTruncFOp::verify() {
1904void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1905 MLIRContext *context) {
1906 patterns.
add<AndIAndIConstant, AndOfExtUI, AndOfExtSI>(context);
1913void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1914 MLIRContext *context) {
1915 patterns.
add<OrIOrIConstant, OrOfExtUI, OrOfExtSI>(context);
1922template <
typename From,
typename To>
1930 return srcType && dstType;
1941OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1944 adaptor.getOperands(),
getType(),
1945 [&resEleType](
const APInt &a,
bool &castStatus) {
1946 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1947 APFloat apf(floatTy.getFloatSemantics(),
1948 APInt::getZero(floatTy.getWidth()));
1949 apf.convertFromAPInt(a,
false,
1950 APFloat::rmNearestTiesToEven);
1955void arith::UIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1956 MLIRContext *context) {
1957 patterns.
add<UIToFPOfExtUI>(context);
1968OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1971 adaptor.getOperands(),
getType(),
1972 [&resEleType](
const APInt &a,
bool &castStatus) {
1973 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1974 APFloat apf(floatTy.getFloatSemantics(),
1975 APInt::getZero(floatTy.getWidth()));
1976 apf.convertFromAPInt(a,
true,
1977 APFloat::rmNearestTiesToEven);
1982void arith::SIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1983 MLIRContext *context) {
1984 patterns.
add<SIToFPOfExtSI, SIToFPOfExtUI>(context);
1995OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1997 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1999 adaptor.getOperands(),
getType(),
2000 [&bitWidth](
const APFloat &a,
bool &castStatus) {
2002 APSInt api(bitWidth,
true);
2003 castStatus = APFloat::opInvalidOp !=
2004 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
2017OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
2019 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
2021 adaptor.getOperands(),
getType(),
2022 [&bitWidth](
const APFloat &a,
bool &castStatus) {
2024 APSInt api(bitWidth,
false);
2025 castStatus = APFloat::opInvalidOp !=
2026 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
2040 return intTy.getWidth();
2041 return IndexType::kInternalStorageBitWidth;
2050 if (!srcType || !dstType)
2054 (srcType.isSignlessInteger() && dstType.
isIndex());
2057bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
2062OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
2064 unsigned resultBitwidth = 64;
2066 resultBitwidth = intTy.getWidth();
2069 adaptor.getOperands(),
getType(),
2070 [resultBitwidth](
const APInt &a,
bool & ) {
2071 return a.sextOrTrunc(resultBitwidth);
2078 if (
auto inner = getOperand().getDefiningOp<arith::IndexCastOp>()) {
2079 Value x = inner.getOperand();
2088void arith::IndexCastOp::getCanonicalizationPatterns(
2089 RewritePatternSet &patterns, MLIRContext *context) {
2090 patterns.
add<IndexCastOfExtSI>(context);
2097bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
2102OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
2104 unsigned resultBitwidth = 64;
2106 resultBitwidth = intTy.getWidth();
2109 adaptor.getOperands(),
getType(),
2110 [resultBitwidth](
const APInt &a,
bool & ) {
2111 return a.zextOrTrunc(resultBitwidth);
2118 if (
auto inner = getOperand().getDefiningOp<arith::IndexCastUIOp>()) {
2119 Value x = inner.getOperand();
2128void arith::IndexCastUIOp::getCanonicalizationPatterns(
2129 RewritePatternSet &patterns, MLIRContext *context) {
2130 patterns.
add<IndexCastUIOfExtUI>(context);
2143 if (!srcType || !dstType)
2149OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
2151 auto operand = adaptor.getIn();
2156 if (
auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
2157 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
2159 if (llvm::isa<ShapedType>(resType))
2167 APInt bits = llvm::isa<FloatAttr>(operand)
2168 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
2169 : llvm::cast<IntegerAttr>(operand).getValue();
2171 "trying to fold on broken IR: operands have incompatible types");
2173 if (
auto resFloatType = dyn_cast<FloatType>(resType))
2174 return FloatAttr::get(resType,
2175 APFloat(resFloatType.getFloatSemantics(), bits));
2176 return IntegerAttr::get(resType, bits);
2179void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2180 MLIRContext *context) {
2181 patterns.
add<BitcastOfBitcast>(context);
2191 const APInt &
lhs,
const APInt &
rhs) {
2192 switch (predicate) {
2193 case arith::CmpIPredicate::eq:
2195 case arith::CmpIPredicate::ne:
2197 case arith::CmpIPredicate::slt:
2199 case arith::CmpIPredicate::sle:
2201 case arith::CmpIPredicate::sgt:
2203 case arith::CmpIPredicate::sge:
2205 case arith::CmpIPredicate::ult:
2207 case arith::CmpIPredicate::ule:
2209 case arith::CmpIPredicate::ugt:
2211 case arith::CmpIPredicate::uge:
2214 llvm_unreachable(
"unknown cmpi predicate kind");
2219 switch (predicate) {
2220 case arith::CmpIPredicate::eq:
2221 case arith::CmpIPredicate::sle:
2222 case arith::CmpIPredicate::sge:
2223 case arith::CmpIPredicate::ule:
2224 case arith::CmpIPredicate::uge:
2226 case arith::CmpIPredicate::ne:
2227 case arith::CmpIPredicate::slt:
2228 case arith::CmpIPredicate::sgt:
2229 case arith::CmpIPredicate::ult:
2230 case arith::CmpIPredicate::ugt:
2233 llvm_unreachable(
"unknown cmpi predicate kind");
2237 if (
auto intType = dyn_cast<IntegerType>(t)) {
2238 return intType.getWidth();
2240 if (
auto vectorIntType = dyn_cast<VectorType>(t)) {
2241 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
2243 return std::nullopt;
2246OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2248 if (getLhs() == getRhs()) {
2254 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2256 std::optional<int64_t> integerWidth =
2258 if (integerWidth && integerWidth.value() == 1 &&
2259 getPredicate() == arith::CmpIPredicate::ne)
2260 return extOp.getOperand();
2262 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2264 std::optional<int64_t> integerWidth =
2266 if (integerWidth && integerWidth.value() == 1 &&
2267 getPredicate() == arith::CmpIPredicate::ne)
2268 return extOp.getOperand();
2273 getPredicate() == arith::CmpIPredicate::ne)
2280 getPredicate() == arith::CmpIPredicate::eq)
2285 if (adaptor.getLhs() && !adaptor.getRhs()) {
2287 using Pred = CmpIPredicate;
2288 const std::pair<Pred, Pred> invPreds[] = {
2289 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2290 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2291 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2292 {Pred::ne, Pred::ne},
2294 Pred origPred = getPredicate();
2295 for (
auto pred : invPreds) {
2296 if (origPred == pred.first) {
2297 setPredicate(pred.second);
2298 Value
lhs = getLhs();
2299 Value
rhs = getRhs();
2300 getLhsMutable().assign(
rhs);
2301 getRhsMutable().assign(
lhs);
2305 llvm_unreachable(
"unknown cmpi predicate kind");
2310 if (
auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2313 [pred = getPredicate()](
const APInt &
lhs,
const APInt &
rhs) {
2322void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2323 MLIRContext *context) {
2324 patterns.
insert<CmpIExtSI, CmpIExtUI>(context);
2334 const APFloat &
lhs,
const APFloat &
rhs) {
2335 auto cmpResult =
lhs.compare(
rhs);
2336 switch (predicate) {
2337 case arith::CmpFPredicate::AlwaysFalse:
2339 case arith::CmpFPredicate::OEQ:
2340 return cmpResult == APFloat::cmpEqual;
2341 case arith::CmpFPredicate::OGT:
2342 return cmpResult == APFloat::cmpGreaterThan;
2343 case arith::CmpFPredicate::OGE:
2344 return cmpResult == APFloat::cmpGreaterThan ||
2345 cmpResult == APFloat::cmpEqual;
2346 case arith::CmpFPredicate::OLT:
2347 return cmpResult == APFloat::cmpLessThan;
2348 case arith::CmpFPredicate::OLE:
2349 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2350 case arith::CmpFPredicate::ONE:
2351 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2352 case arith::CmpFPredicate::ORD:
2353 return cmpResult != APFloat::cmpUnordered;
2354 case arith::CmpFPredicate::UEQ:
2355 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2356 case arith::CmpFPredicate::UGT:
2357 return cmpResult == APFloat::cmpUnordered ||
2358 cmpResult == APFloat::cmpGreaterThan;
2359 case arith::CmpFPredicate::UGE:
2360 return cmpResult == APFloat::cmpUnordered ||
2361 cmpResult == APFloat::cmpGreaterThan ||
2362 cmpResult == APFloat::cmpEqual;
2363 case arith::CmpFPredicate::ULT:
2364 return cmpResult == APFloat::cmpUnordered ||
2365 cmpResult == APFloat::cmpLessThan;
2366 case arith::CmpFPredicate::ULE:
2367 return cmpResult == APFloat::cmpUnordered ||
2368 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2369 case arith::CmpFPredicate::UNE:
2370 return cmpResult != APFloat::cmpEqual;
2371 case arith::CmpFPredicate::UNO:
2372 return cmpResult == APFloat::cmpUnordered;
2373 case arith::CmpFPredicate::AlwaysTrue:
2376 llvm_unreachable(
"unknown cmpf predicate kind");
2380 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2381 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2384 if (
lhs &&
lhs.getValue().isNaN())
2386 if (
rhs &&
rhs.getValue().isNaN())
2402 using namespace arith;
2404 case CmpFPredicate::UEQ:
2405 case CmpFPredicate::OEQ:
2406 return CmpIPredicate::eq;
2407 case CmpFPredicate::UGT:
2408 case CmpFPredicate::OGT:
2409 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2410 case CmpFPredicate::UGE:
2411 case CmpFPredicate::OGE:
2412 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2413 case CmpFPredicate::ULT:
2414 case CmpFPredicate::OLT:
2415 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2416 case CmpFPredicate::ULE:
2417 case CmpFPredicate::OLE:
2418 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2419 case CmpFPredicate::UNE:
2420 case CmpFPredicate::ONE:
2421 return CmpIPredicate::ne;
2423 llvm_unreachable(
"Unexpected predicate!");
2433 const APFloat &
rhs = flt.getValue();
2441 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2442 int mantissaWidth = floatTy.getFPMantissaWidth();
2443 if (mantissaWidth <= 0)
2449 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2451 intVal = si.getIn();
2452 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2454 intVal = ui.getIn();
2461 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2462 auto intWidth = intTy.getWidth();
2465 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2470 if ((
int)intWidth > mantissaWidth) {
2472 int exponent = ilogb(
rhs);
2473 if (exponent == APFloat::IEK_Inf) {
2474 int maxExponent = ilogb(APFloat::getLargest(
rhs.getSemantics()));
2475 if (maxExponent < (
int)valueBits) {
2482 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2491 switch (op.getPredicate()) {
2492 case CmpFPredicate::ORD:
2497 case CmpFPredicate::UNO:
2510 APFloat signedMax(
rhs.getSemantics());
2511 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2512 APFloat::rmNearestTiesToEven);
2513 if (signedMax <
rhs) {
2514 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2515 pred == CmpIPredicate::sle)
2526 APFloat unsignedMax(
rhs.getSemantics());
2527 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2528 APFloat::rmNearestTiesToEven);
2529 if (unsignedMax <
rhs) {
2530 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2531 pred == CmpIPredicate::ule)
2543 APFloat signedMin(
rhs.getSemantics());
2544 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2545 APFloat::rmNearestTiesToEven);
2546 if (signedMin >
rhs) {
2547 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2548 pred == CmpIPredicate::sge)
2558 APFloat unsignedMin(
rhs.getSemantics());
2559 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2560 APFloat::rmNearestTiesToEven);
2561 if (unsignedMin >
rhs) {
2562 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2563 pred == CmpIPredicate::uge)
2578 APSInt rhsInt(intWidth, isUnsigned);
2579 if (APFloat::opInvalidOp ==
2580 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2586 if (!
rhs.isZero()) {
2587 APFloat apf(floatTy.getFloatSemantics(),
2588 APInt::getZero(floatTy.getWidth()));
2589 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2591 bool equal = apf ==
rhs;
2597 case CmpIPredicate::ne:
2601 case CmpIPredicate::eq:
2605 case CmpIPredicate::ule:
2608 if (
rhs.isNegative()) {
2614 case CmpIPredicate::sle:
2617 if (
rhs.isNegative())
2618 pred = CmpIPredicate::slt;
2620 case CmpIPredicate::ult:
2623 if (
rhs.isNegative()) {
2628 pred = CmpIPredicate::ule;
2630 case CmpIPredicate::slt:
2633 if (!
rhs.isNegative())
2634 pred = CmpIPredicate::sle;
2636 case CmpIPredicate::ugt:
2639 if (
rhs.isNegative()) {
2645 case CmpIPredicate::sgt:
2648 if (
rhs.isNegative())
2649 pred = CmpIPredicate::sge;
2651 case CmpIPredicate::uge:
2654 if (
rhs.isNegative()) {
2659 pred = CmpIPredicate::ugt;
2661 case CmpIPredicate::sge:
2664 if (!
rhs.isNegative())
2665 pred = CmpIPredicate::sgt;
2675 ConstantOp::create(rewriter, op.getLoc(), intVal.
getType(),
2681void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2682 MLIRContext *context) {
2683 patterns.
insert<CmpFIntToFPConst>(context);
2697 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2713 arith::XOrIOp::create(
2714 rewriter, op.getLoc(), op.getCondition(),
2716 op.getCondition().
getType(), 1)));
2724void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2725 MLIRContext *context) {
2726 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2727 SelectI1ToNot, SelectToExtUI>(context);
2730OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2731 Value trueVal = getTrueValue();
2732 Value falseVal = getFalseValue();
2733 if (trueVal == falseVal)
2736 Value condition = getCondition();
2754 if (
getType().isSignlessInteger(1) &&
2760 auto pred = cmp.getPredicate();
2761 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2762 auto cmpLhs = cmp.getLhs();
2763 auto cmpRhs = cmp.getRhs();
2771 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2772 (cmpRhs == trueVal && cmpLhs == falseVal))
2773 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2780 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2782 assert(cond.getType().hasStaticShape() &&
2783 "DenseElementsAttr must have static shape");
2785 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2787 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2788 SmallVector<Attribute> results;
2789 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2790 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2791 cond.value_end<BoolAttr>());
2792 auto lhsVals = llvm::make_range(
lhs.value_begin<Attribute>(),
2793 lhs.value_end<Attribute>());
2794 auto rhsVals = llvm::make_range(
rhs.value_begin<Attribute>(),
2795 rhs.value_end<Attribute>());
2797 for (
auto [condVal, lhsVal, rhsVal] :
2798 llvm::zip_equal(condVals, lhsVals, rhsVals))
2799 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2809ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &
result) {
2810 Type conditionType, resultType;
2811 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2819 conditionType = resultType;
2826 result.addTypes(resultType);
2828 {conditionType, resultType, resultType},
2832void arith::SelectOp::print(OpAsmPrinter &p) {
2833 p <<
" " << getOperands();
2836 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().
getType()))
2837 p << condType <<
", ";
2841LogicalResult arith::SelectOp::verify() {
2842 Type conditionType = getCondition().getType();
2849 if (!llvm::isa<TensorType, VectorType>(resultType))
2850 return emitOpError() <<
"expected condition to be a signless i1, but got "
2853 if (conditionType != shapedConditionType) {
2854 return emitOpError() <<
"expected condition type to have the same shape "
2855 "as the result type, expected "
2856 << shapedConditionType <<
", but got "
2865OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2870 bool bounded =
false;
2872 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2873 bounded = b.ult(b.getBitWidth());
2876 return bounded ?
result : Attribute();
2883OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2888 bool bounded =
false;
2890 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2891 bounded = b.ult(b.getBitWidth());
2894 return bounded ?
result : Attribute();
2901OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2906 bool bounded =
false;
2908 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2909 bounded = b.ult(b.getBitWidth());
2912 return bounded ?
result : Attribute();
2922 bool useOnlyFiniteValue) {
2924 case AtomicRMWKind::maximumf: {
2925 const llvm::fltSemantics &semantic =
2926 llvm::cast<FloatType>(resultType).getFloatSemantics();
2927 APFloat identity = useOnlyFiniteValue
2928 ? APFloat::getLargest(semantic,
true)
2929 : APFloat::getInf(semantic,
true);
2932 case AtomicRMWKind::maxnumf: {
2933 const llvm::fltSemantics &semantic =
2934 llvm::cast<FloatType>(resultType).getFloatSemantics();
2935 APFloat identity = APFloat::getNaN(semantic,
true);
2938 case AtomicRMWKind::addf:
2939 case AtomicRMWKind::addi:
2940 case AtomicRMWKind::maxu:
2941 case AtomicRMWKind::ori:
2942 case AtomicRMWKind::xori:
2944 case AtomicRMWKind::andi:
2947 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2948 case AtomicRMWKind::maxs:
2950 resultType, APInt::getSignedMinValue(
2951 llvm::cast<IntegerType>(resultType).getWidth()));
2952 case AtomicRMWKind::minimumf: {
2953 const llvm::fltSemantics &semantic =
2954 llvm::cast<FloatType>(resultType).getFloatSemantics();
2955 APFloat identity = useOnlyFiniteValue
2956 ? APFloat::getLargest(semantic,
false)
2957 : APFloat::getInf(semantic,
false);
2961 case AtomicRMWKind::minnumf: {
2962 const llvm::fltSemantics &semantic =
2963 llvm::cast<FloatType>(resultType).getFloatSemantics();
2964 APFloat identity = APFloat::getNaN(semantic,
false);
2967 case AtomicRMWKind::mins:
2969 resultType, APInt::getSignedMaxValue(
2970 llvm::cast<IntegerType>(resultType).getWidth()));
2971 case AtomicRMWKind::minu:
2974 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2975 case AtomicRMWKind::muli:
2977 case AtomicRMWKind::mulf:
2989 std::optional<AtomicRMWKind> maybeKind =
2992 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2993 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2994 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2995 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2996 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2997 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2999 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
3000 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
3001 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::xori; })
3002 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
3003 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
3004 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
3005 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
3006 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
3007 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
3008 .Default(std::nullopt);
3010 return std::nullopt;
3013 bool useOnlyFiniteValue =
false;
3014 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
3015 if (fmfOpInterface) {
3016 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
3017 useOnlyFiniteValue =
3018 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
3026 useOnlyFiniteValue);
3032 bool useOnlyFiniteValue) {
3034 useOnlyFiniteValue))
3035 return arith::ConstantOp::create(builder, loc, attr);
3044 case AtomicRMWKind::addf:
3045 return arith::AddFOp::create(builder, loc,
lhs,
rhs);
3046 case AtomicRMWKind::addi:
3047 return arith::AddIOp::create(builder, loc,
lhs,
rhs);
3048 case AtomicRMWKind::mulf:
3049 return arith::MulFOp::create(builder, loc,
lhs,
rhs);
3050 case AtomicRMWKind::muli:
3051 return arith::MulIOp::create(builder, loc,
lhs,
rhs);
3052 case AtomicRMWKind::maximumf:
3053 return arith::MaximumFOp::create(builder, loc,
lhs,
rhs);
3054 case AtomicRMWKind::minimumf:
3055 return arith::MinimumFOp::create(builder, loc,
lhs,
rhs);
3056 case AtomicRMWKind::maxnumf:
3057 return arith::MaxNumFOp::create(builder, loc,
lhs,
rhs);
3058 case AtomicRMWKind::minnumf:
3059 return arith::MinNumFOp::create(builder, loc,
lhs,
rhs);
3060 case AtomicRMWKind::maxs:
3061 return arith::MaxSIOp::create(builder, loc,
lhs,
rhs);
3062 case AtomicRMWKind::mins:
3063 return arith::MinSIOp::create(builder, loc,
lhs,
rhs);
3064 case AtomicRMWKind::maxu:
3065 return arith::MaxUIOp::create(builder, loc,
lhs,
rhs);
3066 case AtomicRMWKind::minu:
3067 return arith::MinUIOp::create(builder, loc,
lhs,
rhs);
3068 case AtomicRMWKind::ori:
3069 return arith::OrIOp::create(builder, loc,
lhs,
rhs);
3070 case AtomicRMWKind::andi:
3071 return arith::AndIOp::create(builder, loc,
lhs,
rhs);
3072 case AtomicRMWKind::xori:
3073 return arith::XOrIOp::create(builder, loc,
lhs,
rhs);
3086#define GET_OP_CLASSES
3087#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
3093#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static constexpr llvm::RoundingMode kDefaultRoundingMode
Default rounding mode according to default LLVM floating-point environment.
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=kDefaultRoundingMode)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(std::optional< RoundingMode > roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static IntegerAttr orIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerAttr andIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
static IntegerAttr xorIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static APInt calculateUnsignedBorrow(const APInt &lhs, const APInt &rhs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static unsigned getIndexCastWidth(Type t)
Return the bit-width of t for the purpose of index_cast width checks.
static LogicalResult verifyTruncateOp(Op op)
static Type getElementType(Type type)
Determine the element type of type.
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Specialization of arith.constant op that returns an integer of index type.
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Specialization of arith.constant op that returns an integer value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
llvm::function_ref< Fn > function_ref
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.