26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
45 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
46 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48 APInt value = binFn(lhsVal, rhsVal);
68 static IntegerOverflowFlagsAttr
70 IntegerOverflowFlagsAttr val2) {
72 val1.getValue() & val2.getValue());
78 case arith::CmpIPredicate::eq:
79 return arith::CmpIPredicate::ne;
80 case arith::CmpIPredicate::ne:
81 return arith::CmpIPredicate::eq;
82 case arith::CmpIPredicate::slt:
83 return arith::CmpIPredicate::sge;
84 case arith::CmpIPredicate::sle:
85 return arith::CmpIPredicate::sgt;
86 case arith::CmpIPredicate::sgt:
87 return arith::CmpIPredicate::sle;
88 case arith::CmpIPredicate::sge:
89 return arith::CmpIPredicate::slt;
90 case arith::CmpIPredicate::ult:
91 return arith::CmpIPredicate::uge;
92 case arith::CmpIPredicate::ule:
93 return arith::CmpIPredicate::ugt;
94 case arith::CmpIPredicate::ugt:
95 return arith::CmpIPredicate::ule;
96 case arith::CmpIPredicate::uge:
97 return arith::CmpIPredicate::ult;
99 llvm_unreachable(
"unknown cmpi predicate kind");
108 static llvm::RoundingMode
110 switch (roundingMode) {
111 case RoundingMode::downward:
112 return llvm::RoundingMode::TowardNegative;
113 case RoundingMode::to_nearest_away:
114 return llvm::RoundingMode::NearestTiesToAway;
115 case RoundingMode::to_nearest_even:
116 return llvm::RoundingMode::NearestTiesToEven;
117 case RoundingMode::toward_zero:
118 return llvm::RoundingMode::TowardZero;
119 case RoundingMode::upward:
120 return llvm::RoundingMode::TowardPositive;
122 llvm_unreachable(
"Unhandled rounding mode");
152 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
163 #include "ArithCanonicalization.inc"
173 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
174 return shapedType.cloneWith(std::nullopt, i1Type);
175 if (llvm::isa<UnrankedTensorType>(type))
184 void arith::ConstantOp::getAsmResultNames(
187 if (
auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188 auto intType = llvm::dyn_cast<IntegerType>(type);
191 if (intType && intType.getWidth() == 1)
192 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
196 llvm::raw_svector_ostream specialName(specialNameBuffer);
197 specialName <<
'c' << intCst.getValue();
199 specialName <<
'_' << type;
200 setNameFn(getResult(), specialName.str());
202 setNameFn(getResult(),
"cst");
211 if (llvm::isa<IntegerType>(type) &&
212 !llvm::cast<IntegerType>(type).isSignless())
213 return emitOpError(
"integer return type must be signless");
215 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
217 "value must be an integer, float, or elements attribute");
223 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
225 "intializing scalable vectors with elements attribute is not supported"
226 " unless it's a vector splat");
230 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
232 auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
233 if (!typedAttr || typedAttr.getType() != type)
236 if (llvm::isa<IntegerType>(type) &&
237 !llvm::cast<IntegerType>(type).isSignless())
240 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
245 if (isBuildableWith(value, type))
246 return builder.
create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
250 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
253 int64_t value,
unsigned width) {
255 arith::ConstantOp::build(builder, result, type,
260 Type type, int64_t value) {
261 arith::ConstantOp::build(builder, result, type,
266 Type type,
const APInt &value) {
267 arith::ConstantOp::build(builder, result, type,
272 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
273 return constOp.getType().isSignlessInteger();
278 FloatType type,
const APFloat &value) {
279 arith::ConstantOp::build(builder, result, type,
284 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
285 return llvm::isa<FloatType>(constOp.getType());
291 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
296 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
297 return constOp.getType().isIndex();
305 "type doesn't have a zero representation");
307 assert(zeroAttr &&
"unsupported type for zero attribute");
308 return builder.
create<arith::ConstantOp>(loc, zeroAttr);
321 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
322 if (getRhs() == sub.getRhs())
326 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
327 if (getLhs() == sub.getRhs())
330 return constFoldBinaryOp<IntegerAttr>(
331 adaptor.getOperands(),
332 [](APInt a,
const APInt &b) { return std::move(a) + b; });
337 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
338 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
345 std::optional<SmallVector<int64_t, 4>>
346 arith::AddUIExtendedOp::getShapeForUnroll() {
347 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
348 return llvm::to_vector<4>(vt.getShape());
355 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
359 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
361 Type overflowTy = getOverflow().getType();
367 results.push_back(getLhs());
368 results.push_back(falseValue);
376 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
377 adaptor.getOperands(),
378 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
379 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
380 ArrayRef({sumAttr, adaptor.getLhs()}),
386 results.push_back(sumAttr);
387 results.push_back(overflowAttr);
394 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
396 patterns.add<AddUIExtendedToAddI>(context);
405 if (getOperand(0) == getOperand(1)) {
406 auto shapedType = dyn_cast<ShapedType>(
getType());
408 if (!shapedType || shapedType.hasStaticShape())
415 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
417 if (getRhs() == add.getRhs())
420 if (getRhs() == add.getLhs())
424 return constFoldBinaryOp<IntegerAttr>(
425 adaptor.getOperands(),
426 [](APInt a,
const APInt &b) { return std::move(a) - b; });
431 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
432 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
433 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
450 return constFoldBinaryOp<IntegerAttr>(
451 adaptor.getOperands(),
452 [](
const APInt &a,
const APInt &b) { return a * b; });
455 void arith::MulIOp::getAsmResultNames(
457 if (!isa<IndexType>(
getType()))
463 return op && op->getName().getStringRef() ==
"vector.vscale";
466 IntegerAttr baseValue;
469 isVscale(b.getDefiningOp());
472 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
477 llvm::raw_svector_ostream specialName(specialNameBuffer);
478 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
479 setNameFn(getResult(), specialName.str());
484 patterns.add<MulIMulIConstant>(context);
491 std::optional<SmallVector<int64_t, 4>>
492 arith::MulSIExtendedOp::getShapeForUnroll() {
493 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
494 return llvm::to_vector<4>(vt.getShape());
499 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
504 results.push_back(zero);
505 results.push_back(zero);
510 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
511 adaptor.getOperands(),
512 [](
const APInt &a,
const APInt &b) { return a * b; })) {
514 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
515 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
516 return llvm::APIntOps::mulhs(a, b);
518 assert(highAttr &&
"Unexpected constant-folding failure");
520 results.push_back(lowAttr);
521 results.push_back(highAttr);
528 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
530 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
537 std::optional<SmallVector<int64_t, 4>>
538 arith::MulUIExtendedOp::getShapeForUnroll() {
539 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
540 return llvm::to_vector<4>(vt.getShape());
545 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
550 results.push_back(zero);
551 results.push_back(zero);
559 results.push_back(getLhs());
560 results.push_back(zero);
565 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
566 adaptor.getOperands(),
567 [](
const APInt &a,
const APInt &b) { return a * b; })) {
569 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
570 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
571 return llvm::APIntOps::mulhu(a, b);
573 assert(highAttr &&
"Unexpected constant-folding failure");
575 results.push_back(lowAttr);
576 results.push_back(highAttr);
583 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
585 patterns.add<MulUIExtendedToMulI>(context);
594 arith::IntegerOverflowFlags ovfFlags) {
596 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
599 if (mul.getLhs() == rhs)
602 if (mul.getRhs() == rhs)
608 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
614 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
619 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
620 [&](APInt a,
const APInt &b) {
648 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
654 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
658 bool overflowOrDiv0 =
false;
659 auto result = constFoldBinaryOp<IntegerAttr>(
660 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
661 if (overflowOrDiv0 || !b) {
662 overflowOrDiv0 = true;
665 return a.sdiv_ov(b, overflowOrDiv0);
668 return overflowOrDiv0 ?
Attribute() : result;
695 APInt one(a.getBitWidth(), 1,
true);
696 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
697 return val.sadd_ov(one, overflow);
704 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
709 bool overflowOrDiv0 =
false;
710 auto result = constFoldBinaryOp<IntegerAttr>(
711 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
712 if (overflowOrDiv0 || !b) {
713 overflowOrDiv0 = true;
716 APInt quotient = a.udiv(b);
719 APInt one(a.getBitWidth(), 1,
true);
720 return quotient.uadd_ov(one, overflowOrDiv0);
723 return overflowOrDiv0 ?
Attribute() : result;
734 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
742 bool overflowOrDiv0 =
false;
743 auto result = constFoldBinaryOp<IntegerAttr>(
744 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
745 if (overflowOrDiv0 || !b) {
746 overflowOrDiv0 = true;
752 unsigned bits = a.getBitWidth();
754 bool aGtZero = a.sgt(zero);
755 bool bGtZero = b.sgt(zero);
756 if (aGtZero && bGtZero) {
763 bool overflowNegA =
false;
764 bool overflowNegB =
false;
765 bool overflowDiv =
false;
766 bool overflowNegRes =
false;
767 if (!aGtZero && !bGtZero) {
769 APInt posA = zero.ssub_ov(a, overflowNegA);
770 APInt posB = zero.ssub_ov(b, overflowNegB);
772 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
775 if (!aGtZero && bGtZero) {
777 APInt posA = zero.ssub_ov(a, overflowNegA);
778 APInt div = posA.sdiv_ov(b, overflowDiv);
779 APInt res = zero.ssub_ov(div, overflowNegRes);
780 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
784 APInt posB = zero.ssub_ov(b, overflowNegB);
785 APInt div = a.sdiv_ov(posB, overflowDiv);
786 APInt res = zero.ssub_ov(div, overflowNegRes);
788 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
792 return overflowOrDiv0 ?
Attribute() : result;
803 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
809 bool overflowOrDiv =
false;
810 auto result = constFoldBinaryOp<IntegerAttr>(
811 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
813 overflowOrDiv = true;
816 return a.sfloordiv_ov(b, overflowOrDiv);
819 return overflowOrDiv ?
Attribute() : result;
826 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
833 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
834 [&](APInt a,
const APInt &b) {
835 if (div0 || b.isZero()) {
849 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
856 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
857 [&](APInt a,
const APInt &b) {
858 if (div0 || b.isZero()) {
874 for (
bool reversePrev : {
false,
true}) {
875 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
876 .getDefiningOp<arith::AndIOp>();
880 Value other = (reversePrev ? op.getLhs() : op.getRhs());
881 if (other != prev.getLhs() && other != prev.getRhs())
884 return prev.getResult();
896 intValue.isAllOnes())
901 intValue.isAllOnes())
906 intValue.isAllOnes())
913 return constFoldBinaryOp<IntegerAttr>(
914 adaptor.getOperands(),
915 [](APInt a,
const APInt &b) { return std::move(a) & b; });
928 if (rhsVal.isAllOnes())
929 return adaptor.getRhs();
936 intValue.isAllOnes())
937 return getRhs().getDefiningOp<XOrIOp>().getRhs();
941 intValue.isAllOnes())
942 return getLhs().getDefiningOp<XOrIOp>().getRhs();
944 return constFoldBinaryOp<IntegerAttr>(
945 adaptor.getOperands(),
946 [](APInt a,
const APInt &b) { return std::move(a) | b; });
958 if (getLhs() == getRhs())
962 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
963 if (prev.getRhs() == getRhs())
964 return prev.getLhs();
965 if (prev.getLhs() == getRhs())
966 return prev.getRhs();
970 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
971 if (prev.getRhs() == getLhs())
972 return prev.getLhs();
973 if (prev.getLhs() == getLhs())
974 return prev.getRhs();
977 return constFoldBinaryOp<IntegerAttr>(
978 adaptor.getOperands(),
979 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
984 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
993 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
994 return op.getOperand();
995 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
996 [](
const APFloat &a) { return -a; });
1003 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1008 return constFoldBinaryOp<FloatAttr>(
1009 adaptor.getOperands(),
1010 [](
const APFloat &a,
const APFloat &b) { return a + b; });
1017 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1022 return constFoldBinaryOp<FloatAttr>(
1023 adaptor.getOperands(),
1024 [](
const APFloat &a,
const APFloat &b) { return a - b; });
1031 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1033 if (getLhs() == getRhs())
1040 return constFoldBinaryOp<FloatAttr>(
1041 adaptor.getOperands(),
1042 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
1049 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1051 if (getLhs() == getRhs())
1058 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1067 if (getLhs() == getRhs())
1073 if (intValue.isMaxSignedValue())
1076 if (intValue.isMinSignedValue())
1080 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1081 [](
const APInt &a,
const APInt &b) {
1082 return llvm::APIntOps::smax(a, b);
1092 if (getLhs() == getRhs())
1098 if (intValue.isMaxValue())
1101 if (intValue.isMinValue())
1105 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1106 [](
const APInt &a,
const APInt &b) {
1107 return llvm::APIntOps::umax(a, b);
1115 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1117 if (getLhs() == getRhs())
1124 return constFoldBinaryOp<FloatAttr>(
1125 adaptor.getOperands(),
1126 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
1133 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1135 if (getLhs() == getRhs())
1142 return constFoldBinaryOp<FloatAttr>(
1143 adaptor.getOperands(),
1144 [](
const APFloat &a,
const APFloat &b) { return llvm::minnum(a, b); });
1153 if (getLhs() == getRhs())
1159 if (intValue.isMinSignedValue())
1162 if (intValue.isMaxSignedValue())
1166 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1167 [](
const APInt &a,
const APInt &b) {
1168 return llvm::APIntOps::smin(a, b);
1178 if (getLhs() == getRhs())
1184 if (intValue.isMinValue())
1187 if (intValue.isMaxValue())
1191 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1192 [](
const APInt &a,
const APInt &b) {
1193 return llvm::APIntOps::umin(a, b);
1201 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1206 return constFoldBinaryOp<FloatAttr>(
1207 adaptor.getOperands(),
1208 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1220 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1225 return constFoldBinaryOp<FloatAttr>(
1226 adaptor.getOperands(),
1227 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1239 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1240 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1241 [](
const APFloat &a,
const APFloat &b) {
1246 (void)result.mod(b);
1255 template <
typename... Types>
1261 template <
typename... ShapedTypes,
typename... ElementTypes>
1264 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1268 if (!llvm::isa<ElementTypes...>(underlyingType))
1271 return underlyingType;
1275 template <
typename... ElementTypes>
1282 template <
typename... ElementTypes>
1291 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1292 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1293 if (!rankedTensorA || !rankedTensorB)
1295 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1299 if (inputs.size() != 1 || outputs.size() != 1)
1311 template <
typename ValType,
typename Op>
1316 if (llvm::cast<ValType>(srcType).getWidth() >=
1317 llvm::cast<ValType>(dstType).getWidth())
1319 << dstType <<
" must be wider than operand type " << srcType;
1325 template <
typename ValType,
typename Op>
1330 if (llvm::cast<ValType>(srcType).getWidth() <=
1331 llvm::cast<ValType>(dstType).getWidth())
1333 << dstType <<
" must be shorter than operand type " << srcType;
1339 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1344 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1345 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1346 if (!srcType || !dstType)
1349 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1350 srcType.getIntOrFloatBitWidth());
1356 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1357 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1358 bool losesInfo =
false;
1359 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1360 if (losesInfo || status != APFloat::opOK)
1370 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1371 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1372 getInMutable().assign(lhs.getIn());
1377 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1378 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1379 adaptor.getOperands(),
getType(),
1380 [bitWidth](
const APInt &a,
bool &castStatus) {
1381 return a.zext(bitWidth);
1386 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1390 return verifyExtOp<IntegerType>(*
this);
1397 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1398 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1399 getInMutable().assign(lhs.getIn());
1404 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1405 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1406 adaptor.getOperands(),
getType(),
1407 [bitWidth](
const APInt &a,
bool &castStatus) {
1408 return a.sext(bitWidth);
1413 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1418 patterns.add<ExtSIOfExtUI>(context);
1422 return verifyExtOp<IntegerType>(*
this);
1431 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1432 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1433 if (truncFOp.getOperand().getType() ==
getType()) {
1434 arith::FastMathFlags truncFMF =
1435 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1436 bool isTruncContract =
1438 arith::FastMathFlags extFMF =
1439 getFastmath().value_or(arith::FastMathFlags::none);
1440 bool isExtContract =
1442 if (isTruncContract && isExtContract) {
1443 return truncFOp.getOperand();
1449 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1450 return constFoldCastOp<FloatAttr, FloatAttr>(
1451 adaptor.getOperands(),
getType(),
1452 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1454 if (failed(result)) {
1463 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1472 bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
1474 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1478 return verifyExtOp<FloatType>(*
this);
1485 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1486 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1493 if (llvm::cast<IntegerType>(srcType).getWidth() >
1494 llvm::cast<IntegerType>(dstType).getWidth()) {
1501 if (srcType == dstType)
1506 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1507 setOperand(getOperand().getDefiningOp()->getOperand(0));
1512 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1513 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1514 adaptor.getOperands(),
getType(),
1515 [bitWidth](
const APInt &a,
bool &castStatus) {
1516 return a.trunc(bitWidth);
1521 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1527 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1532 return verifyTruncateOp<IntegerType>(*
this);
1541 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1543 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1544 Value src = extOp.getIn();
1546 auto intermediateType =
1549 if (llvm::APFloatBase::isRepresentableBy(
1550 srcType.getFloatSemantics(),
1551 intermediateType.getFloatSemantics())) {
1553 if (srcType.getWidth() > resElemType.getWidth()) {
1559 if (srcType == resElemType)
1564 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1565 return constFoldCastOp<FloatAttr, FloatAttr>(
1566 adaptor.getOperands(),
getType(),
1567 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1568 RoundingMode roundingMode =
1569 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1570 llvm::RoundingMode llvmRoundingMode =
1572 FailureOr<APFloat> result =
1574 if (failed(result)) {
1584 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1588 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1592 return verifyTruncateOp<FloatType>(*
this);
1599 bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
1601 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1605 return verifyTruncateOp<FloatType>(*
this);
1614 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1623 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1630 template <
typename From,
typename To>
1635 auto srcType = getTypeIfLike<From>(inputs.front());
1636 auto dstType = getTypeIfLike<To>(outputs.back());
1638 return srcType && dstType;
1646 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1649 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1651 return constFoldCastOp<IntegerAttr, FloatAttr>(
1652 adaptor.getOperands(),
getType(),
1653 [&resEleType](
const APInt &a,
bool &castStatus) {
1654 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1655 APFloat apf(floatTy.getFloatSemantics(),
1657 apf.convertFromAPInt(a,
false,
1658 APFloat::rmNearestTiesToEven);
1668 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1671 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1673 return constFoldCastOp<IntegerAttr, FloatAttr>(
1674 adaptor.getOperands(),
getType(),
1675 [&resEleType](
const APInt &a,
bool &castStatus) {
1676 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1677 APFloat apf(floatTy.getFloatSemantics(),
1679 apf.convertFromAPInt(a,
true,
1680 APFloat::rmNearestTiesToEven);
1690 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1693 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1695 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1696 return constFoldCastOp<FloatAttr, IntegerAttr>(
1697 adaptor.getOperands(),
getType(),
1698 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1700 APSInt api(bitWidth,
true);
1701 castStatus = APFloat::opInvalidOp !=
1702 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1712 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1715 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1717 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1718 return constFoldCastOp<FloatAttr, IntegerAttr>(
1719 adaptor.getOperands(),
getType(),
1720 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1722 APSInt api(bitWidth,
false);
1723 castStatus = APFloat::opInvalidOp !=
1724 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1737 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1738 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1739 if (!srcType || !dstType)
1746 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1751 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1753 unsigned resultBitwidth = 64;
1755 resultBitwidth = intTy.getWidth();
1757 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1758 adaptor.getOperands(),
getType(),
1759 [resultBitwidth](
const APInt &a,
bool & ) {
1760 return a.sextOrTrunc(resultBitwidth);
1764 void arith::IndexCastOp::getCanonicalizationPatterns(
1766 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1773 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1778 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1780 unsigned resultBitwidth = 64;
1782 resultBitwidth = intTy.getWidth();
1784 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1785 adaptor.getOperands(),
getType(),
1786 [resultBitwidth](
const APInt &a,
bool & ) {
1787 return a.zextOrTrunc(resultBitwidth);
1791 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1793 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1804 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1805 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1806 if (!srcType || !dstType)
1812 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1814 auto operand = adaptor.getIn();
1819 if (
auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1820 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1822 if (llvm::isa<ShapedType>(resType))
1826 if (llvm::isa<ub::PoisonAttr>(operand))
1830 APInt bits = llvm::isa<FloatAttr>(operand)
1831 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1832 : llvm::cast<IntegerAttr>(operand).getValue();
1834 "trying to fold on broken IR: operands have incompatible types");
1836 if (
auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1838 APFloat(resFloatType.getFloatSemantics(), bits));
1844 patterns.add<BitcastOfBitcast>(context);
1854 const APInt &lhs,
const APInt &rhs) {
1855 switch (predicate) {
1856 case arith::CmpIPredicate::eq:
1858 case arith::CmpIPredicate::ne:
1860 case arith::CmpIPredicate::slt:
1861 return lhs.slt(rhs);
1862 case arith::CmpIPredicate::sle:
1863 return lhs.sle(rhs);
1864 case arith::CmpIPredicate::sgt:
1865 return lhs.sgt(rhs);
1866 case arith::CmpIPredicate::sge:
1867 return lhs.sge(rhs);
1868 case arith::CmpIPredicate::ult:
1869 return lhs.ult(rhs);
1870 case arith::CmpIPredicate::ule:
1871 return lhs.ule(rhs);
1872 case arith::CmpIPredicate::ugt:
1873 return lhs.ugt(rhs);
1874 case arith::CmpIPredicate::uge:
1875 return lhs.uge(rhs);
1877 llvm_unreachable(
"unknown cmpi predicate kind");
1882 switch (predicate) {
1883 case arith::CmpIPredicate::eq:
1884 case arith::CmpIPredicate::sle:
1885 case arith::CmpIPredicate::sge:
1886 case arith::CmpIPredicate::ule:
1887 case arith::CmpIPredicate::uge:
1889 case arith::CmpIPredicate::ne:
1890 case arith::CmpIPredicate::slt:
1891 case arith::CmpIPredicate::sgt:
1892 case arith::CmpIPredicate::ult:
1893 case arith::CmpIPredicate::ugt:
1896 llvm_unreachable(
"unknown cmpi predicate kind");
1900 if (
auto intType = llvm::dyn_cast<IntegerType>(t)) {
1901 return intType.getWidth();
1903 if (
auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1904 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1906 return std::nullopt;
1909 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1911 if (getLhs() == getRhs()) {
1917 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1919 std::optional<int64_t> integerWidth =
1921 if (integerWidth && integerWidth.value() == 1 &&
1922 getPredicate() == arith::CmpIPredicate::ne)
1923 return extOp.getOperand();
1925 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1927 std::optional<int64_t> integerWidth =
1929 if (integerWidth && integerWidth.value() == 1 &&
1930 getPredicate() == arith::CmpIPredicate::ne)
1931 return extOp.getOperand();
1936 getPredicate() == arith::CmpIPredicate::ne)
1943 getPredicate() == arith::CmpIPredicate::eq)
1948 if (adaptor.getLhs() && !adaptor.getRhs()) {
1950 using Pred = CmpIPredicate;
1951 const std::pair<Pred, Pred> invPreds[] = {
1952 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1953 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1954 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1955 {Pred::ne, Pred::ne},
1957 Pred origPred = getPredicate();
1958 for (
auto pred : invPreds) {
1959 if (origPred == pred.first) {
1960 setPredicate(pred.second);
1961 Value lhs = getLhs();
1962 Value rhs = getRhs();
1963 getLhsMutable().assign(rhs);
1964 getRhsMutable().assign(lhs);
1968 llvm_unreachable(
"unknown cmpi predicate kind");
1973 if (
auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1974 return constFoldBinaryOp<IntegerAttr>(
1976 [pred = getPredicate()](
const APInt &lhs,
const APInt &rhs) {
1987 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1997 const APFloat &lhs,
const APFloat &rhs) {
1998 auto cmpResult = lhs.compare(rhs);
1999 switch (predicate) {
2000 case arith::CmpFPredicate::AlwaysFalse:
2002 case arith::CmpFPredicate::OEQ:
2003 return cmpResult == APFloat::cmpEqual;
2004 case arith::CmpFPredicate::OGT:
2005 return cmpResult == APFloat::cmpGreaterThan;
2006 case arith::CmpFPredicate::OGE:
2007 return cmpResult == APFloat::cmpGreaterThan ||
2008 cmpResult == APFloat::cmpEqual;
2009 case arith::CmpFPredicate::OLT:
2010 return cmpResult == APFloat::cmpLessThan;
2011 case arith::CmpFPredicate::OLE:
2012 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2013 case arith::CmpFPredicate::ONE:
2014 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2015 case arith::CmpFPredicate::ORD:
2016 return cmpResult != APFloat::cmpUnordered;
2017 case arith::CmpFPredicate::UEQ:
2018 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2019 case arith::CmpFPredicate::UGT:
2020 return cmpResult == APFloat::cmpUnordered ||
2021 cmpResult == APFloat::cmpGreaterThan;
2022 case arith::CmpFPredicate::UGE:
2023 return cmpResult == APFloat::cmpUnordered ||
2024 cmpResult == APFloat::cmpGreaterThan ||
2025 cmpResult == APFloat::cmpEqual;
2026 case arith::CmpFPredicate::ULT:
2027 return cmpResult == APFloat::cmpUnordered ||
2028 cmpResult == APFloat::cmpLessThan;
2029 case arith::CmpFPredicate::ULE:
2030 return cmpResult == APFloat::cmpUnordered ||
2031 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2032 case arith::CmpFPredicate::UNE:
2033 return cmpResult != APFloat::cmpEqual;
2034 case arith::CmpFPredicate::UNO:
2035 return cmpResult == APFloat::cmpUnordered;
2036 case arith::CmpFPredicate::AlwaysTrue:
2039 llvm_unreachable(
"unknown cmpf predicate kind");
2042 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2043 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2044 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2047 if (lhs && lhs.getValue().isNaN())
2049 if (rhs && rhs.getValue().isNaN())
2065 using namespace arith;
2067 case CmpFPredicate::UEQ:
2068 case CmpFPredicate::OEQ:
2069 return CmpIPredicate::eq;
2070 case CmpFPredicate::UGT:
2071 case CmpFPredicate::OGT:
2072 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2073 case CmpFPredicate::UGE:
2074 case CmpFPredicate::OGE:
2075 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2076 case CmpFPredicate::ULT:
2077 case CmpFPredicate::OLT:
2078 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2079 case CmpFPredicate::ULE:
2080 case CmpFPredicate::OLE:
2081 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2082 case CmpFPredicate::UNE:
2083 case CmpFPredicate::ONE:
2084 return CmpIPredicate::ne;
2086 llvm_unreachable(
"Unexpected predicate!");
2096 const APFloat &rhs = flt.getValue();
2104 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2105 int mantissaWidth = floatTy.getFPMantissaWidth();
2106 if (mantissaWidth <= 0)
2112 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2114 intVal = si.getIn();
2115 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2117 intVal = ui.getIn();
2124 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2125 auto intWidth = intTy.getWidth();
2128 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2133 if ((
int)intWidth > mantissaWidth) {
2135 int exponent = ilogb(rhs);
2136 if (exponent == APFloat::IEK_Inf) {
2137 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2138 if (maxExponent < (
int)valueBits) {
2145 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2154 switch (op.getPredicate()) {
2155 case CmpFPredicate::ORD:
2160 case CmpFPredicate::UNO:
2173 APFloat signedMax(rhs.getSemantics());
2174 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2175 APFloat::rmNearestTiesToEven);
2176 if (signedMax < rhs) {
2177 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2178 pred == CmpIPredicate::sle)
2189 APFloat unsignedMax(rhs.getSemantics());
2190 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2191 APFloat::rmNearestTiesToEven);
2192 if (unsignedMax < rhs) {
2193 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2194 pred == CmpIPredicate::ule)
2206 APFloat signedMin(rhs.getSemantics());
2207 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2208 APFloat::rmNearestTiesToEven);
2209 if (signedMin > rhs) {
2210 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2211 pred == CmpIPredicate::sge)
2221 APFloat unsignedMin(rhs.getSemantics());
2222 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2223 APFloat::rmNearestTiesToEven);
2224 if (unsignedMin > rhs) {
2225 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2226 pred == CmpIPredicate::uge)
2241 APSInt rhsInt(intWidth, isUnsigned);
2242 if (APFloat::opInvalidOp ==
2243 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2249 if (!rhs.isZero()) {
2250 APFloat apf(floatTy.getFloatSemantics(),
2252 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2254 bool equal = apf == rhs;
2260 case CmpIPredicate::ne:
2264 case CmpIPredicate::eq:
2268 case CmpIPredicate::ule:
2271 if (rhs.isNegative()) {
2277 case CmpIPredicate::sle:
2280 if (rhs.isNegative())
2281 pred = CmpIPredicate::slt;
2283 case CmpIPredicate::ult:
2286 if (rhs.isNegative()) {
2291 pred = CmpIPredicate::ule;
2293 case CmpIPredicate::slt:
2296 if (!rhs.isNegative())
2297 pred = CmpIPredicate::sle;
2299 case CmpIPredicate::ugt:
2302 if (rhs.isNegative()) {
2308 case CmpIPredicate::sgt:
2311 if (rhs.isNegative())
2312 pred = CmpIPredicate::sge;
2314 case CmpIPredicate::uge:
2317 if (rhs.isNegative()) {
2322 pred = CmpIPredicate::ugt;
2324 case CmpIPredicate::sge:
2327 if (!rhs.isNegative())
2328 pred = CmpIPredicate::sgt;
2338 rewriter.
create<ConstantOp>(
2339 op.getLoc(), intVal.
getType(),
2361 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2377 rewriter.
create<arith::XOrIOp>(
2378 op.getLoc(), op.getCondition(),
2379 rewriter.
create<arith::ConstantIntOp>(
2380 op.getLoc(), op.getCondition().getType(), 1)));
2390 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2394 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2395 Value trueVal = getTrueValue();
2396 Value falseVal = getFalseValue();
2397 if (trueVal == falseVal)
2400 Value condition = getCondition();
2411 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2414 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2418 if (
getType().isSignlessInteger(1) &&
2423 if (
auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.
getDefiningOp())) {
2424 auto pred = cmp.getPredicate();
2425 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2426 auto cmpLhs = cmp.getLhs();
2427 auto cmpRhs = cmp.getRhs();
2435 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2436 (cmpRhs == trueVal && cmpLhs == falseVal))
2437 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2444 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2446 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2448 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2450 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2451 auto condVals = llvm::make_range(cond.value_begin<
BoolAttr>(),
2453 auto lhsVals = llvm::make_range(lhs.value_begin<
Attribute>(),
2455 auto rhsVals = llvm::make_range(rhs.value_begin<
Attribute>(),
2458 for (
auto [condVal, lhsVal, rhsVal] :
2459 llvm::zip_equal(condVals, lhsVals, rhsVals))
2460 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2471 Type conditionType, resultType;
2480 conditionType = resultType;
2489 {conditionType, resultType, resultType},
2494 p <<
" " << getOperands();
2497 if (ShapedType condType =
2498 llvm::dyn_cast<ShapedType>(getCondition().
getType()))
2499 p << condType <<
", ";
2504 Type conditionType = getCondition().getType();
2511 if (!llvm::isa<TensorType, VectorType>(resultType))
2512 return emitOpError() <<
"expected condition to be a signless i1, but got "
2515 if (conditionType != shapedConditionType) {
2516 return emitOpError() <<
"expected condition type to have the same shape "
2517 "as the result type, expected "
2518 << shapedConditionType <<
", but got "
2527 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2532 bool bounded =
false;
2533 auto result = constFoldBinaryOp<IntegerAttr>(
2534 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2535 bounded = b.ult(b.getBitWidth());
2545 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2550 bool bounded =
false;
2551 auto result = constFoldBinaryOp<IntegerAttr>(
2552 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2553 bounded = b.ult(b.getBitWidth());
2563 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2568 bool bounded =
false;
2569 auto result = constFoldBinaryOp<IntegerAttr>(
2570 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2571 bounded = b.ult(b.getBitWidth());
2584 bool useOnlyFiniteValue) {
2586 case AtomicRMWKind::maximumf: {
2587 const llvm::fltSemantics &semantic =
2588 llvm::cast<FloatType>(resultType).getFloatSemantics();
2589 APFloat identity = useOnlyFiniteValue
2590 ? APFloat::getLargest(semantic,
true)
2591 : APFloat::getInf(semantic,
true);
2594 case AtomicRMWKind::maxnumf: {
2595 const llvm::fltSemantics &semantic =
2596 llvm::cast<FloatType>(resultType).getFloatSemantics();
2597 APFloat identity = APFloat::getNaN(semantic,
true);
2600 case AtomicRMWKind::addf:
2601 case AtomicRMWKind::addi:
2602 case AtomicRMWKind::maxu:
2603 case AtomicRMWKind::ori:
2605 case AtomicRMWKind::andi:
2608 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2609 case AtomicRMWKind::maxs:
2611 resultType, APInt::getSignedMinValue(
2612 llvm::cast<IntegerType>(resultType).getWidth()));
2613 case AtomicRMWKind::minimumf: {
2614 const llvm::fltSemantics &semantic =
2615 llvm::cast<FloatType>(resultType).getFloatSemantics();
2616 APFloat identity = useOnlyFiniteValue
2617 ? APFloat::getLargest(semantic,
false)
2618 : APFloat::getInf(semantic,
false);
2622 case AtomicRMWKind::minnumf: {
2623 const llvm::fltSemantics &semantic =
2624 llvm::cast<FloatType>(resultType).getFloatSemantics();
2625 APFloat identity = APFloat::getNaN(semantic,
false);
2628 case AtomicRMWKind::mins:
2630 resultType, APInt::getSignedMaxValue(
2631 llvm::cast<IntegerType>(resultType).getWidth()));
2632 case AtomicRMWKind::minu:
2635 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2636 case AtomicRMWKind::muli:
2638 case AtomicRMWKind::mulf:
2650 std::optional<AtomicRMWKind> maybeKind =
2653 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2654 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2655 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2656 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2657 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2658 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2660 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2661 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2662 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::ori; })
2663 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2664 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2665 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2666 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2667 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2668 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2669 .Default([](
Operation *op) {
return std::nullopt; });
2671 return std::nullopt;
2674 bool useOnlyFiniteValue =
false;
2675 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2676 if (fmfOpInterface) {
2677 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2678 useOnlyFiniteValue =
2679 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2687 useOnlyFiniteValue);
2693 bool useOnlyFiniteValue) {
2696 return builder.
create<arith::ConstantOp>(loc, attr);
2704 case AtomicRMWKind::addf:
2705 return builder.
create<arith::AddFOp>(loc, lhs, rhs);
2706 case AtomicRMWKind::addi:
2707 return builder.
create<arith::AddIOp>(loc, lhs, rhs);
2708 case AtomicRMWKind::mulf:
2709 return builder.
create<arith::MulFOp>(loc, lhs, rhs);
2710 case AtomicRMWKind::muli:
2711 return builder.
create<arith::MulIOp>(loc, lhs, rhs);
2712 case AtomicRMWKind::maximumf:
2713 return builder.
create<arith::MaximumFOp>(loc, lhs, rhs);
2714 case AtomicRMWKind::minimumf:
2715 return builder.
create<arith::MinimumFOp>(loc, lhs, rhs);
2716 case AtomicRMWKind::maxnumf:
2717 return builder.
create<arith::MaxNumFOp>(loc, lhs, rhs);
2718 case AtomicRMWKind::minnumf:
2719 return builder.
create<arith::MinNumFOp>(loc, lhs, rhs);
2720 case AtomicRMWKind::maxs:
2721 return builder.
create<arith::MaxSIOp>(loc, lhs, rhs);
2722 case AtomicRMWKind::mins:
2723 return builder.
create<arith::MinSIOp>(loc, lhs, rhs);
2724 case AtomicRMWKind::maxu:
2725 return builder.
create<arith::MaxUIOp>(loc, lhs, rhs);
2726 case AtomicRMWKind::minu:
2727 return builder.
create<arith::MinUIOp>(loc, lhs, rhs);
2728 case AtomicRMWKind::ori:
2729 return builder.
create<arith::OrIOp>(loc, lhs, rhs);
2730 case AtomicRMWKind::andi:
2731 return builder.
create<arith::AndIOp>(loc, lhs, rhs);
2744 #define GET_OP_CLASSES
2745 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2751 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1216::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
Specialization of arith.constant op that returns an integer value.
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)