26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
45 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
46 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48 APInt value = binFn(lhsVal, rhsVal);
68 static IntegerOverflowFlagsAttr
70 IntegerOverflowFlagsAttr val2) {
72 val1.getValue() & val2.getValue());
78 case arith::CmpIPredicate::eq:
79 return arith::CmpIPredicate::ne;
80 case arith::CmpIPredicate::ne:
81 return arith::CmpIPredicate::eq;
82 case arith::CmpIPredicate::slt:
83 return arith::CmpIPredicate::sge;
84 case arith::CmpIPredicate::sle:
85 return arith::CmpIPredicate::sgt;
86 case arith::CmpIPredicate::sgt:
87 return arith::CmpIPredicate::sle;
88 case arith::CmpIPredicate::sge:
89 return arith::CmpIPredicate::slt;
90 case arith::CmpIPredicate::ult:
91 return arith::CmpIPredicate::uge;
92 case arith::CmpIPredicate::ule:
93 return arith::CmpIPredicate::ugt;
94 case arith::CmpIPredicate::ugt:
95 return arith::CmpIPredicate::ule;
96 case arith::CmpIPredicate::uge:
97 return arith::CmpIPredicate::ult;
99 llvm_unreachable(
"unknown cmpi predicate kind");
108 static llvm::RoundingMode
110 switch (roundingMode) {
111 case RoundingMode::downward:
112 return llvm::RoundingMode::TowardNegative;
113 case RoundingMode::to_nearest_away:
114 return llvm::RoundingMode::NearestTiesToAway;
115 case RoundingMode::to_nearest_even:
116 return llvm::RoundingMode::NearestTiesToEven;
117 case RoundingMode::toward_zero:
118 return llvm::RoundingMode::TowardZero;
119 case RoundingMode::upward:
120 return llvm::RoundingMode::TowardPositive;
122 llvm_unreachable(
"Unhandled rounding mode");
152 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
163 #include "ArithCanonicalization.inc"
173 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
174 return shapedType.cloneWith(std::nullopt, i1Type);
175 if (llvm::isa<UnrankedTensorType>(type))
184 void arith::ConstantOp::getAsmResultNames(
187 if (
auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188 auto intType = llvm::dyn_cast<IntegerType>(type);
191 if (intType && intType.getWidth() == 1)
192 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
196 llvm::raw_svector_ostream specialName(specialNameBuffer);
197 specialName <<
'c' << intCst.getValue();
199 specialName <<
'_' << type;
200 setNameFn(getResult(), specialName.str());
202 setNameFn(getResult(),
"cst");
211 if (llvm::isa<IntegerType>(type) &&
212 !llvm::cast<IntegerType>(type).isSignless())
213 return emitOpError(
"integer return type must be signless");
215 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
217 "value must be an integer, float, or elements attribute");
223 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
225 "intializing scalable vectors with elements attribute is not supported"
226 " unless it's a vector splat");
230 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
232 auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
233 if (!typedAttr || typedAttr.getType() != type)
236 if (llvm::isa<IntegerType>(type) &&
237 !llvm::cast<IntegerType>(type).isSignless())
240 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
245 if (isBuildableWith(value, type))
246 return builder.
create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
250 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
253 int64_t value,
unsigned width) {
255 arith::ConstantOp::build(builder, result, type,
260 int64_t value,
Type type) {
262 "ConstantIntOp can only have signless integer type values");
263 arith::ConstantOp::build(builder, result, type,
268 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
269 return constOp.getType().isSignlessInteger();
274 const APFloat &value, FloatType type) {
275 arith::ConstantOp::build(builder, result, type,
280 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
281 return llvm::isa<FloatType>(constOp.getType());
287 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
292 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
293 return constOp.getType().isIndex();
307 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
308 if (getRhs() == sub.getRhs())
312 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
313 if (getLhs() == sub.getRhs())
316 return constFoldBinaryOp<IntegerAttr>(
317 adaptor.getOperands(),
318 [](APInt a,
const APInt &b) { return std::move(a) + b; });
323 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
324 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
331 std::optional<SmallVector<int64_t, 4>>
332 arith::AddUIExtendedOp::getShapeForUnroll() {
333 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
334 return llvm::to_vector<4>(vt.getShape());
341 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
345 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
347 Type overflowTy = getOverflow().getType();
353 results.push_back(getLhs());
354 results.push_back(falseValue);
362 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
363 adaptor.getOperands(),
364 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
365 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
366 ArrayRef({sumAttr, adaptor.getLhs()}),
372 results.push_back(sumAttr);
373 results.push_back(overflowAttr);
380 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
382 patterns.add<AddUIExtendedToAddI>(context);
391 if (getOperand(0) == getOperand(1)) {
392 auto shapedType = dyn_cast<ShapedType>(
getType());
394 if (!shapedType || shapedType.hasStaticShape())
401 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
403 if (getRhs() == add.getRhs())
406 if (getRhs() == add.getLhs())
410 return constFoldBinaryOp<IntegerAttr>(
411 adaptor.getOperands(),
412 [](APInt a,
const APInt &b) { return std::move(a) - b; });
417 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
418 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
419 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
436 return constFoldBinaryOp<IntegerAttr>(
437 adaptor.getOperands(),
438 [](
const APInt &a,
const APInt &b) { return a * b; });
441 void arith::MulIOp::getAsmResultNames(
443 if (!isa<IndexType>(
getType()))
449 return op && op->getName().getStringRef() ==
"vector.vscale";
452 IntegerAttr baseValue;
455 isVscale(b.getDefiningOp());
458 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
463 llvm::raw_svector_ostream specialName(specialNameBuffer);
464 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
465 setNameFn(getResult(), specialName.str());
470 patterns.add<MulIMulIConstant>(context);
477 std::optional<SmallVector<int64_t, 4>>
478 arith::MulSIExtendedOp::getShapeForUnroll() {
479 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
480 return llvm::to_vector<4>(vt.getShape());
485 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
490 results.push_back(zero);
491 results.push_back(zero);
496 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
497 adaptor.getOperands(),
498 [](
const APInt &a,
const APInt &b) { return a * b; })) {
500 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
501 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
502 return llvm::APIntOps::mulhs(a, b);
504 assert(highAttr &&
"Unexpected constant-folding failure");
506 results.push_back(lowAttr);
507 results.push_back(highAttr);
514 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
516 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
523 std::optional<SmallVector<int64_t, 4>>
524 arith::MulUIExtendedOp::getShapeForUnroll() {
525 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
526 return llvm::to_vector<4>(vt.getShape());
531 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
536 results.push_back(zero);
537 results.push_back(zero);
545 results.push_back(getLhs());
546 results.push_back(zero);
551 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
552 adaptor.getOperands(),
553 [](
const APInt &a,
const APInt &b) { return a * b; })) {
555 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
556 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
557 return llvm::APIntOps::mulhu(a, b);
559 assert(highAttr &&
"Unexpected constant-folding failure");
561 results.push_back(lowAttr);
562 results.push_back(highAttr);
569 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
571 patterns.add<MulUIExtendedToMulI>(context);
580 arith::IntegerOverflowFlags ovfFlags) {
582 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
585 if (mul.getLhs() == rhs)
588 if (mul.getRhs() == rhs)
594 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
600 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
605 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
606 [&](APInt a,
const APInt &b) {
634 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
640 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
644 bool overflowOrDiv0 =
false;
645 auto result = constFoldBinaryOp<IntegerAttr>(
646 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
647 if (overflowOrDiv0 || !b) {
648 overflowOrDiv0 = true;
651 return a.sdiv_ov(b, overflowOrDiv0);
654 return overflowOrDiv0 ?
Attribute() : result;
681 APInt one(a.getBitWidth(), 1,
true);
682 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
683 return val.sadd_ov(one, overflow);
690 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
695 bool overflowOrDiv0 =
false;
696 auto result = constFoldBinaryOp<IntegerAttr>(
697 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
698 if (overflowOrDiv0 || !b) {
699 overflowOrDiv0 = true;
702 APInt quotient = a.udiv(b);
705 APInt one(a.getBitWidth(), 1,
true);
706 return quotient.uadd_ov(one, overflowOrDiv0);
709 return overflowOrDiv0 ?
Attribute() : result;
720 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
728 bool overflowOrDiv0 =
false;
729 auto result = constFoldBinaryOp<IntegerAttr>(
730 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
731 if (overflowOrDiv0 || !b) {
732 overflowOrDiv0 = true;
738 unsigned bits = a.getBitWidth();
740 bool aGtZero = a.sgt(zero);
741 bool bGtZero = b.sgt(zero);
742 if (aGtZero && bGtZero) {
749 bool overflowNegA =
false;
750 bool overflowNegB =
false;
751 bool overflowDiv =
false;
752 bool overflowNegRes =
false;
753 if (!aGtZero && !bGtZero) {
755 APInt posA = zero.ssub_ov(a, overflowNegA);
756 APInt posB = zero.ssub_ov(b, overflowNegB);
758 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
761 if (!aGtZero && bGtZero) {
763 APInt posA = zero.ssub_ov(a, overflowNegA);
764 APInt div = posA.sdiv_ov(b, overflowDiv);
765 APInt res = zero.ssub_ov(div, overflowNegRes);
766 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
770 APInt posB = zero.ssub_ov(b, overflowNegB);
771 APInt div = a.sdiv_ov(posB, overflowDiv);
772 APInt res = zero.ssub_ov(div, overflowNegRes);
774 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
778 return overflowOrDiv0 ?
Attribute() : result;
789 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
795 bool overflowOrDiv =
false;
796 auto result = constFoldBinaryOp<IntegerAttr>(
797 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
799 overflowOrDiv = true;
802 return a.sfloordiv_ov(b, overflowOrDiv);
805 return overflowOrDiv ?
Attribute() : result;
812 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
819 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
820 [&](APInt a,
const APInt &b) {
821 if (div0 || b.isZero()) {
835 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
842 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
843 [&](APInt a,
const APInt &b) {
844 if (div0 || b.isZero()) {
860 for (
bool reversePrev : {
false,
true}) {
861 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
862 .getDefiningOp<arith::AndIOp>();
866 Value other = (reversePrev ? op.getLhs() : op.getRhs());
867 if (other != prev.getLhs() && other != prev.getRhs())
870 return prev.getResult();
882 intValue.isAllOnes())
887 intValue.isAllOnes())
892 intValue.isAllOnes())
899 return constFoldBinaryOp<IntegerAttr>(
900 adaptor.getOperands(),
901 [](APInt a,
const APInt &b) { return std::move(a) & b; });
914 if (rhsVal.isAllOnes())
915 return adaptor.getRhs();
922 intValue.isAllOnes())
923 return getRhs().getDefiningOp<XOrIOp>().getRhs();
927 intValue.isAllOnes())
928 return getLhs().getDefiningOp<XOrIOp>().getRhs();
930 return constFoldBinaryOp<IntegerAttr>(
931 adaptor.getOperands(),
932 [](APInt a,
const APInt &b) { return std::move(a) | b; });
944 if (getLhs() == getRhs())
948 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
949 if (prev.getRhs() == getRhs())
950 return prev.getLhs();
951 if (prev.getLhs() == getRhs())
952 return prev.getRhs();
956 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
957 if (prev.getRhs() == getLhs())
958 return prev.getLhs();
959 if (prev.getLhs() == getLhs())
960 return prev.getRhs();
963 return constFoldBinaryOp<IntegerAttr>(
964 adaptor.getOperands(),
965 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
970 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
979 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
980 return op.getOperand();
981 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
982 [](
const APFloat &a) { return -a; });
994 return constFoldBinaryOp<FloatAttr>(
995 adaptor.getOperands(),
996 [](
const APFloat &a,
const APFloat &b) { return a + b; });
1003 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1008 return constFoldBinaryOp<FloatAttr>(
1009 adaptor.getOperands(),
1010 [](
const APFloat &a,
const APFloat &b) { return a - b; });
1017 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1019 if (getLhs() == getRhs())
1026 return constFoldBinaryOp<FloatAttr>(
1027 adaptor.getOperands(),
1028 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
1035 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1037 if (getLhs() == getRhs())
1044 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1053 if (getLhs() == getRhs())
1059 if (intValue.isMaxSignedValue())
1062 if (intValue.isMinSignedValue())
1066 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1067 [](
const APInt &a,
const APInt &b) {
1068 return llvm::APIntOps::smax(a, b);
1078 if (getLhs() == getRhs())
1084 if (intValue.isMaxValue())
1087 if (intValue.isMinValue())
1091 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1092 [](
const APInt &a,
const APInt &b) {
1093 return llvm::APIntOps::umax(a, b);
1101 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1103 if (getLhs() == getRhs())
1110 return constFoldBinaryOp<FloatAttr>(
1111 adaptor.getOperands(),
1112 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
1119 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1121 if (getLhs() == getRhs())
1128 return constFoldBinaryOp<FloatAttr>(
1129 adaptor.getOperands(),
1130 [](
const APFloat &a,
const APFloat &b) { return llvm::minnum(a, b); });
1139 if (getLhs() == getRhs())
1145 if (intValue.isMinSignedValue())
1148 if (intValue.isMaxSignedValue())
1152 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1153 [](
const APInt &a,
const APInt &b) {
1154 return llvm::APIntOps::smin(a, b);
1164 if (getLhs() == getRhs())
1170 if (intValue.isMinValue())
1173 if (intValue.isMaxValue())
1177 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1178 [](
const APInt &a,
const APInt &b) {
1179 return llvm::APIntOps::umin(a, b);
1187 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1192 return constFoldBinaryOp<FloatAttr>(
1193 adaptor.getOperands(),
1194 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1206 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1211 return constFoldBinaryOp<FloatAttr>(
1212 adaptor.getOperands(),
1213 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1225 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1226 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1227 [](
const APFloat &a,
const APFloat &b) {
1232 (void)result.mod(b);
1241 template <
typename... Types>
1247 template <
typename... ShapedTypes,
typename... ElementTypes>
1250 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1254 if (!llvm::isa<ElementTypes...>(underlyingType))
1257 return underlyingType;
1261 template <
typename... ElementTypes>
1268 template <
typename... ElementTypes>
1277 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1278 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1279 if (!rankedTensorA || !rankedTensorB)
1281 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1285 if (inputs.size() != 1 || outputs.size() != 1)
1297 template <
typename ValType,
typename Op>
1302 if (llvm::cast<ValType>(srcType).getWidth() >=
1303 llvm::cast<ValType>(dstType).getWidth())
1305 << dstType <<
" must be wider than operand type " << srcType;
1311 template <
typename ValType,
typename Op>
1316 if (llvm::cast<ValType>(srcType).getWidth() <=
1317 llvm::cast<ValType>(dstType).getWidth())
1319 << dstType <<
" must be shorter than operand type " << srcType;
1325 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1330 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1331 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1332 if (!srcType || !dstType)
1335 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1336 srcType.getIntOrFloatBitWidth());
1342 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1343 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1344 bool losesInfo =
false;
1345 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1346 if (losesInfo || status != APFloat::opOK)
1356 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1357 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1358 getInMutable().assign(lhs.getIn());
1363 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1364 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1365 adaptor.getOperands(),
getType(),
1366 [bitWidth](
const APInt &a,
bool &castStatus) {
1367 return a.zext(bitWidth);
1372 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1376 return verifyExtOp<IntegerType>(*
this);
1383 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1384 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1385 getInMutable().assign(lhs.getIn());
1390 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1391 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1392 adaptor.getOperands(),
getType(),
1393 [bitWidth](
const APInt &a,
bool &castStatus) {
1394 return a.sext(bitWidth);
1399 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1404 patterns.add<ExtSIOfExtUI>(context);
1408 return verifyExtOp<IntegerType>(*
this);
1417 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1418 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1419 if (truncFOp.getOperand().getType() ==
getType()) {
1420 arith::FastMathFlags truncFMF =
1421 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1422 bool isTruncContract =
1424 arith::FastMathFlags extFMF =
1425 getFastmath().value_or(arith::FastMathFlags::none);
1426 bool isExtContract =
1428 if (isTruncContract && isExtContract) {
1429 return truncFOp.getOperand();
1435 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1436 return constFoldCastOp<FloatAttr, FloatAttr>(
1437 adaptor.getOperands(),
getType(),
1438 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1440 if (failed(result)) {
1449 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1458 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1459 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1466 if (llvm::cast<IntegerType>(srcType).getWidth() >
1467 llvm::cast<IntegerType>(dstType).getWidth()) {
1474 if (srcType == dstType)
1479 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1480 setOperand(getOperand().getDefiningOp()->getOperand(0));
1485 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1486 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1487 adaptor.getOperands(),
getType(),
1488 [bitWidth](
const APInt &a,
bool &castStatus) {
1489 return a.trunc(bitWidth);
1494 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1499 patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1500 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1505 return verifyTruncateOp<IntegerType>(*
this);
1514 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1516 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1517 Value src = extOp.getIn();
1519 auto intermediateType =
1522 if (llvm::APFloatBase::isRepresentableBy(
1523 srcType.getFloatSemantics(),
1524 intermediateType.getFloatSemantics())) {
1526 if (srcType.getWidth() > resElemType.getWidth()) {
1532 if (srcType == resElemType)
1537 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1538 return constFoldCastOp<FloatAttr, FloatAttr>(
1539 adaptor.getOperands(),
getType(),
1540 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1541 RoundingMode roundingMode =
1542 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1543 llvm::RoundingMode llvmRoundingMode =
1545 FailureOr<APFloat> result =
1547 if (failed(result)) {
1557 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1561 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1565 return verifyTruncateOp<FloatType>(*
this);
1574 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1583 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1590 template <
typename From,
typename To>
1595 auto srcType = getTypeIfLike<From>(inputs.front());
1596 auto dstType = getTypeIfLike<To>(outputs.back());
1598 return srcType && dstType;
1606 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1609 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1611 return constFoldCastOp<IntegerAttr, FloatAttr>(
1612 adaptor.getOperands(),
getType(),
1613 [&resEleType](
const APInt &a,
bool &castStatus) {
1614 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1615 APFloat apf(floatTy.getFloatSemantics(),
1617 apf.convertFromAPInt(a,
false,
1618 APFloat::rmNearestTiesToEven);
1628 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1631 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1633 return constFoldCastOp<IntegerAttr, FloatAttr>(
1634 adaptor.getOperands(),
getType(),
1635 [&resEleType](
const APInt &a,
bool &castStatus) {
1636 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1637 APFloat apf(floatTy.getFloatSemantics(),
1639 apf.convertFromAPInt(a,
true,
1640 APFloat::rmNearestTiesToEven);
1650 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1653 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1655 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1656 return constFoldCastOp<FloatAttr, IntegerAttr>(
1657 adaptor.getOperands(),
getType(),
1658 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1660 APSInt api(bitWidth,
true);
1661 castStatus = APFloat::opInvalidOp !=
1662 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1672 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1675 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1677 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1678 return constFoldCastOp<FloatAttr, IntegerAttr>(
1679 adaptor.getOperands(),
getType(),
1680 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1682 APSInt api(bitWidth,
false);
1683 castStatus = APFloat::opInvalidOp !=
1684 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1697 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1698 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1699 if (!srcType || !dstType)
1706 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1711 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1713 unsigned resultBitwidth = 64;
1715 resultBitwidth = intTy.getWidth();
1717 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1718 adaptor.getOperands(),
getType(),
1719 [resultBitwidth](
const APInt &a,
bool & ) {
1720 return a.sextOrTrunc(resultBitwidth);
1724 void arith::IndexCastOp::getCanonicalizationPatterns(
1726 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1733 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1738 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1740 unsigned resultBitwidth = 64;
1742 resultBitwidth = intTy.getWidth();
1744 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1745 adaptor.getOperands(),
getType(),
1746 [resultBitwidth](
const APInt &a,
bool & ) {
1747 return a.zextOrTrunc(resultBitwidth);
1751 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1753 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1764 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1765 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1766 if (!srcType || !dstType)
1772 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1774 auto operand = adaptor.getIn();
1779 if (
auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1780 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1782 if (llvm::isa<ShapedType>(resType))
1786 if (llvm::isa<ub::PoisonAttr>(operand))
1790 APInt bits = llvm::isa<FloatAttr>(operand)
1791 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1792 : llvm::cast<IntegerAttr>(operand).getValue();
1794 "trying to fold on broken IR: operands have incompatible types");
1796 if (
auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1798 APFloat(resFloatType.getFloatSemantics(), bits));
1804 patterns.add<BitcastOfBitcast>(context);
1814 const APInt &lhs,
const APInt &rhs) {
1815 switch (predicate) {
1816 case arith::CmpIPredicate::eq:
1818 case arith::CmpIPredicate::ne:
1820 case arith::CmpIPredicate::slt:
1821 return lhs.slt(rhs);
1822 case arith::CmpIPredicate::sle:
1823 return lhs.sle(rhs);
1824 case arith::CmpIPredicate::sgt:
1825 return lhs.sgt(rhs);
1826 case arith::CmpIPredicate::sge:
1827 return lhs.sge(rhs);
1828 case arith::CmpIPredicate::ult:
1829 return lhs.ult(rhs);
1830 case arith::CmpIPredicate::ule:
1831 return lhs.ule(rhs);
1832 case arith::CmpIPredicate::ugt:
1833 return lhs.ugt(rhs);
1834 case arith::CmpIPredicate::uge:
1835 return lhs.uge(rhs);
1837 llvm_unreachable(
"unknown cmpi predicate kind");
1842 switch (predicate) {
1843 case arith::CmpIPredicate::eq:
1844 case arith::CmpIPredicate::sle:
1845 case arith::CmpIPredicate::sge:
1846 case arith::CmpIPredicate::ule:
1847 case arith::CmpIPredicate::uge:
1849 case arith::CmpIPredicate::ne:
1850 case arith::CmpIPredicate::slt:
1851 case arith::CmpIPredicate::sgt:
1852 case arith::CmpIPredicate::ult:
1853 case arith::CmpIPredicate::ugt:
1856 llvm_unreachable(
"unknown cmpi predicate kind");
1860 if (
auto intType = llvm::dyn_cast<IntegerType>(t)) {
1861 return intType.getWidth();
1863 if (
auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1864 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1866 return std::nullopt;
1869 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1871 if (getLhs() == getRhs()) {
1877 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1879 std::optional<int64_t> integerWidth =
1881 if (integerWidth && integerWidth.value() == 1 &&
1882 getPredicate() == arith::CmpIPredicate::ne)
1883 return extOp.getOperand();
1885 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1887 std::optional<int64_t> integerWidth =
1889 if (integerWidth && integerWidth.value() == 1 &&
1890 getPredicate() == arith::CmpIPredicate::ne)
1891 return extOp.getOperand();
1896 getPredicate() == arith::CmpIPredicate::ne)
1903 getPredicate() == arith::CmpIPredicate::eq)
1908 if (adaptor.getLhs() && !adaptor.getRhs()) {
1910 using Pred = CmpIPredicate;
1911 const std::pair<Pred, Pred> invPreds[] = {
1912 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1913 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1914 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1915 {Pred::ne, Pred::ne},
1917 Pred origPred = getPredicate();
1918 for (
auto pred : invPreds) {
1919 if (origPred == pred.first) {
1920 setPredicate(pred.second);
1921 Value lhs = getLhs();
1922 Value rhs = getRhs();
1923 getLhsMutable().assign(rhs);
1924 getRhsMutable().assign(lhs);
1928 llvm_unreachable(
"unknown cmpi predicate kind");
1933 if (
auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1934 return constFoldBinaryOp<IntegerAttr>(
1936 [pred = getPredicate()](
const APInt &lhs,
const APInt &rhs) {
1947 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1957 const APFloat &lhs,
const APFloat &rhs) {
1958 auto cmpResult = lhs.compare(rhs);
1959 switch (predicate) {
1960 case arith::CmpFPredicate::AlwaysFalse:
1962 case arith::CmpFPredicate::OEQ:
1963 return cmpResult == APFloat::cmpEqual;
1964 case arith::CmpFPredicate::OGT:
1965 return cmpResult == APFloat::cmpGreaterThan;
1966 case arith::CmpFPredicate::OGE:
1967 return cmpResult == APFloat::cmpGreaterThan ||
1968 cmpResult == APFloat::cmpEqual;
1969 case arith::CmpFPredicate::OLT:
1970 return cmpResult == APFloat::cmpLessThan;
1971 case arith::CmpFPredicate::OLE:
1972 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1973 case arith::CmpFPredicate::ONE:
1974 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1975 case arith::CmpFPredicate::ORD:
1976 return cmpResult != APFloat::cmpUnordered;
1977 case arith::CmpFPredicate::UEQ:
1978 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1979 case arith::CmpFPredicate::UGT:
1980 return cmpResult == APFloat::cmpUnordered ||
1981 cmpResult == APFloat::cmpGreaterThan;
1982 case arith::CmpFPredicate::UGE:
1983 return cmpResult == APFloat::cmpUnordered ||
1984 cmpResult == APFloat::cmpGreaterThan ||
1985 cmpResult == APFloat::cmpEqual;
1986 case arith::CmpFPredicate::ULT:
1987 return cmpResult == APFloat::cmpUnordered ||
1988 cmpResult == APFloat::cmpLessThan;
1989 case arith::CmpFPredicate::ULE:
1990 return cmpResult == APFloat::cmpUnordered ||
1991 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1992 case arith::CmpFPredicate::UNE:
1993 return cmpResult != APFloat::cmpEqual;
1994 case arith::CmpFPredicate::UNO:
1995 return cmpResult == APFloat::cmpUnordered;
1996 case arith::CmpFPredicate::AlwaysTrue:
1999 llvm_unreachable(
"unknown cmpf predicate kind");
2002 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2003 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2004 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2007 if (lhs && lhs.getValue().isNaN())
2009 if (rhs && rhs.getValue().isNaN())
2025 using namespace arith;
2027 case CmpFPredicate::UEQ:
2028 case CmpFPredicate::OEQ:
2029 return CmpIPredicate::eq;
2030 case CmpFPredicate::UGT:
2031 case CmpFPredicate::OGT:
2032 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2033 case CmpFPredicate::UGE:
2034 case CmpFPredicate::OGE:
2035 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2036 case CmpFPredicate::ULT:
2037 case CmpFPredicate::OLT:
2038 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2039 case CmpFPredicate::ULE:
2040 case CmpFPredicate::OLE:
2041 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2042 case CmpFPredicate::UNE:
2043 case CmpFPredicate::ONE:
2044 return CmpIPredicate::ne;
2046 llvm_unreachable(
"Unexpected predicate!");
2056 const APFloat &rhs = flt.getValue();
2064 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2065 int mantissaWidth = floatTy.getFPMantissaWidth();
2066 if (mantissaWidth <= 0)
2072 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2074 intVal = si.getIn();
2075 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2077 intVal = ui.getIn();
2084 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2085 auto intWidth = intTy.getWidth();
2088 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2093 if ((
int)intWidth > mantissaWidth) {
2095 int exponent = ilogb(rhs);
2096 if (exponent == APFloat::IEK_Inf) {
2097 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2098 if (maxExponent < (
int)valueBits) {
2105 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2114 switch (op.getPredicate()) {
2115 case CmpFPredicate::ORD:
2120 case CmpFPredicate::UNO:
2133 APFloat signedMax(rhs.getSemantics());
2134 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2135 APFloat::rmNearestTiesToEven);
2136 if (signedMax < rhs) {
2137 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2138 pred == CmpIPredicate::sle)
2149 APFloat unsignedMax(rhs.getSemantics());
2150 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2151 APFloat::rmNearestTiesToEven);
2152 if (unsignedMax < rhs) {
2153 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2154 pred == CmpIPredicate::ule)
2166 APFloat signedMin(rhs.getSemantics());
2167 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2168 APFloat::rmNearestTiesToEven);
2169 if (signedMin > rhs) {
2170 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2171 pred == CmpIPredicate::sge)
2181 APFloat unsignedMin(rhs.getSemantics());
2182 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2183 APFloat::rmNearestTiesToEven);
2184 if (unsignedMin > rhs) {
2185 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2186 pred == CmpIPredicate::uge)
2201 APSInt rhsInt(intWidth, isUnsigned);
2202 if (APFloat::opInvalidOp ==
2203 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2209 if (!rhs.isZero()) {
2210 APFloat apf(floatTy.getFloatSemantics(),
2212 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2214 bool equal = apf == rhs;
2220 case CmpIPredicate::ne:
2224 case CmpIPredicate::eq:
2228 case CmpIPredicate::ule:
2231 if (rhs.isNegative()) {
2237 case CmpIPredicate::sle:
2240 if (rhs.isNegative())
2241 pred = CmpIPredicate::slt;
2243 case CmpIPredicate::ult:
2246 if (rhs.isNegative()) {
2251 pred = CmpIPredicate::ule;
2253 case CmpIPredicate::slt:
2256 if (!rhs.isNegative())
2257 pred = CmpIPredicate::sle;
2259 case CmpIPredicate::ugt:
2262 if (rhs.isNegative()) {
2268 case CmpIPredicate::sgt:
2271 if (rhs.isNegative())
2272 pred = CmpIPredicate::sge;
2274 case CmpIPredicate::uge:
2277 if (rhs.isNegative()) {
2282 pred = CmpIPredicate::ugt;
2284 case CmpIPredicate::sge:
2287 if (!rhs.isNegative())
2288 pred = CmpIPredicate::sgt;
2298 rewriter.
create<ConstantOp>(
2299 op.getLoc(), intVal.
getType(),
2321 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2337 rewriter.
create<arith::XOrIOp>(
2338 op.getLoc(), op.getCondition(),
2339 rewriter.
create<arith::ConstantIntOp>(
2340 op.getLoc(), 1, op.getCondition().getType())));
2350 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2354 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2355 Value trueVal = getTrueValue();
2356 Value falseVal = getFalseValue();
2357 if (trueVal == falseVal)
2360 Value condition = getCondition();
2371 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2374 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2378 if (
getType().isSignlessInteger(1) &&
2383 if (
auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.
getDefiningOp())) {
2384 auto pred = cmp.getPredicate();
2385 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2386 auto cmpLhs = cmp.getLhs();
2387 auto cmpRhs = cmp.getRhs();
2395 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2396 (cmpRhs == trueVal && cmpLhs == falseVal))
2397 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2404 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2406 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2408 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2410 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2411 auto condVals = llvm::make_range(cond.value_begin<
BoolAttr>(),
2413 auto lhsVals = llvm::make_range(lhs.value_begin<
Attribute>(),
2415 auto rhsVals = llvm::make_range(rhs.value_begin<
Attribute>(),
2418 for (
auto [condVal, lhsVal, rhsVal] :
2419 llvm::zip_equal(condVals, lhsVals, rhsVals))
2420 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2431 Type conditionType, resultType;
2440 conditionType = resultType;
2449 {conditionType, resultType, resultType},
2454 p <<
" " << getOperands();
2457 if (ShapedType condType =
2458 llvm::dyn_cast<ShapedType>(getCondition().
getType()))
2459 p << condType <<
", ";
2464 Type conditionType = getCondition().getType();
2471 if (!llvm::isa<TensorType, VectorType>(resultType))
2472 return emitOpError() <<
"expected condition to be a signless i1, but got "
2475 if (conditionType != shapedConditionType) {
2476 return emitOpError() <<
"expected condition type to have the same shape "
2477 "as the result type, expected "
2478 << shapedConditionType <<
", but got "
2487 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2492 bool bounded =
false;
2493 auto result = constFoldBinaryOp<IntegerAttr>(
2494 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2495 bounded = b.ult(b.getBitWidth());
2505 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2510 bool bounded =
false;
2511 auto result = constFoldBinaryOp<IntegerAttr>(
2512 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2513 bounded = b.ult(b.getBitWidth());
2523 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2528 bool bounded =
false;
2529 auto result = constFoldBinaryOp<IntegerAttr>(
2530 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2531 bounded = b.ult(b.getBitWidth());
2544 bool useOnlyFiniteValue) {
2546 case AtomicRMWKind::maximumf: {
2547 const llvm::fltSemantics &semantic =
2548 llvm::cast<FloatType>(resultType).getFloatSemantics();
2549 APFloat identity = useOnlyFiniteValue
2550 ? APFloat::getLargest(semantic,
true)
2551 : APFloat::getInf(semantic,
true);
2554 case AtomicRMWKind::maxnumf: {
2555 const llvm::fltSemantics &semantic =
2556 llvm::cast<FloatType>(resultType).getFloatSemantics();
2557 APFloat identity = APFloat::getNaN(semantic,
true);
2560 case AtomicRMWKind::addf:
2561 case AtomicRMWKind::addi:
2562 case AtomicRMWKind::maxu:
2563 case AtomicRMWKind::ori:
2565 case AtomicRMWKind::andi:
2568 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2569 case AtomicRMWKind::maxs:
2571 resultType, APInt::getSignedMinValue(
2572 llvm::cast<IntegerType>(resultType).getWidth()));
2573 case AtomicRMWKind::minimumf: {
2574 const llvm::fltSemantics &semantic =
2575 llvm::cast<FloatType>(resultType).getFloatSemantics();
2576 APFloat identity = useOnlyFiniteValue
2577 ? APFloat::getLargest(semantic,
false)
2578 : APFloat::getInf(semantic,
false);
2582 case AtomicRMWKind::minnumf: {
2583 const llvm::fltSemantics &semantic =
2584 llvm::cast<FloatType>(resultType).getFloatSemantics();
2585 APFloat identity = APFloat::getNaN(semantic,
false);
2588 case AtomicRMWKind::mins:
2590 resultType, APInt::getSignedMaxValue(
2591 llvm::cast<IntegerType>(resultType).getWidth()));
2592 case AtomicRMWKind::minu:
2595 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2596 case AtomicRMWKind::muli:
2598 case AtomicRMWKind::mulf:
2610 std::optional<AtomicRMWKind> maybeKind =
2613 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2614 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2615 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2616 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2617 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2618 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2620 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2621 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2622 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::ori; })
2623 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2624 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2625 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2626 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2627 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2628 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2629 .Default([](
Operation *op) {
return std::nullopt; });
2631 return std::nullopt;
2634 bool useOnlyFiniteValue =
false;
2635 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2636 if (fmfOpInterface) {
2637 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2638 useOnlyFiniteValue =
2639 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2647 useOnlyFiniteValue);
2653 bool useOnlyFiniteValue) {
2656 return builder.
create<arith::ConstantOp>(loc, attr);
2664 case AtomicRMWKind::addf:
2665 return builder.
create<arith::AddFOp>(loc, lhs, rhs);
2666 case AtomicRMWKind::addi:
2667 return builder.
create<arith::AddIOp>(loc, lhs, rhs);
2668 case AtomicRMWKind::mulf:
2669 return builder.
create<arith::MulFOp>(loc, lhs, rhs);
2670 case AtomicRMWKind::muli:
2671 return builder.
create<arith::MulIOp>(loc, lhs, rhs);
2672 case AtomicRMWKind::maximumf:
2673 return builder.
create<arith::MaximumFOp>(loc, lhs, rhs);
2674 case AtomicRMWKind::minimumf:
2675 return builder.
create<arith::MinimumFOp>(loc, lhs, rhs);
2676 case AtomicRMWKind::maxnumf:
2677 return builder.
create<arith::MaxNumFOp>(loc, lhs, rhs);
2678 case AtomicRMWKind::minnumf:
2679 return builder.
create<arith::MinNumFOp>(loc, lhs, rhs);
2680 case AtomicRMWKind::maxs:
2681 return builder.
create<arith::MaxSIOp>(loc, lhs, rhs);
2682 case AtomicRMWKind::mins:
2683 return builder.
create<arith::MinSIOp>(loc, lhs, rhs);
2684 case AtomicRMWKind::maxu:
2685 return builder.
create<arith::MaxUIOp>(loc, lhs, rhs);
2686 case AtomicRMWKind::minu:
2687 return builder.
create<arith::MinUIOp>(loc, lhs, rhs);
2688 case AtomicRMWKind::ori:
2689 return builder.
create<arith::OrIOp>(loc, lhs, rhs);
2690 case AtomicRMWKind::andi:
2691 return builder.
create<arith::AndIOp>(loc, lhs, rhs);
2704 #define GET_OP_CLASSES
2705 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2711 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
Specialization of arith.constant op that returns an integer value.
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)