26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
45 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
46 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48 APInt value = binFn(lhsVal, rhsVal);
68 static IntegerOverflowFlagsAttr
70 IntegerOverflowFlagsAttr val2) {
72 val1.getValue() & val2.getValue());
78 case arith::CmpIPredicate::eq:
79 return arith::CmpIPredicate::ne;
80 case arith::CmpIPredicate::ne:
81 return arith::CmpIPredicate::eq;
82 case arith::CmpIPredicate::slt:
83 return arith::CmpIPredicate::sge;
84 case arith::CmpIPredicate::sle:
85 return arith::CmpIPredicate::sgt;
86 case arith::CmpIPredicate::sgt:
87 return arith::CmpIPredicate::sle;
88 case arith::CmpIPredicate::sge:
89 return arith::CmpIPredicate::slt;
90 case arith::CmpIPredicate::ult:
91 return arith::CmpIPredicate::uge;
92 case arith::CmpIPredicate::ule:
93 return arith::CmpIPredicate::ugt;
94 case arith::CmpIPredicate::ugt:
95 return arith::CmpIPredicate::ule;
96 case arith::CmpIPredicate::uge:
97 return arith::CmpIPredicate::ult;
99 llvm_unreachable(
"unknown cmpi predicate kind");
108 static llvm::RoundingMode
110 switch (roundingMode) {
111 case RoundingMode::downward:
112 return llvm::RoundingMode::TowardNegative;
113 case RoundingMode::to_nearest_away:
114 return llvm::RoundingMode::NearestTiesToAway;
115 case RoundingMode::to_nearest_even:
116 return llvm::RoundingMode::NearestTiesToEven;
117 case RoundingMode::toward_zero:
118 return llvm::RoundingMode::TowardZero;
119 case RoundingMode::upward:
120 return llvm::RoundingMode::TowardPositive;
122 llvm_unreachable(
"Unhandled rounding mode");
152 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
163 #include "ArithCanonicalization.inc"
173 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
174 return shapedType.cloneWith(std::nullopt, i1Type);
175 if (llvm::isa<UnrankedTensorType>(type))
184 void arith::ConstantOp::getAsmResultNames(
187 if (
auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188 auto intType = llvm::dyn_cast<IntegerType>(type);
191 if (intType && intType.getWidth() == 1)
192 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
196 llvm::raw_svector_ostream specialName(specialNameBuffer);
197 specialName <<
'c' << intCst.getValue();
199 specialName <<
'_' << type;
200 setNameFn(getResult(), specialName.str());
202 setNameFn(getResult(),
"cst");
211 if (llvm::isa<IntegerType>(type) &&
212 !llvm::cast<IntegerType>(type).isSignless())
213 return emitOpError(
"integer return type must be signless");
215 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
217 "value must be an integer, float, or elements attribute");
223 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
225 "intializing scalable vectors with elements attribute is not supported"
226 " unless it's a vector splat");
230 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
232 auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
233 if (!typedAttr || typedAttr.getType() != type)
236 if (llvm::isa<IntegerType>(type) &&
237 !llvm::cast<IntegerType>(type).isSignless())
240 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
245 if (isBuildableWith(value, type))
246 return builder.
create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
250 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
253 int64_t value,
unsigned width) {
255 arith::ConstantOp::build(builder, result, type,
260 int64_t value,
Type type) {
262 "ConstantIntOp can only have signless integer type values");
263 arith::ConstantOp::build(builder, result, type,
268 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
269 return constOp.getType().isSignlessInteger();
274 const APFloat &value, FloatType type) {
275 arith::ConstantOp::build(builder, result, type,
280 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
281 return llvm::isa<FloatType>(constOp.getType());
287 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
292 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
293 return constOp.getType().isIndex();
307 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
308 if (getRhs() == sub.getRhs())
312 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
313 if (getLhs() == sub.getRhs())
316 return constFoldBinaryOp<IntegerAttr>(
317 adaptor.getOperands(),
318 [](APInt a,
const APInt &b) { return std::move(a) + b; });
323 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
324 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
331 std::optional<SmallVector<int64_t, 4>>
332 arith::AddUIExtendedOp::getShapeForUnroll() {
333 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
334 return llvm::to_vector<4>(vt.getShape());
341 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
345 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
347 Type overflowTy = getOverflow().getType();
353 results.push_back(getLhs());
354 results.push_back(falseValue);
362 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
363 adaptor.getOperands(),
364 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
365 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
366 ArrayRef({sumAttr, adaptor.getLhs()}),
372 results.push_back(sumAttr);
373 results.push_back(overflowAttr);
380 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
382 patterns.add<AddUIExtendedToAddI>(context);
391 if (getOperand(0) == getOperand(1)) {
392 auto shapedType = dyn_cast<ShapedType>(
getType());
394 if (!shapedType || shapedType.hasStaticShape())
401 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
403 if (getRhs() == add.getRhs())
406 if (getRhs() == add.getLhs())
410 return constFoldBinaryOp<IntegerAttr>(
411 adaptor.getOperands(),
412 [](APInt a,
const APInt &b) { return std::move(a) - b; });
417 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
418 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
419 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
436 return constFoldBinaryOp<IntegerAttr>(
437 adaptor.getOperands(),
438 [](
const APInt &a,
const APInt &b) { return a * b; });
441 void arith::MulIOp::getAsmResultNames(
443 if (!isa<IndexType>(
getType()))
449 return op && op->getName().getStringRef() ==
"vector.vscale";
452 IntegerAttr baseValue;
455 isVscale(b.getDefiningOp());
458 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
463 llvm::raw_svector_ostream specialName(specialNameBuffer);
464 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
465 setNameFn(getResult(), specialName.str());
470 patterns.add<MulIMulIConstant>(context);
477 std::optional<SmallVector<int64_t, 4>>
478 arith::MulSIExtendedOp::getShapeForUnroll() {
479 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
480 return llvm::to_vector<4>(vt.getShape());
485 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
490 results.push_back(zero);
491 results.push_back(zero);
496 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
497 adaptor.getOperands(),
498 [](
const APInt &a,
const APInt &b) { return a * b; })) {
500 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
501 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
502 return llvm::APIntOps::mulhs(a, b);
504 assert(highAttr &&
"Unexpected constant-folding failure");
506 results.push_back(lowAttr);
507 results.push_back(highAttr);
514 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
516 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
523 std::optional<SmallVector<int64_t, 4>>
524 arith::MulUIExtendedOp::getShapeForUnroll() {
525 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
526 return llvm::to_vector<4>(vt.getShape());
531 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
536 results.push_back(zero);
537 results.push_back(zero);
545 results.push_back(getLhs());
546 results.push_back(zero);
551 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
552 adaptor.getOperands(),
553 [](
const APInt &a,
const APInt &b) { return a * b; })) {
555 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
556 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
557 return llvm::APIntOps::mulhu(a, b);
559 assert(highAttr &&
"Unexpected constant-folding failure");
561 results.push_back(lowAttr);
562 results.push_back(highAttr);
569 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
571 patterns.add<MulUIExtendedToMulI>(context);
580 arith::IntegerOverflowFlags ovfFlags) {
582 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
585 if (mul.getLhs() == rhs)
588 if (mul.getRhs() == rhs)
594 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
600 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
605 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
606 [&](APInt a,
const APInt &b) {
634 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
640 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
644 bool overflowOrDiv0 =
false;
645 auto result = constFoldBinaryOp<IntegerAttr>(
646 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
647 if (overflowOrDiv0 || !b) {
648 overflowOrDiv0 = true;
651 return a.sdiv_ov(b, overflowOrDiv0);
654 return overflowOrDiv0 ?
Attribute() : result;
681 APInt one(a.getBitWidth(), 1,
true);
682 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
683 return val.sadd_ov(one, overflow);
690 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
695 bool overflowOrDiv0 =
false;
696 auto result = constFoldBinaryOp<IntegerAttr>(
697 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
698 if (overflowOrDiv0 || !b) {
699 overflowOrDiv0 = true;
702 APInt quotient = a.udiv(b);
705 APInt one(a.getBitWidth(), 1,
true);
706 return quotient.uadd_ov(one, overflowOrDiv0);
709 return overflowOrDiv0 ?
Attribute() : result;
720 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
728 bool overflowOrDiv0 =
false;
729 auto result = constFoldBinaryOp<IntegerAttr>(
730 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
731 if (overflowOrDiv0 || !b) {
732 overflowOrDiv0 = true;
738 unsigned bits = a.getBitWidth();
740 bool aGtZero = a.sgt(zero);
741 bool bGtZero = b.sgt(zero);
742 if (aGtZero && bGtZero) {
749 bool overflowNegA =
false;
750 bool overflowNegB =
false;
751 bool overflowDiv =
false;
752 bool overflowNegRes =
false;
753 if (!aGtZero && !bGtZero) {
755 APInt posA = zero.ssub_ov(a, overflowNegA);
756 APInt posB = zero.ssub_ov(b, overflowNegB);
758 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
761 if (!aGtZero && bGtZero) {
763 APInt posA = zero.ssub_ov(a, overflowNegA);
764 APInt div = posA.sdiv_ov(b, overflowDiv);
765 APInt res = zero.ssub_ov(div, overflowNegRes);
766 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
770 APInt posB = zero.ssub_ov(b, overflowNegB);
771 APInt div = a.sdiv_ov(posB, overflowDiv);
772 APInt res = zero.ssub_ov(div, overflowNegRes);
774 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
778 return overflowOrDiv0 ?
Attribute() : result;
789 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
795 bool overflowOrDiv =
false;
796 auto result = constFoldBinaryOp<IntegerAttr>(
797 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
799 overflowOrDiv = true;
802 return a.sfloordiv_ov(b, overflowOrDiv);
805 return overflowOrDiv ?
Attribute() : result;
812 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
819 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
820 [&](APInt a,
const APInt &b) {
821 if (div0 || b.isZero()) {
835 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
842 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
843 [&](APInt a,
const APInt &b) {
844 if (div0 || b.isZero()) {
860 for (
bool reversePrev : {
false,
true}) {
861 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
862 .getDefiningOp<arith::AndIOp>();
866 Value other = (reversePrev ? op.getLhs() : op.getRhs());
867 if (other != prev.getLhs() && other != prev.getRhs())
870 return prev.getResult();
882 intValue.isAllOnes())
887 intValue.isAllOnes())
892 intValue.isAllOnes())
899 return constFoldBinaryOp<IntegerAttr>(
900 adaptor.getOperands(),
901 [](APInt a,
const APInt &b) { return std::move(a) & b; });
914 if (rhsVal.isAllOnes())
915 return adaptor.getRhs();
922 intValue.isAllOnes())
923 return getRhs().getDefiningOp<XOrIOp>().getRhs();
927 intValue.isAllOnes())
928 return getLhs().getDefiningOp<XOrIOp>().getRhs();
930 return constFoldBinaryOp<IntegerAttr>(
931 adaptor.getOperands(),
932 [](APInt a,
const APInt &b) { return std::move(a) | b; });
944 if (getLhs() == getRhs())
948 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
949 if (prev.getRhs() == getRhs())
950 return prev.getLhs();
951 if (prev.getLhs() == getRhs())
952 return prev.getRhs();
956 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
957 if (prev.getRhs() == getLhs())
958 return prev.getLhs();
959 if (prev.getLhs() == getLhs())
960 return prev.getRhs();
963 return constFoldBinaryOp<IntegerAttr>(
964 adaptor.getOperands(),
965 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
970 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
979 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
980 return op.getOperand();
981 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
982 [](
const APFloat &a) { return -a; });
994 return constFoldBinaryOp<FloatAttr>(
995 adaptor.getOperands(),
996 [](
const APFloat &a,
const APFloat &b) { return a + b; });
1003 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1008 return constFoldBinaryOp<FloatAttr>(
1009 adaptor.getOperands(),
1010 [](
const APFloat &a,
const APFloat &b) { return a - b; });
1017 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1019 if (getLhs() == getRhs())
1026 return constFoldBinaryOp<FloatAttr>(
1027 adaptor.getOperands(),
1028 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
1035 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1037 if (getLhs() == getRhs())
1044 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1053 if (getLhs() == getRhs())
1059 if (intValue.isMaxSignedValue())
1062 if (intValue.isMinSignedValue())
1066 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1067 [](
const APInt &a,
const APInt &b) {
1068 return llvm::APIntOps::smax(a, b);
1078 if (getLhs() == getRhs())
1084 if (intValue.isMaxValue())
1087 if (intValue.isMinValue())
1091 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1092 [](
const APInt &a,
const APInt &b) {
1093 return llvm::APIntOps::umax(a, b);
1101 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1103 if (getLhs() == getRhs())
1110 return constFoldBinaryOp<FloatAttr>(
1111 adaptor.getOperands(),
1112 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
1119 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1121 if (getLhs() == getRhs())
1128 return constFoldBinaryOp<FloatAttr>(
1129 adaptor.getOperands(),
1130 [](
const APFloat &a,
const APFloat &b) { return llvm::minnum(a, b); });
1139 if (getLhs() == getRhs())
1145 if (intValue.isMinSignedValue())
1148 if (intValue.isMaxSignedValue())
1152 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1153 [](
const APInt &a,
const APInt &b) {
1154 return llvm::APIntOps::smin(a, b);
1164 if (getLhs() == getRhs())
1170 if (intValue.isMinValue())
1173 if (intValue.isMaxValue())
1177 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1178 [](
const APInt &a,
const APInt &b) {
1179 return llvm::APIntOps::umin(a, b);
1187 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1192 return constFoldBinaryOp<FloatAttr>(
1193 adaptor.getOperands(),
1194 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1206 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1211 return constFoldBinaryOp<FloatAttr>(
1212 adaptor.getOperands(),
1213 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1225 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1226 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1227 [](
const APFloat &a,
const APFloat &b) {
1232 (void)result.mod(b);
1241 template <
typename... Types>
1247 template <
typename... ShapedTypes,
typename... ElementTypes>
1250 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1254 if (!llvm::isa<ElementTypes...>(underlyingType))
1257 return underlyingType;
1261 template <
typename... ElementTypes>
1268 template <
typename... ElementTypes>
1277 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1278 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1279 if (!rankedTensorA || !rankedTensorB)
1281 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1285 if (inputs.size() != 1 || outputs.size() != 1)
1297 template <
typename ValType,
typename Op>
1302 if (llvm::cast<ValType>(srcType).getWidth() >=
1303 llvm::cast<ValType>(dstType).getWidth())
1305 << dstType <<
" must be wider than operand type " << srcType;
1311 template <
typename ValType,
typename Op>
1316 if (llvm::cast<ValType>(srcType).getWidth() <=
1317 llvm::cast<ValType>(dstType).getWidth())
1319 << dstType <<
" must be shorter than operand type " << srcType;
1325 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1330 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1331 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1332 if (!srcType || !dstType)
1335 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1336 srcType.getIntOrFloatBitWidth());
1342 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1343 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1344 bool losesInfo =
false;
1345 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1346 if (losesInfo || status != APFloat::opOK)
1356 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1357 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1358 getInMutable().assign(lhs.getIn());
1363 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1364 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1365 adaptor.getOperands(),
getType(),
1366 [bitWidth](
const APInt &a,
bool &castStatus) {
1367 return a.zext(bitWidth);
1372 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1376 return verifyExtOp<IntegerType>(*
this);
1383 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1384 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1385 getInMutable().assign(lhs.getIn());
1390 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1391 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1392 adaptor.getOperands(),
getType(),
1393 [bitWidth](
const APInt &a,
bool &castStatus) {
1394 return a.sext(bitWidth);
1399 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1404 patterns.add<ExtSIOfExtUI>(context);
1408 return verifyExtOp<IntegerType>(*
this);
1417 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1418 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1419 if (truncFOp.getOperand().getType() ==
getType()) {
1420 arith::FastMathFlags truncFMF =
1421 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1422 bool isTruncContract =
1424 arith::FastMathFlags extFMF =
1425 getFastmath().value_or(arith::FastMathFlags::none);
1426 bool isExtContract =
1428 if (isTruncContract && isExtContract) {
1429 return truncFOp.getOperand();
1435 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1436 return constFoldCastOp<FloatAttr, FloatAttr>(
1437 adaptor.getOperands(),
getType(),
1438 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1440 if (failed(result)) {
1449 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1458 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1459 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1466 if (llvm::cast<IntegerType>(srcType).getWidth() >
1467 llvm::cast<IntegerType>(dstType).getWidth()) {
1474 if (srcType == dstType)
1479 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1480 setOperand(getOperand().getDefiningOp()->getOperand(0));
1485 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1486 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1487 adaptor.getOperands(),
getType(),
1488 [bitWidth](
const APInt &a,
bool &castStatus) {
1489 return a.trunc(bitWidth);
1494 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1499 patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1500 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1505 return verifyTruncateOp<IntegerType>(*
this);
1514 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1516 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1517 Value src = extOp.getIn();
1519 auto intermediateType =
1522 if (llvm::APFloatBase::isRepresentableBy(
1523 srcType.getFloatSemantics(),
1524 intermediateType.getFloatSemantics())) {
1526 if (srcType.getWidth() > resElemType.getWidth()) {
1532 if (srcType == resElemType)
1537 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1538 return constFoldCastOp<FloatAttr, FloatAttr>(
1539 adaptor.getOperands(),
getType(),
1540 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1541 RoundingMode roundingMode =
1542 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1543 llvm::RoundingMode llvmRoundingMode =
1545 FailureOr<APFloat> result =
1547 if (failed(result)) {
1556 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1560 return verifyTruncateOp<FloatType>(*
this);
1569 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1578 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1585 template <
typename From,
typename To>
1590 auto srcType = getTypeIfLike<From>(inputs.front());
1591 auto dstType = getTypeIfLike<To>(outputs.back());
1593 return srcType && dstType;
1601 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1604 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1606 return constFoldCastOp<IntegerAttr, FloatAttr>(
1607 adaptor.getOperands(),
getType(),
1608 [&resEleType](
const APInt &a,
bool &castStatus) {
1609 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1610 APFloat apf(floatTy.getFloatSemantics(),
1612 apf.convertFromAPInt(a,
false,
1613 APFloat::rmNearestTiesToEven);
1623 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1626 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1628 return constFoldCastOp<IntegerAttr, FloatAttr>(
1629 adaptor.getOperands(),
getType(),
1630 [&resEleType](
const APInt &a,
bool &castStatus) {
1631 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1632 APFloat apf(floatTy.getFloatSemantics(),
1634 apf.convertFromAPInt(a,
true,
1635 APFloat::rmNearestTiesToEven);
1645 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1648 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1650 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1651 return constFoldCastOp<FloatAttr, IntegerAttr>(
1652 adaptor.getOperands(),
getType(),
1653 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1655 APSInt api(bitWidth,
true);
1656 castStatus = APFloat::opInvalidOp !=
1657 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1667 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1670 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1672 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1673 return constFoldCastOp<FloatAttr, IntegerAttr>(
1674 adaptor.getOperands(),
getType(),
1675 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1677 APSInt api(bitWidth,
false);
1678 castStatus = APFloat::opInvalidOp !=
1679 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1692 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1693 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1694 if (!srcType || !dstType)
1701 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1706 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1708 unsigned resultBitwidth = 64;
1710 resultBitwidth = intTy.getWidth();
1712 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1713 adaptor.getOperands(),
getType(),
1714 [resultBitwidth](
const APInt &a,
bool & ) {
1715 return a.sextOrTrunc(resultBitwidth);
1719 void arith::IndexCastOp::getCanonicalizationPatterns(
1721 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1728 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1733 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1735 unsigned resultBitwidth = 64;
1737 resultBitwidth = intTy.getWidth();
1739 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1740 adaptor.getOperands(),
getType(),
1741 [resultBitwidth](
const APInt &a,
bool & ) {
1742 return a.zextOrTrunc(resultBitwidth);
1746 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1748 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1759 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1760 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1761 if (!srcType || !dstType)
1767 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1769 auto operand = adaptor.getIn();
1774 if (
auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1775 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1777 if (llvm::isa<ShapedType>(resType))
1781 if (llvm::isa<ub::PoisonAttr>(operand))
1785 APInt bits = llvm::isa<FloatAttr>(operand)
1786 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1787 : llvm::cast<IntegerAttr>(operand).getValue();
1789 "trying to fold on broken IR: operands have incompatible types");
1791 if (
auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1793 APFloat(resFloatType.getFloatSemantics(), bits));
1799 patterns.add<BitcastOfBitcast>(context);
1809 const APInt &lhs,
const APInt &rhs) {
1810 switch (predicate) {
1811 case arith::CmpIPredicate::eq:
1813 case arith::CmpIPredicate::ne:
1815 case arith::CmpIPredicate::slt:
1816 return lhs.slt(rhs);
1817 case arith::CmpIPredicate::sle:
1818 return lhs.sle(rhs);
1819 case arith::CmpIPredicate::sgt:
1820 return lhs.sgt(rhs);
1821 case arith::CmpIPredicate::sge:
1822 return lhs.sge(rhs);
1823 case arith::CmpIPredicate::ult:
1824 return lhs.ult(rhs);
1825 case arith::CmpIPredicate::ule:
1826 return lhs.ule(rhs);
1827 case arith::CmpIPredicate::ugt:
1828 return lhs.ugt(rhs);
1829 case arith::CmpIPredicate::uge:
1830 return lhs.uge(rhs);
1832 llvm_unreachable(
"unknown cmpi predicate kind");
1837 switch (predicate) {
1838 case arith::CmpIPredicate::eq:
1839 case arith::CmpIPredicate::sle:
1840 case arith::CmpIPredicate::sge:
1841 case arith::CmpIPredicate::ule:
1842 case arith::CmpIPredicate::uge:
1844 case arith::CmpIPredicate::ne:
1845 case arith::CmpIPredicate::slt:
1846 case arith::CmpIPredicate::sgt:
1847 case arith::CmpIPredicate::ult:
1848 case arith::CmpIPredicate::ugt:
1851 llvm_unreachable(
"unknown cmpi predicate kind");
1855 if (
auto intType = llvm::dyn_cast<IntegerType>(t)) {
1856 return intType.getWidth();
1858 if (
auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1859 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1861 return std::nullopt;
1864 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1866 if (getLhs() == getRhs()) {
1872 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1874 std::optional<int64_t> integerWidth =
1876 if (integerWidth && integerWidth.value() == 1 &&
1877 getPredicate() == arith::CmpIPredicate::ne)
1878 return extOp.getOperand();
1880 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1882 std::optional<int64_t> integerWidth =
1884 if (integerWidth && integerWidth.value() == 1 &&
1885 getPredicate() == arith::CmpIPredicate::ne)
1886 return extOp.getOperand();
1891 getPredicate() == arith::CmpIPredicate::ne)
1898 getPredicate() == arith::CmpIPredicate::eq)
1903 if (adaptor.getLhs() && !adaptor.getRhs()) {
1905 using Pred = CmpIPredicate;
1906 const std::pair<Pred, Pred> invPreds[] = {
1907 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1908 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1909 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1910 {Pred::ne, Pred::ne},
1912 Pred origPred = getPredicate();
1913 for (
auto pred : invPreds) {
1914 if (origPred == pred.first) {
1915 setPredicate(pred.second);
1916 Value lhs = getLhs();
1917 Value rhs = getRhs();
1918 getLhsMutable().assign(rhs);
1919 getRhsMutable().assign(lhs);
1923 llvm_unreachable(
"unknown cmpi predicate kind");
1928 if (
auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1929 return constFoldBinaryOp<IntegerAttr>(
1931 [pred = getPredicate()](
const APInt &lhs,
const APInt &rhs) {
1942 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1952 const APFloat &lhs,
const APFloat &rhs) {
1953 auto cmpResult = lhs.compare(rhs);
1954 switch (predicate) {
1955 case arith::CmpFPredicate::AlwaysFalse:
1957 case arith::CmpFPredicate::OEQ:
1958 return cmpResult == APFloat::cmpEqual;
1959 case arith::CmpFPredicate::OGT:
1960 return cmpResult == APFloat::cmpGreaterThan;
1961 case arith::CmpFPredicate::OGE:
1962 return cmpResult == APFloat::cmpGreaterThan ||
1963 cmpResult == APFloat::cmpEqual;
1964 case arith::CmpFPredicate::OLT:
1965 return cmpResult == APFloat::cmpLessThan;
1966 case arith::CmpFPredicate::OLE:
1967 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1968 case arith::CmpFPredicate::ONE:
1969 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1970 case arith::CmpFPredicate::ORD:
1971 return cmpResult != APFloat::cmpUnordered;
1972 case arith::CmpFPredicate::UEQ:
1973 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1974 case arith::CmpFPredicate::UGT:
1975 return cmpResult == APFloat::cmpUnordered ||
1976 cmpResult == APFloat::cmpGreaterThan;
1977 case arith::CmpFPredicate::UGE:
1978 return cmpResult == APFloat::cmpUnordered ||
1979 cmpResult == APFloat::cmpGreaterThan ||
1980 cmpResult == APFloat::cmpEqual;
1981 case arith::CmpFPredicate::ULT:
1982 return cmpResult == APFloat::cmpUnordered ||
1983 cmpResult == APFloat::cmpLessThan;
1984 case arith::CmpFPredicate::ULE:
1985 return cmpResult == APFloat::cmpUnordered ||
1986 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1987 case arith::CmpFPredicate::UNE:
1988 return cmpResult != APFloat::cmpEqual;
1989 case arith::CmpFPredicate::UNO:
1990 return cmpResult == APFloat::cmpUnordered;
1991 case arith::CmpFPredicate::AlwaysTrue:
1994 llvm_unreachable(
"unknown cmpf predicate kind");
1997 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1998 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1999 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2002 if (lhs && lhs.getValue().isNaN())
2004 if (rhs && rhs.getValue().isNaN())
2020 using namespace arith;
2022 case CmpFPredicate::UEQ:
2023 case CmpFPredicate::OEQ:
2024 return CmpIPredicate::eq;
2025 case CmpFPredicate::UGT:
2026 case CmpFPredicate::OGT:
2027 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2028 case CmpFPredicate::UGE:
2029 case CmpFPredicate::OGE:
2030 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2031 case CmpFPredicate::ULT:
2032 case CmpFPredicate::OLT:
2033 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2034 case CmpFPredicate::ULE:
2035 case CmpFPredicate::OLE:
2036 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2037 case CmpFPredicate::UNE:
2038 case CmpFPredicate::ONE:
2039 return CmpIPredicate::ne;
2041 llvm_unreachable(
"Unexpected predicate!");
2051 const APFloat &rhs = flt.getValue();
2059 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2060 int mantissaWidth = floatTy.getFPMantissaWidth();
2061 if (mantissaWidth <= 0)
2067 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2069 intVal = si.getIn();
2070 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2072 intVal = ui.getIn();
2079 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2080 auto intWidth = intTy.getWidth();
2083 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2088 if ((
int)intWidth > mantissaWidth) {
2090 int exponent = ilogb(rhs);
2091 if (exponent == APFloat::IEK_Inf) {
2092 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2093 if (maxExponent < (
int)valueBits) {
2100 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2109 switch (op.getPredicate()) {
2110 case CmpFPredicate::ORD:
2115 case CmpFPredicate::UNO:
2128 APFloat signedMax(rhs.getSemantics());
2129 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2130 APFloat::rmNearestTiesToEven);
2131 if (signedMax < rhs) {
2132 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2133 pred == CmpIPredicate::sle)
2144 APFloat unsignedMax(rhs.getSemantics());
2145 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2146 APFloat::rmNearestTiesToEven);
2147 if (unsignedMax < rhs) {
2148 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2149 pred == CmpIPredicate::ule)
2161 APFloat signedMin(rhs.getSemantics());
2162 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2163 APFloat::rmNearestTiesToEven);
2164 if (signedMin > rhs) {
2165 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2166 pred == CmpIPredicate::sge)
2176 APFloat unsignedMin(rhs.getSemantics());
2177 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2178 APFloat::rmNearestTiesToEven);
2179 if (unsignedMin > rhs) {
2180 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2181 pred == CmpIPredicate::uge)
2196 APSInt rhsInt(intWidth, isUnsigned);
2197 if (APFloat::opInvalidOp ==
2198 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2204 if (!rhs.isZero()) {
2205 APFloat apf(floatTy.getFloatSemantics(),
2207 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2209 bool equal = apf == rhs;
2215 case CmpIPredicate::ne:
2219 case CmpIPredicate::eq:
2223 case CmpIPredicate::ule:
2226 if (rhs.isNegative()) {
2232 case CmpIPredicate::sle:
2235 if (rhs.isNegative())
2236 pred = CmpIPredicate::slt;
2238 case CmpIPredicate::ult:
2241 if (rhs.isNegative()) {
2246 pred = CmpIPredicate::ule;
2248 case CmpIPredicate::slt:
2251 if (!rhs.isNegative())
2252 pred = CmpIPredicate::sle;
2254 case CmpIPredicate::ugt:
2257 if (rhs.isNegative()) {
2263 case CmpIPredicate::sgt:
2266 if (rhs.isNegative())
2267 pred = CmpIPredicate::sge;
2269 case CmpIPredicate::uge:
2272 if (rhs.isNegative()) {
2277 pred = CmpIPredicate::ugt;
2279 case CmpIPredicate::sge:
2282 if (!rhs.isNegative())
2283 pred = CmpIPredicate::sgt;
2293 rewriter.
create<ConstantOp>(
2294 op.getLoc(), intVal.
getType(),
2316 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2332 rewriter.
create<arith::XOrIOp>(
2333 op.getLoc(), op.getCondition(),
2334 rewriter.
create<arith::ConstantIntOp>(
2335 op.getLoc(), 1, op.getCondition().getType())));
2345 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2349 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2350 Value trueVal = getTrueValue();
2351 Value falseVal = getFalseValue();
2352 if (trueVal == falseVal)
2355 Value condition = getCondition();
2366 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2369 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2373 if (
getType().isSignlessInteger(1) &&
2378 if (
auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.
getDefiningOp())) {
2379 auto pred = cmp.getPredicate();
2380 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2381 auto cmpLhs = cmp.getLhs();
2382 auto cmpRhs = cmp.getRhs();
2390 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2391 (cmpRhs == trueVal && cmpLhs == falseVal))
2392 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2399 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2401 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2403 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2405 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2406 auto condVals = llvm::make_range(cond.value_begin<
BoolAttr>(),
2408 auto lhsVals = llvm::make_range(lhs.value_begin<
Attribute>(),
2410 auto rhsVals = llvm::make_range(rhs.value_begin<
Attribute>(),
2413 for (
auto [condVal, lhsVal, rhsVal] :
2414 llvm::zip_equal(condVals, lhsVals, rhsVals))
2415 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2426 Type conditionType, resultType;
2435 conditionType = resultType;
2444 {conditionType, resultType, resultType},
2449 p <<
" " << getOperands();
2452 if (ShapedType condType =
2453 llvm::dyn_cast<ShapedType>(getCondition().
getType()))
2454 p << condType <<
", ";
2459 Type conditionType = getCondition().getType();
2466 if (!llvm::isa<TensorType, VectorType>(resultType))
2467 return emitOpError() <<
"expected condition to be a signless i1, but got "
2470 if (conditionType != shapedConditionType) {
2471 return emitOpError() <<
"expected condition type to have the same shape "
2472 "as the result type, expected "
2473 << shapedConditionType <<
", but got "
2482 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2487 bool bounded =
false;
2488 auto result = constFoldBinaryOp<IntegerAttr>(
2489 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2490 bounded = b.ult(b.getBitWidth());
2500 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2505 bool bounded =
false;
2506 auto result = constFoldBinaryOp<IntegerAttr>(
2507 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2508 bounded = b.ult(b.getBitWidth());
2518 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2523 bool bounded =
false;
2524 auto result = constFoldBinaryOp<IntegerAttr>(
2525 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2526 bounded = b.ult(b.getBitWidth());
2539 bool useOnlyFiniteValue) {
2541 case AtomicRMWKind::maximumf: {
2542 const llvm::fltSemantics &semantic =
2543 llvm::cast<FloatType>(resultType).getFloatSemantics();
2544 APFloat identity = useOnlyFiniteValue
2545 ? APFloat::getLargest(semantic,
true)
2546 : APFloat::getInf(semantic,
true);
2549 case AtomicRMWKind::maxnumf: {
2550 const llvm::fltSemantics &semantic =
2551 llvm::cast<FloatType>(resultType).getFloatSemantics();
2552 APFloat identity = APFloat::getNaN(semantic,
true);
2555 case AtomicRMWKind::addf:
2556 case AtomicRMWKind::addi:
2557 case AtomicRMWKind::maxu:
2558 case AtomicRMWKind::ori:
2560 case AtomicRMWKind::andi:
2563 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2564 case AtomicRMWKind::maxs:
2566 resultType, APInt::getSignedMinValue(
2567 llvm::cast<IntegerType>(resultType).getWidth()));
2568 case AtomicRMWKind::minimumf: {
2569 const llvm::fltSemantics &semantic =
2570 llvm::cast<FloatType>(resultType).getFloatSemantics();
2571 APFloat identity = useOnlyFiniteValue
2572 ? APFloat::getLargest(semantic,
false)
2573 : APFloat::getInf(semantic,
false);
2577 case AtomicRMWKind::minnumf: {
2578 const llvm::fltSemantics &semantic =
2579 llvm::cast<FloatType>(resultType).getFloatSemantics();
2580 APFloat identity = APFloat::getNaN(semantic,
false);
2583 case AtomicRMWKind::mins:
2585 resultType, APInt::getSignedMaxValue(
2586 llvm::cast<IntegerType>(resultType).getWidth()));
2587 case AtomicRMWKind::minu:
2590 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2591 case AtomicRMWKind::muli:
2593 case AtomicRMWKind::mulf:
2605 std::optional<AtomicRMWKind> maybeKind =
2608 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2609 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2610 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2611 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2612 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2613 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2615 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2616 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2617 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::ori; })
2618 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2619 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2620 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2621 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2622 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2623 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2624 .Default([](
Operation *op) {
return std::nullopt; });
2626 return std::nullopt;
2629 bool useOnlyFiniteValue =
false;
2630 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2631 if (fmfOpInterface) {
2632 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2633 useOnlyFiniteValue =
2634 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2642 useOnlyFiniteValue);
2648 bool useOnlyFiniteValue) {
2651 return builder.
create<arith::ConstantOp>(loc, attr);
2659 case AtomicRMWKind::addf:
2660 return builder.
create<arith::AddFOp>(loc, lhs, rhs);
2661 case AtomicRMWKind::addi:
2662 return builder.
create<arith::AddIOp>(loc, lhs, rhs);
2663 case AtomicRMWKind::mulf:
2664 return builder.
create<arith::MulFOp>(loc, lhs, rhs);
2665 case AtomicRMWKind::muli:
2666 return builder.
create<arith::MulIOp>(loc, lhs, rhs);
2667 case AtomicRMWKind::maximumf:
2668 return builder.
create<arith::MaximumFOp>(loc, lhs, rhs);
2669 case AtomicRMWKind::minimumf:
2670 return builder.
create<arith::MinimumFOp>(loc, lhs, rhs);
2671 case AtomicRMWKind::maxnumf:
2672 return builder.
create<arith::MaxNumFOp>(loc, lhs, rhs);
2673 case AtomicRMWKind::minnumf:
2674 return builder.
create<arith::MinNumFOp>(loc, lhs, rhs);
2675 case AtomicRMWKind::maxs:
2676 return builder.
create<arith::MaxSIOp>(loc, lhs, rhs);
2677 case AtomicRMWKind::mins:
2678 return builder.
create<arith::MinSIOp>(loc, lhs, rhs);
2679 case AtomicRMWKind::maxu:
2680 return builder.
create<arith::MaxUIOp>(loc, lhs, rhs);
2681 case AtomicRMWKind::minu:
2682 return builder.
create<arith::MinUIOp>(loc, lhs, rhs);
2683 case AtomicRMWKind::ori:
2684 return builder.
create<arith::OrIOp>(loc, lhs, rhs);
2685 case AtomicRMWKind::andi:
2686 return builder.
create<arith::AndIOp>(loc, lhs, rhs);
2699 #define GET_OP_CLASSES
2700 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2706 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1197::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
Specialization of arith.constant op that returns an integer value.
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)