26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
45 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
46 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48 APInt value = binFn(lhsVal, rhsVal);
68 static IntegerOverflowFlagsAttr
70 IntegerOverflowFlagsAttr val2) {
72 val1.getValue() & val2.getValue());
78 case arith::CmpIPredicate::eq:
79 return arith::CmpIPredicate::ne;
80 case arith::CmpIPredicate::ne:
81 return arith::CmpIPredicate::eq;
82 case arith::CmpIPredicate::slt:
83 return arith::CmpIPredicate::sge;
84 case arith::CmpIPredicate::sle:
85 return arith::CmpIPredicate::sgt;
86 case arith::CmpIPredicate::sgt:
87 return arith::CmpIPredicate::sle;
88 case arith::CmpIPredicate::sge:
89 return arith::CmpIPredicate::slt;
90 case arith::CmpIPredicate::ult:
91 return arith::CmpIPredicate::uge;
92 case arith::CmpIPredicate::ule:
93 return arith::CmpIPredicate::ugt;
94 case arith::CmpIPredicate::ugt:
95 return arith::CmpIPredicate::ule;
96 case arith::CmpIPredicate::uge:
97 return arith::CmpIPredicate::ult;
99 llvm_unreachable(
"unknown cmpi predicate kind");
108 static llvm::RoundingMode
110 switch (roundingMode) {
111 case RoundingMode::downward:
112 return llvm::RoundingMode::TowardNegative;
113 case RoundingMode::to_nearest_away:
114 return llvm::RoundingMode::NearestTiesToAway;
115 case RoundingMode::to_nearest_even:
116 return llvm::RoundingMode::NearestTiesToEven;
117 case RoundingMode::toward_zero:
118 return llvm::RoundingMode::TowardZero;
119 case RoundingMode::upward:
120 return llvm::RoundingMode::TowardPositive;
122 llvm_unreachable(
"Unhandled rounding mode");
152 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
163 #include "ArithCanonicalization.inc"
173 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
174 return shapedType.cloneWith(std::nullopt, i1Type);
175 if (llvm::isa<UnrankedTensorType>(type))
184 void arith::ConstantOp::getAsmResultNames(
187 if (
auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188 auto intType = llvm::dyn_cast<IntegerType>(type);
191 if (intType && intType.getWidth() == 1)
192 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
196 llvm::raw_svector_ostream specialName(specialNameBuffer);
197 specialName <<
'c' << intCst.getValue();
199 specialName <<
'_' << type;
200 setNameFn(getResult(), specialName.str());
202 setNameFn(getResult(),
"cst");
211 if (llvm::isa<IntegerType>(type) &&
212 !llvm::cast<IntegerType>(type).isSignless())
213 return emitOpError(
"integer return type must be signless");
215 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
217 "value must be an integer, float, or elements attribute");
223 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
225 "intializing scalable vectors with elements attribute is not supported"
226 " unless it's a vector splat");
230 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
232 auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
233 if (!typedAttr || typedAttr.getType() != type)
236 if (llvm::isa<IntegerType>(type) &&
237 !llvm::cast<IntegerType>(type).isSignless())
240 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
245 if (isBuildableWith(value, type))
246 return builder.
create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
250 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
253 int64_t value,
unsigned width) {
255 arith::ConstantOp::build(builder, result, type,
260 int64_t value,
Type type) {
262 "ConstantIntOp can only have signless integer type values");
263 arith::ConstantOp::build(builder, result, type,
268 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
269 return constOp.getType().isSignlessInteger();
274 const APFloat &value, FloatType type) {
275 arith::ConstantOp::build(builder, result, type,
280 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
281 return llvm::isa<FloatType>(constOp.getType());
287 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
292 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
293 return constOp.getType().isIndex();
307 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
308 if (getRhs() == sub.getRhs())
312 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
313 if (getLhs() == sub.getRhs())
316 return constFoldBinaryOp<IntegerAttr>(
317 adaptor.getOperands(),
318 [](APInt a,
const APInt &b) { return std::move(a) + b; });
323 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
324 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
331 std::optional<SmallVector<int64_t, 4>>
332 arith::AddUIExtendedOp::getShapeForUnroll() {
333 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
334 return llvm::to_vector<4>(vt.getShape());
341 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
345 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
347 Type overflowTy = getOverflow().getType();
353 results.push_back(getLhs());
354 results.push_back(falseValue);
362 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
363 adaptor.getOperands(),
364 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
365 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
366 ArrayRef({sumAttr, adaptor.getLhs()}),
372 results.push_back(sumAttr);
373 results.push_back(overflowAttr);
380 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
382 patterns.add<AddUIExtendedToAddI>(context);
391 if (getOperand(0) == getOperand(1)) {
392 auto shapedType = dyn_cast<ShapedType>(
getType());
394 if (!shapedType || shapedType.hasStaticShape())
401 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
403 if (getRhs() == add.getRhs())
406 if (getRhs() == add.getLhs())
410 return constFoldBinaryOp<IntegerAttr>(
411 adaptor.getOperands(),
412 [](APInt a,
const APInt &b) { return std::move(a) - b; });
417 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
418 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
419 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
436 return constFoldBinaryOp<IntegerAttr>(
437 adaptor.getOperands(),
438 [](
const APInt &a,
const APInt &b) { return a * b; });
441 void arith::MulIOp::getAsmResultNames(
443 if (!isa<IndexType>(
getType()))
449 return op && op->getName().getStringRef() ==
"vector.vscale";
452 IntegerAttr baseValue;
455 isVscale(b.getDefiningOp());
458 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
463 llvm::raw_svector_ostream specialName(specialNameBuffer);
464 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
465 setNameFn(getResult(), specialName.str());
470 patterns.add<MulIMulIConstant>(context);
477 std::optional<SmallVector<int64_t, 4>>
478 arith::MulSIExtendedOp::getShapeForUnroll() {
479 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
480 return llvm::to_vector<4>(vt.getShape());
485 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
490 results.push_back(zero);
491 results.push_back(zero);
496 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
497 adaptor.getOperands(),
498 [](
const APInt &a,
const APInt &b) { return a * b; })) {
500 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
501 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
502 return llvm::APIntOps::mulhs(a, b);
504 assert(highAttr &&
"Unexpected constant-folding failure");
506 results.push_back(lowAttr);
507 results.push_back(highAttr);
514 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
516 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
523 std::optional<SmallVector<int64_t, 4>>
524 arith::MulUIExtendedOp::getShapeForUnroll() {
525 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
526 return llvm::to_vector<4>(vt.getShape());
531 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
536 results.push_back(zero);
537 results.push_back(zero);
545 results.push_back(getLhs());
546 results.push_back(zero);
551 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
552 adaptor.getOperands(),
553 [](
const APInt &a,
const APInt &b) { return a * b; })) {
555 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
556 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
557 return llvm::APIntOps::mulhu(a, b);
559 assert(highAttr &&
"Unexpected constant-folding failure");
561 results.push_back(lowAttr);
562 results.push_back(highAttr);
569 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
571 patterns.add<MulUIExtendedToMulI>(context);
580 arith::IntegerOverflowFlags ovfFlags) {
582 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
585 if (mul.getLhs() == rhs)
588 if (mul.getRhs() == rhs)
594 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
600 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
605 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
606 [&](APInt a,
const APInt &b) {
634 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
640 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
644 bool overflowOrDiv0 =
false;
645 auto result = constFoldBinaryOp<IntegerAttr>(
646 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
647 if (overflowOrDiv0 || !b) {
648 overflowOrDiv0 = true;
651 return a.sdiv_ov(b, overflowOrDiv0);
654 return overflowOrDiv0 ?
Attribute() : result;
681 APInt one(a.getBitWidth(), 1,
true);
682 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
683 return val.sadd_ov(one, overflow);
690 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
695 bool overflowOrDiv0 =
false;
696 auto result = constFoldBinaryOp<IntegerAttr>(
697 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
698 if (overflowOrDiv0 || !b) {
699 overflowOrDiv0 = true;
702 APInt quotient = a.udiv(b);
705 APInt one(a.getBitWidth(), 1,
true);
706 return quotient.uadd_ov(one, overflowOrDiv0);
709 return overflowOrDiv0 ?
Attribute() : result;
720 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
728 bool overflowOrDiv0 =
false;
729 auto result = constFoldBinaryOp<IntegerAttr>(
730 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
731 if (overflowOrDiv0 || !b) {
732 overflowOrDiv0 = true;
738 unsigned bits = a.getBitWidth();
740 bool aGtZero = a.sgt(zero);
741 bool bGtZero = b.sgt(zero);
742 if (aGtZero && bGtZero) {
749 bool overflowNegA =
false;
750 bool overflowNegB =
false;
751 bool overflowDiv =
false;
752 bool overflowNegRes =
false;
753 if (!aGtZero && !bGtZero) {
755 APInt posA = zero.ssub_ov(a, overflowNegA);
756 APInt posB = zero.ssub_ov(b, overflowNegB);
758 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
761 if (!aGtZero && bGtZero) {
763 APInt posA = zero.ssub_ov(a, overflowNegA);
764 APInt div = posA.sdiv_ov(b, overflowDiv);
765 APInt res = zero.ssub_ov(div, overflowNegRes);
766 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
770 APInt posB = zero.ssub_ov(b, overflowNegB);
771 APInt div = a.sdiv_ov(posB, overflowDiv);
772 APInt res = zero.ssub_ov(div, overflowNegRes);
774 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
778 return overflowOrDiv0 ?
Attribute() : result;
789 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
795 bool overflowOrDiv =
false;
796 auto result = constFoldBinaryOp<IntegerAttr>(
797 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
799 overflowOrDiv = true;
802 return a.sfloordiv_ov(b, overflowOrDiv);
805 return overflowOrDiv ?
Attribute() : result;
812 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
819 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
820 [&](APInt a,
const APInt &b) {
821 if (div0 || b.isZero()) {
835 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
842 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
843 [&](APInt a,
const APInt &b) {
844 if (div0 || b.isZero()) {
860 for (
bool reversePrev : {
false,
true}) {
861 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
862 .getDefiningOp<arith::AndIOp>();
866 Value other = (reversePrev ? op.getLhs() : op.getRhs());
867 if (other != prev.getLhs() && other != prev.getRhs())
870 return prev.getResult();
882 intValue.isAllOnes())
887 intValue.isAllOnes())
892 intValue.isAllOnes())
899 return constFoldBinaryOp<IntegerAttr>(
900 adaptor.getOperands(),
901 [](APInt a,
const APInt &b) { return std::move(a) & b; });
914 if (rhsVal.isAllOnes())
915 return adaptor.getRhs();
922 intValue.isAllOnes())
923 return getRhs().getDefiningOp<XOrIOp>().getRhs();
927 intValue.isAllOnes())
928 return getLhs().getDefiningOp<XOrIOp>().getRhs();
930 return constFoldBinaryOp<IntegerAttr>(
931 adaptor.getOperands(),
932 [](APInt a,
const APInt &b) { return std::move(a) | b; });
944 if (getLhs() == getRhs())
948 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
949 if (prev.getRhs() == getRhs())
950 return prev.getLhs();
951 if (prev.getLhs() == getRhs())
952 return prev.getRhs();
956 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
957 if (prev.getRhs() == getLhs())
958 return prev.getLhs();
959 if (prev.getLhs() == getLhs())
960 return prev.getRhs();
963 return constFoldBinaryOp<IntegerAttr>(
964 adaptor.getOperands(),
965 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
970 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
979 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
980 return op.getOperand();
981 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
982 [](
const APFloat &a) { return -a; });
994 return constFoldBinaryOp<FloatAttr>(
995 adaptor.getOperands(),
996 [](
const APFloat &a,
const APFloat &b) { return a + b; });
1003 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1008 return constFoldBinaryOp<FloatAttr>(
1009 adaptor.getOperands(),
1010 [](
const APFloat &a,
const APFloat &b) { return a - b; });
1017 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1019 if (getLhs() == getRhs())
1026 return constFoldBinaryOp<FloatAttr>(
1027 adaptor.getOperands(),
1028 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
1035 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1037 if (getLhs() == getRhs())
1044 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1053 if (getLhs() == getRhs())
1059 if (intValue.isMaxSignedValue())
1062 if (intValue.isMinSignedValue())
1066 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1067 [](
const APInt &a,
const APInt &b) {
1068 return llvm::APIntOps::smax(a, b);
1078 if (getLhs() == getRhs())
1084 if (intValue.isMaxValue())
1087 if (intValue.isMinValue())
1091 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1092 [](
const APInt &a,
const APInt &b) {
1093 return llvm::APIntOps::umax(a, b);
1101 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1103 if (getLhs() == getRhs())
1110 return constFoldBinaryOp<FloatAttr>(
1111 adaptor.getOperands(),
1112 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
1119 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1121 if (getLhs() == getRhs())
1128 return constFoldBinaryOp<FloatAttr>(
1129 adaptor.getOperands(),
1130 [](
const APFloat &a,
const APFloat &b) { return llvm::minnum(a, b); });
1139 if (getLhs() == getRhs())
1145 if (intValue.isMinSignedValue())
1148 if (intValue.isMaxSignedValue())
1152 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1153 [](
const APInt &a,
const APInt &b) {
1154 return llvm::APIntOps::smin(a, b);
1164 if (getLhs() == getRhs())
1170 if (intValue.isMinValue())
1173 if (intValue.isMaxValue())
1177 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1178 [](
const APInt &a,
const APInt &b) {
1179 return llvm::APIntOps::umin(a, b);
1187 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1192 return constFoldBinaryOp<FloatAttr>(
1193 adaptor.getOperands(),
1194 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1206 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1211 return constFoldBinaryOp<FloatAttr>(
1212 adaptor.getOperands(),
1213 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1225 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1226 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1227 [](
const APFloat &a,
const APFloat &b) {
1232 (void)result.mod(b);
1241 template <
typename... Types>
1247 template <
typename... ShapedTypes,
typename... ElementTypes>
1250 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1254 if (!llvm::isa<ElementTypes...>(underlyingType))
1257 return underlyingType;
1261 template <
typename... ElementTypes>
1268 template <
typename... ElementTypes>
1277 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1278 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1279 if (!rankedTensorA || !rankedTensorB)
1281 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1285 if (inputs.size() != 1 || outputs.size() != 1)
1297 template <
typename ValType,
typename Op>
1302 if (llvm::cast<ValType>(srcType).getWidth() >=
1303 llvm::cast<ValType>(dstType).getWidth())
1305 << dstType <<
" must be wider than operand type " << srcType;
1311 template <
typename ValType,
typename Op>
1316 if (llvm::cast<ValType>(srcType).getWidth() <=
1317 llvm::cast<ValType>(dstType).getWidth())
1319 << dstType <<
" must be shorter than operand type " << srcType;
1325 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1330 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1331 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1332 if (!srcType || !dstType)
1335 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1336 srcType.getIntOrFloatBitWidth());
1342 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1343 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1344 bool losesInfo =
false;
1345 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1346 if (losesInfo || status != APFloat::opOK)
1356 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1357 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1358 getInMutable().assign(lhs.getIn());
1363 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1364 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1365 adaptor.getOperands(),
getType(),
1366 [bitWidth](
const APInt &a,
bool &castStatus) {
1367 return a.zext(bitWidth);
1372 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1376 return verifyExtOp<IntegerType>(*
this);
1383 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1384 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1385 getInMutable().assign(lhs.getIn());
1390 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1391 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1392 adaptor.getOperands(),
getType(),
1393 [bitWidth](
const APInt &a,
bool &castStatus) {
1394 return a.sext(bitWidth);
1399 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1404 patterns.add<ExtSIOfExtUI>(context);
1408 return verifyExtOp<IntegerType>(*
this);
1417 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1418 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1419 if (truncFOp.getOperand().getType() ==
getType()) {
1420 arith::FastMathFlags truncFMF =
1421 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1422 bool isTruncContract =
1424 arith::FastMathFlags extFMF =
1425 getFastmath().value_or(arith::FastMathFlags::none);
1426 bool isExtContract =
1428 if (isTruncContract && isExtContract) {
1429 return truncFOp.getOperand();
1435 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1436 return constFoldCastOp<FloatAttr, FloatAttr>(
1437 adaptor.getOperands(),
getType(),
1438 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1440 if (failed(result)) {
1449 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1458 bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
1460 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1464 return verifyExtOp<FloatType>(*
this);
1471 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1472 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1479 if (llvm::cast<IntegerType>(srcType).getWidth() >
1480 llvm::cast<IntegerType>(dstType).getWidth()) {
1487 if (srcType == dstType)
1492 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1493 setOperand(getOperand().getDefiningOp()->getOperand(0));
1498 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1499 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1500 adaptor.getOperands(),
getType(),
1501 [bitWidth](
const APInt &a,
bool &castStatus) {
1502 return a.trunc(bitWidth);
1507 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1512 patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1513 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1518 return verifyTruncateOp<IntegerType>(*
this);
1527 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1529 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1530 Value src = extOp.getIn();
1532 auto intermediateType =
1535 if (llvm::APFloatBase::isRepresentableBy(
1536 srcType.getFloatSemantics(),
1537 intermediateType.getFloatSemantics())) {
1539 if (srcType.getWidth() > resElemType.getWidth()) {
1545 if (srcType == resElemType)
1550 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1551 return constFoldCastOp<FloatAttr, FloatAttr>(
1552 adaptor.getOperands(),
getType(),
1553 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1554 RoundingMode roundingMode =
1555 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1556 llvm::RoundingMode llvmRoundingMode =
1558 FailureOr<APFloat> result =
1560 if (failed(result)) {
1570 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1574 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1578 return verifyTruncateOp<FloatType>(*
this);
1585 bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
1587 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1591 return verifyTruncateOp<FloatType>(*
this);
1600 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1609 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1616 template <
typename From,
typename To>
1621 auto srcType = getTypeIfLike<From>(inputs.front());
1622 auto dstType = getTypeIfLike<To>(outputs.back());
1624 return srcType && dstType;
1632 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1635 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1637 return constFoldCastOp<IntegerAttr, FloatAttr>(
1638 adaptor.getOperands(),
getType(),
1639 [&resEleType](
const APInt &a,
bool &castStatus) {
1640 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1641 APFloat apf(floatTy.getFloatSemantics(),
1643 apf.convertFromAPInt(a,
false,
1644 APFloat::rmNearestTiesToEven);
1654 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1657 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1659 return constFoldCastOp<IntegerAttr, FloatAttr>(
1660 adaptor.getOperands(),
getType(),
1661 [&resEleType](
const APInt &a,
bool &castStatus) {
1662 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1663 APFloat apf(floatTy.getFloatSemantics(),
1665 apf.convertFromAPInt(a,
true,
1666 APFloat::rmNearestTiesToEven);
1676 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1679 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1681 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1682 return constFoldCastOp<FloatAttr, IntegerAttr>(
1683 adaptor.getOperands(),
getType(),
1684 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1686 APSInt api(bitWidth,
true);
1687 castStatus = APFloat::opInvalidOp !=
1688 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1698 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1701 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1703 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1704 return constFoldCastOp<FloatAttr, IntegerAttr>(
1705 adaptor.getOperands(),
getType(),
1706 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1708 APSInt api(bitWidth,
false);
1709 castStatus = APFloat::opInvalidOp !=
1710 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1723 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1724 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1725 if (!srcType || !dstType)
1732 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1737 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1739 unsigned resultBitwidth = 64;
1741 resultBitwidth = intTy.getWidth();
1743 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1744 adaptor.getOperands(),
getType(),
1745 [resultBitwidth](
const APInt &a,
bool & ) {
1746 return a.sextOrTrunc(resultBitwidth);
1750 void arith::IndexCastOp::getCanonicalizationPatterns(
1752 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1759 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1764 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1766 unsigned resultBitwidth = 64;
1768 resultBitwidth = intTy.getWidth();
1770 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1771 adaptor.getOperands(),
getType(),
1772 [resultBitwidth](
const APInt &a,
bool & ) {
1773 return a.zextOrTrunc(resultBitwidth);
1777 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1779 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1790 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1791 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1792 if (!srcType || !dstType)
1798 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1800 auto operand = adaptor.getIn();
1805 if (
auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1806 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1808 if (llvm::isa<ShapedType>(resType))
1812 if (llvm::isa<ub::PoisonAttr>(operand))
1816 APInt bits = llvm::isa<FloatAttr>(operand)
1817 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1818 : llvm::cast<IntegerAttr>(operand).getValue();
1820 "trying to fold on broken IR: operands have incompatible types");
1822 if (
auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1824 APFloat(resFloatType.getFloatSemantics(), bits));
1830 patterns.add<BitcastOfBitcast>(context);
1840 const APInt &lhs,
const APInt &rhs) {
1841 switch (predicate) {
1842 case arith::CmpIPredicate::eq:
1844 case arith::CmpIPredicate::ne:
1846 case arith::CmpIPredicate::slt:
1847 return lhs.slt(rhs);
1848 case arith::CmpIPredicate::sle:
1849 return lhs.sle(rhs);
1850 case arith::CmpIPredicate::sgt:
1851 return lhs.sgt(rhs);
1852 case arith::CmpIPredicate::sge:
1853 return lhs.sge(rhs);
1854 case arith::CmpIPredicate::ult:
1855 return lhs.ult(rhs);
1856 case arith::CmpIPredicate::ule:
1857 return lhs.ule(rhs);
1858 case arith::CmpIPredicate::ugt:
1859 return lhs.ugt(rhs);
1860 case arith::CmpIPredicate::uge:
1861 return lhs.uge(rhs);
1863 llvm_unreachable(
"unknown cmpi predicate kind");
1868 switch (predicate) {
1869 case arith::CmpIPredicate::eq:
1870 case arith::CmpIPredicate::sle:
1871 case arith::CmpIPredicate::sge:
1872 case arith::CmpIPredicate::ule:
1873 case arith::CmpIPredicate::uge:
1875 case arith::CmpIPredicate::ne:
1876 case arith::CmpIPredicate::slt:
1877 case arith::CmpIPredicate::sgt:
1878 case arith::CmpIPredicate::ult:
1879 case arith::CmpIPredicate::ugt:
1882 llvm_unreachable(
"unknown cmpi predicate kind");
1886 if (
auto intType = llvm::dyn_cast<IntegerType>(t)) {
1887 return intType.getWidth();
1889 if (
auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1890 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1892 return std::nullopt;
1895 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1897 if (getLhs() == getRhs()) {
1903 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1905 std::optional<int64_t> integerWidth =
1907 if (integerWidth && integerWidth.value() == 1 &&
1908 getPredicate() == arith::CmpIPredicate::ne)
1909 return extOp.getOperand();
1911 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1913 std::optional<int64_t> integerWidth =
1915 if (integerWidth && integerWidth.value() == 1 &&
1916 getPredicate() == arith::CmpIPredicate::ne)
1917 return extOp.getOperand();
1922 getPredicate() == arith::CmpIPredicate::ne)
1929 getPredicate() == arith::CmpIPredicate::eq)
1934 if (adaptor.getLhs() && !adaptor.getRhs()) {
1936 using Pred = CmpIPredicate;
1937 const std::pair<Pred, Pred> invPreds[] = {
1938 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1939 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1940 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1941 {Pred::ne, Pred::ne},
1943 Pred origPred = getPredicate();
1944 for (
auto pred : invPreds) {
1945 if (origPred == pred.first) {
1946 setPredicate(pred.second);
1947 Value lhs = getLhs();
1948 Value rhs = getRhs();
1949 getLhsMutable().assign(rhs);
1950 getRhsMutable().assign(lhs);
1954 llvm_unreachable(
"unknown cmpi predicate kind");
1959 if (
auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1960 return constFoldBinaryOp<IntegerAttr>(
1962 [pred = getPredicate()](
const APInt &lhs,
const APInt &rhs) {
1973 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1983 const APFloat &lhs,
const APFloat &rhs) {
1984 auto cmpResult = lhs.compare(rhs);
1985 switch (predicate) {
1986 case arith::CmpFPredicate::AlwaysFalse:
1988 case arith::CmpFPredicate::OEQ:
1989 return cmpResult == APFloat::cmpEqual;
1990 case arith::CmpFPredicate::OGT:
1991 return cmpResult == APFloat::cmpGreaterThan;
1992 case arith::CmpFPredicate::OGE:
1993 return cmpResult == APFloat::cmpGreaterThan ||
1994 cmpResult == APFloat::cmpEqual;
1995 case arith::CmpFPredicate::OLT:
1996 return cmpResult == APFloat::cmpLessThan;
1997 case arith::CmpFPredicate::OLE:
1998 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1999 case arith::CmpFPredicate::ONE:
2000 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2001 case arith::CmpFPredicate::ORD:
2002 return cmpResult != APFloat::cmpUnordered;
2003 case arith::CmpFPredicate::UEQ:
2004 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2005 case arith::CmpFPredicate::UGT:
2006 return cmpResult == APFloat::cmpUnordered ||
2007 cmpResult == APFloat::cmpGreaterThan;
2008 case arith::CmpFPredicate::UGE:
2009 return cmpResult == APFloat::cmpUnordered ||
2010 cmpResult == APFloat::cmpGreaterThan ||
2011 cmpResult == APFloat::cmpEqual;
2012 case arith::CmpFPredicate::ULT:
2013 return cmpResult == APFloat::cmpUnordered ||
2014 cmpResult == APFloat::cmpLessThan;
2015 case arith::CmpFPredicate::ULE:
2016 return cmpResult == APFloat::cmpUnordered ||
2017 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2018 case arith::CmpFPredicate::UNE:
2019 return cmpResult != APFloat::cmpEqual;
2020 case arith::CmpFPredicate::UNO:
2021 return cmpResult == APFloat::cmpUnordered;
2022 case arith::CmpFPredicate::AlwaysTrue:
2025 llvm_unreachable(
"unknown cmpf predicate kind");
2028 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2029 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2030 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2033 if (lhs && lhs.getValue().isNaN())
2035 if (rhs && rhs.getValue().isNaN())
2051 using namespace arith;
2053 case CmpFPredicate::UEQ:
2054 case CmpFPredicate::OEQ:
2055 return CmpIPredicate::eq;
2056 case CmpFPredicate::UGT:
2057 case CmpFPredicate::OGT:
2058 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2059 case CmpFPredicate::UGE:
2060 case CmpFPredicate::OGE:
2061 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2062 case CmpFPredicate::ULT:
2063 case CmpFPredicate::OLT:
2064 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2065 case CmpFPredicate::ULE:
2066 case CmpFPredicate::OLE:
2067 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2068 case CmpFPredicate::UNE:
2069 case CmpFPredicate::ONE:
2070 return CmpIPredicate::ne;
2072 llvm_unreachable(
"Unexpected predicate!");
2082 const APFloat &rhs = flt.getValue();
2090 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2091 int mantissaWidth = floatTy.getFPMantissaWidth();
2092 if (mantissaWidth <= 0)
2098 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2100 intVal = si.getIn();
2101 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2103 intVal = ui.getIn();
2110 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2111 auto intWidth = intTy.getWidth();
2114 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2119 if ((
int)intWidth > mantissaWidth) {
2121 int exponent = ilogb(rhs);
2122 if (exponent == APFloat::IEK_Inf) {
2123 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2124 if (maxExponent < (
int)valueBits) {
2131 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2140 switch (op.getPredicate()) {
2141 case CmpFPredicate::ORD:
2146 case CmpFPredicate::UNO:
2159 APFloat signedMax(rhs.getSemantics());
2160 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2161 APFloat::rmNearestTiesToEven);
2162 if (signedMax < rhs) {
2163 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2164 pred == CmpIPredicate::sle)
2175 APFloat unsignedMax(rhs.getSemantics());
2176 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2177 APFloat::rmNearestTiesToEven);
2178 if (unsignedMax < rhs) {
2179 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2180 pred == CmpIPredicate::ule)
2192 APFloat signedMin(rhs.getSemantics());
2193 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2194 APFloat::rmNearestTiesToEven);
2195 if (signedMin > rhs) {
2196 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2197 pred == CmpIPredicate::sge)
2207 APFloat unsignedMin(rhs.getSemantics());
2208 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2209 APFloat::rmNearestTiesToEven);
2210 if (unsignedMin > rhs) {
2211 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2212 pred == CmpIPredicate::uge)
2227 APSInt rhsInt(intWidth, isUnsigned);
2228 if (APFloat::opInvalidOp ==
2229 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2235 if (!rhs.isZero()) {
2236 APFloat apf(floatTy.getFloatSemantics(),
2238 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2240 bool equal = apf == rhs;
2246 case CmpIPredicate::ne:
2250 case CmpIPredicate::eq:
2254 case CmpIPredicate::ule:
2257 if (rhs.isNegative()) {
2263 case CmpIPredicate::sle:
2266 if (rhs.isNegative())
2267 pred = CmpIPredicate::slt;
2269 case CmpIPredicate::ult:
2272 if (rhs.isNegative()) {
2277 pred = CmpIPredicate::ule;
2279 case CmpIPredicate::slt:
2282 if (!rhs.isNegative())
2283 pred = CmpIPredicate::sle;
2285 case CmpIPredicate::ugt:
2288 if (rhs.isNegative()) {
2294 case CmpIPredicate::sgt:
2297 if (rhs.isNegative())
2298 pred = CmpIPredicate::sge;
2300 case CmpIPredicate::uge:
2303 if (rhs.isNegative()) {
2308 pred = CmpIPredicate::ugt;
2310 case CmpIPredicate::sge:
2313 if (!rhs.isNegative())
2314 pred = CmpIPredicate::sgt;
2324 rewriter.
create<ConstantOp>(
2325 op.getLoc(), intVal.
getType(),
2347 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2363 rewriter.
create<arith::XOrIOp>(
2364 op.getLoc(), op.getCondition(),
2365 rewriter.
create<arith::ConstantIntOp>(
2366 op.getLoc(), 1, op.getCondition().getType())));
2376 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2380 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2381 Value trueVal = getTrueValue();
2382 Value falseVal = getFalseValue();
2383 if (trueVal == falseVal)
2386 Value condition = getCondition();
2397 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2400 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2404 if (
getType().isSignlessInteger(1) &&
2409 if (
auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.
getDefiningOp())) {
2410 auto pred = cmp.getPredicate();
2411 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2412 auto cmpLhs = cmp.getLhs();
2413 auto cmpRhs = cmp.getRhs();
2421 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2422 (cmpRhs == trueVal && cmpLhs == falseVal))
2423 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2430 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2432 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2434 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2436 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2437 auto condVals = llvm::make_range(cond.value_begin<
BoolAttr>(),
2439 auto lhsVals = llvm::make_range(lhs.value_begin<
Attribute>(),
2441 auto rhsVals = llvm::make_range(rhs.value_begin<
Attribute>(),
2444 for (
auto [condVal, lhsVal, rhsVal] :
2445 llvm::zip_equal(condVals, lhsVals, rhsVals))
2446 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2457 Type conditionType, resultType;
2466 conditionType = resultType;
2475 {conditionType, resultType, resultType},
2480 p <<
" " << getOperands();
2483 if (ShapedType condType =
2484 llvm::dyn_cast<ShapedType>(getCondition().
getType()))
2485 p << condType <<
", ";
2490 Type conditionType = getCondition().getType();
2497 if (!llvm::isa<TensorType, VectorType>(resultType))
2498 return emitOpError() <<
"expected condition to be a signless i1, but got "
2501 if (conditionType != shapedConditionType) {
2502 return emitOpError() <<
"expected condition type to have the same shape "
2503 "as the result type, expected "
2504 << shapedConditionType <<
", but got "
2513 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2518 bool bounded =
false;
2519 auto result = constFoldBinaryOp<IntegerAttr>(
2520 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2521 bounded = b.ult(b.getBitWidth());
2531 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2536 bool bounded =
false;
2537 auto result = constFoldBinaryOp<IntegerAttr>(
2538 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2539 bounded = b.ult(b.getBitWidth());
2549 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2554 bool bounded =
false;
2555 auto result = constFoldBinaryOp<IntegerAttr>(
2556 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2557 bounded = b.ult(b.getBitWidth());
2570 bool useOnlyFiniteValue) {
2572 case AtomicRMWKind::maximumf: {
2573 const llvm::fltSemantics &semantic =
2574 llvm::cast<FloatType>(resultType).getFloatSemantics();
2575 APFloat identity = useOnlyFiniteValue
2576 ? APFloat::getLargest(semantic,
true)
2577 : APFloat::getInf(semantic,
true);
2580 case AtomicRMWKind::maxnumf: {
2581 const llvm::fltSemantics &semantic =
2582 llvm::cast<FloatType>(resultType).getFloatSemantics();
2583 APFloat identity = APFloat::getNaN(semantic,
true);
2586 case AtomicRMWKind::addf:
2587 case AtomicRMWKind::addi:
2588 case AtomicRMWKind::maxu:
2589 case AtomicRMWKind::ori:
2591 case AtomicRMWKind::andi:
2594 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2595 case AtomicRMWKind::maxs:
2597 resultType, APInt::getSignedMinValue(
2598 llvm::cast<IntegerType>(resultType).getWidth()));
2599 case AtomicRMWKind::minimumf: {
2600 const llvm::fltSemantics &semantic =
2601 llvm::cast<FloatType>(resultType).getFloatSemantics();
2602 APFloat identity = useOnlyFiniteValue
2603 ? APFloat::getLargest(semantic,
false)
2604 : APFloat::getInf(semantic,
false);
2608 case AtomicRMWKind::minnumf: {
2609 const llvm::fltSemantics &semantic =
2610 llvm::cast<FloatType>(resultType).getFloatSemantics();
2611 APFloat identity = APFloat::getNaN(semantic,
false);
2614 case AtomicRMWKind::mins:
2616 resultType, APInt::getSignedMaxValue(
2617 llvm::cast<IntegerType>(resultType).getWidth()));
2618 case AtomicRMWKind::minu:
2621 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2622 case AtomicRMWKind::muli:
2624 case AtomicRMWKind::mulf:
2636 std::optional<AtomicRMWKind> maybeKind =
2639 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2640 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2641 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2642 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2643 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2644 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2646 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2647 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2648 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::ori; })
2649 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2650 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2651 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2652 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2653 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2654 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2655 .Default([](
Operation *op) {
return std::nullopt; });
2657 return std::nullopt;
2660 bool useOnlyFiniteValue =
false;
2661 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2662 if (fmfOpInterface) {
2663 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2664 useOnlyFiniteValue =
2665 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2673 useOnlyFiniteValue);
2679 bool useOnlyFiniteValue) {
2682 return builder.
create<arith::ConstantOp>(loc, attr);
2690 case AtomicRMWKind::addf:
2691 return builder.
create<arith::AddFOp>(loc, lhs, rhs);
2692 case AtomicRMWKind::addi:
2693 return builder.
create<arith::AddIOp>(loc, lhs, rhs);
2694 case AtomicRMWKind::mulf:
2695 return builder.
create<arith::MulFOp>(loc, lhs, rhs);
2696 case AtomicRMWKind::muli:
2697 return builder.
create<arith::MulIOp>(loc, lhs, rhs);
2698 case AtomicRMWKind::maximumf:
2699 return builder.
create<arith::MaximumFOp>(loc, lhs, rhs);
2700 case AtomicRMWKind::minimumf:
2701 return builder.
create<arith::MinimumFOp>(loc, lhs, rhs);
2702 case AtomicRMWKind::maxnumf:
2703 return builder.
create<arith::MaxNumFOp>(loc, lhs, rhs);
2704 case AtomicRMWKind::minnumf:
2705 return builder.
create<arith::MinNumFOp>(loc, lhs, rhs);
2706 case AtomicRMWKind::maxs:
2707 return builder.
create<arith::MaxSIOp>(loc, lhs, rhs);
2708 case AtomicRMWKind::mins:
2709 return builder.
create<arith::MinSIOp>(loc, lhs, rhs);
2710 case AtomicRMWKind::maxu:
2711 return builder.
create<arith::MaxUIOp>(loc, lhs, rhs);
2712 case AtomicRMWKind::minu:
2713 return builder.
create<arith::MinUIOp>(loc, lhs, rhs);
2714 case AtomicRMWKind::ori:
2715 return builder.
create<arith::OrIOp>(loc, lhs, rhs);
2716 case AtomicRMWKind::andi:
2717 return builder.
create<arith::AndIOp>(loc, lhs, rhs);
2730 #define GET_OP_CLASSES
2731 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2737 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1204::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
Specialization of arith.constant op that returns an integer value.
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)