25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/APInt.h"
27 #include "llvm/ADT/APSInt.h"
28 #include "llvm/ADT/FloatingPointMode.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallString.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
44 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
45 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
46 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
67 static IntegerOverflowFlagsAttr
69 IntegerOverflowFlagsAttr val2) {
71 val1.getValue() & val2.getValue());
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
98 llvm_unreachable(
"unknown cmpi predicate kind");
107 static llvm::RoundingMode
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
121 llvm_unreachable(
"Unhandled rounding mode");
151 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
162 #include "ArithCanonicalization.inc"
172 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
173 return shapedType.cloneWith(std::nullopt, i1Type);
174 if (llvm::isa<UnrankedTensorType>(type))
183 void arith::ConstantOp::getAsmResultNames(
186 if (
auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
187 auto intType = llvm::dyn_cast<IntegerType>(type);
190 if (intType && intType.getWidth() == 1)
191 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
195 llvm::raw_svector_ostream specialName(specialNameBuffer);
196 specialName <<
'c' << intCst.getValue();
198 specialName <<
'_' << type;
199 setNameFn(getResult(), specialName.str());
201 setNameFn(getResult(),
"cst");
210 if (getValue().
getType() != type) {
211 return emitOpError() <<
"value type " << getValue().getType()
212 <<
" must match return type: " << type;
215 if (llvm::isa<IntegerType>(type) &&
216 !llvm::cast<IntegerType>(type).isSignless())
217 return emitOpError(
"integer return type must be signless");
219 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
221 "value must be an integer, float, or elements attribute");
227 auto vecType = dyn_cast<VectorType>(type);
228 if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
230 "intializing scalable vectors with elements attribute is not supported"
231 " unless it's a vector splat");
235 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
237 auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
238 if (!typedAttr || typedAttr.getType() != type)
241 if (llvm::isa<IntegerType>(type) &&
242 !llvm::cast<IntegerType>(type).isSignless())
245 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
250 if (isBuildableWith(value, type))
251 return builder.
create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
255 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
258 int64_t value,
unsigned width) {
260 arith::ConstantOp::build(builder, result, type,
265 int64_t value,
Type type) {
267 "ConstantIntOp can only have signless integer type values");
268 arith::ConstantOp::build(builder, result, type,
273 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
274 return constOp.getType().isSignlessInteger();
280 arith::ConstantOp::build(builder, result, type,
285 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
286 return llvm::isa<FloatType>(constOp.getType());
292 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
297 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
298 return constOp.getType().isIndex();
312 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
313 if (getRhs() == sub.getRhs())
317 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
318 if (getLhs() == sub.getRhs())
321 return constFoldBinaryOp<IntegerAttr>(
322 adaptor.getOperands(),
323 [](APInt a,
const APInt &b) { return std::move(a) + b; });
328 patterns.
add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
329 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
336 std::optional<SmallVector<int64_t, 4>>
337 arith::AddUIExtendedOp::getShapeForUnroll() {
338 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
339 return llvm::to_vector<4>(vt.getShape());
346 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
350 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
352 Type overflowTy = getOverflow().getType();
358 results.push_back(getLhs());
359 results.push_back(falseValue);
367 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
368 adaptor.getOperands(),
369 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
370 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
371 ArrayRef({sumAttr, adaptor.getLhs()}),
377 results.push_back(sumAttr);
378 results.push_back(overflowAttr);
385 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
387 patterns.
add<AddUIExtendedToAddI>(context);
396 if (getOperand(0) == getOperand(1))
402 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
404 if (getRhs() == add.getRhs())
407 if (getRhs() == add.getLhs())
411 return constFoldBinaryOp<IntegerAttr>(
412 adaptor.getOperands(),
413 [](APInt a,
const APInt &b) { return std::move(a) - b; });
418 patterns.
add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
419 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
420 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
437 return constFoldBinaryOp<IntegerAttr>(
438 adaptor.getOperands(),
439 [](
const APInt &a,
const APInt &b) { return a * b; });
442 void arith::MulIOp::getAsmResultNames(
444 if (!isa<IndexType>(
getType()))
453 IntegerAttr baseValue;
456 isVscale(b.getDefiningOp());
459 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
464 llvm::raw_svector_ostream specialName(specialNameBuffer);
465 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
466 setNameFn(getResult(), specialName.str());
471 patterns.
add<MulIMulIConstant>(context);
478 std::optional<SmallVector<int64_t, 4>>
479 arith::MulSIExtendedOp::getShapeForUnroll() {
480 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
481 return llvm::to_vector<4>(vt.getShape());
486 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
491 results.push_back(zero);
492 results.push_back(zero);
497 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
498 adaptor.getOperands(),
499 [](
const APInt &a,
const APInt &b) { return a * b; })) {
501 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
502 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
503 return llvm::APIntOps::mulhs(a, b);
505 assert(highAttr &&
"Unexpected constant-folding failure");
507 results.push_back(lowAttr);
508 results.push_back(highAttr);
515 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
517 patterns.
add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
524 std::optional<SmallVector<int64_t, 4>>
525 arith::MulUIExtendedOp::getShapeForUnroll() {
526 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
527 return llvm::to_vector<4>(vt.getShape());
532 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
537 results.push_back(zero);
538 results.push_back(zero);
546 results.push_back(getLhs());
547 results.push_back(zero);
552 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
553 adaptor.getOperands(),
554 [](
const APInt &a,
const APInt &b) { return a * b; })) {
556 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
557 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
558 return llvm::APIntOps::mulhu(a, b);
560 assert(highAttr &&
"Unexpected constant-folding failure");
562 results.push_back(lowAttr);
563 results.push_back(highAttr);
570 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
572 patterns.
add<MulUIExtendedToMulI>(context);
579 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
586 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
587 [&](APInt a,
const APInt &b) {
615 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
621 bool overflowOrDiv0 =
false;
622 auto result = constFoldBinaryOp<IntegerAttr>(
623 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
624 if (overflowOrDiv0 || !b) {
625 overflowOrDiv0 = true;
628 return a.sdiv_ov(b, overflowOrDiv0);
631 return overflowOrDiv0 ?
Attribute() : result;
658 APInt one(a.getBitWidth(), 1,
true);
659 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
660 return val.sadd_ov(one, overflow);
667 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
672 bool overflowOrDiv0 =
false;
673 auto result = constFoldBinaryOp<IntegerAttr>(
674 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
675 if (overflowOrDiv0 || !b) {
676 overflowOrDiv0 = true;
679 APInt quotient = a.udiv(b);
682 APInt one(a.getBitWidth(), 1,
true);
683 return quotient.uadd_ov(one, overflowOrDiv0);
686 return overflowOrDiv0 ?
Attribute() : result;
697 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
705 bool overflowOrDiv0 =
false;
706 auto result = constFoldBinaryOp<IntegerAttr>(
707 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
708 if (overflowOrDiv0 || !b) {
709 overflowOrDiv0 = true;
715 unsigned bits = a.getBitWidth();
717 bool aGtZero = a.sgt(zero);
718 bool bGtZero = b.sgt(zero);
719 if (aGtZero && bGtZero) {
726 bool overflowNegA =
false;
727 bool overflowNegB =
false;
728 bool overflowDiv =
false;
729 bool overflowNegRes =
false;
730 if (!aGtZero && !bGtZero) {
732 APInt posA = zero.ssub_ov(a, overflowNegA);
733 APInt posB = zero.ssub_ov(b, overflowNegB);
735 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
738 if (!aGtZero && bGtZero) {
740 APInt posA = zero.ssub_ov(a, overflowNegA);
741 APInt div = posA.sdiv_ov(b, overflowDiv);
742 APInt res = zero.ssub_ov(div, overflowNegRes);
743 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
747 APInt posB = zero.ssub_ov(b, overflowNegB);
748 APInt div = a.sdiv_ov(posB, overflowDiv);
749 APInt res = zero.ssub_ov(div, overflowNegRes);
751 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
755 return overflowOrDiv0 ?
Attribute() : result;
766 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
772 bool overflowOrDiv =
false;
773 auto result = constFoldBinaryOp<IntegerAttr>(
774 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
776 overflowOrDiv = true;
779 return a.sfloordiv_ov(b, overflowOrDiv);
782 return overflowOrDiv ?
Attribute() : result;
789 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
796 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
797 [&](APInt a,
const APInt &b) {
798 if (div0 || b.isZero()) {
812 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
819 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
820 [&](APInt a,
const APInt &b) {
821 if (div0 || b.isZero()) {
837 for (
bool reversePrev : {
false,
true}) {
838 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
839 .getDefiningOp<arith::AndIOp>();
843 Value other = (reversePrev ? op.getLhs() : op.getRhs());
844 if (other != prev.getLhs() && other != prev.getRhs())
847 return prev.getResult();
859 intValue.isAllOnes())
864 intValue.isAllOnes())
869 intValue.isAllOnes())
876 return constFoldBinaryOp<IntegerAttr>(
877 adaptor.getOperands(),
878 [](APInt a,
const APInt &b) { return std::move(a) & b; });
891 if (rhsVal.isAllOnes())
892 return adaptor.getRhs();
899 intValue.isAllOnes())
900 return getRhs().getDefiningOp<XOrIOp>().getRhs();
904 intValue.isAllOnes())
905 return getLhs().getDefiningOp<XOrIOp>().getRhs();
907 return constFoldBinaryOp<IntegerAttr>(
908 adaptor.getOperands(),
909 [](APInt a,
const APInt &b) { return std::move(a) | b; });
921 if (getLhs() == getRhs())
925 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
926 if (prev.getRhs() == getRhs())
927 return prev.getLhs();
928 if (prev.getLhs() == getRhs())
929 return prev.getRhs();
933 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
934 if (prev.getRhs() == getLhs())
935 return prev.getLhs();
936 if (prev.getLhs() == getLhs())
937 return prev.getRhs();
940 return constFoldBinaryOp<IntegerAttr>(
941 adaptor.getOperands(),
942 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
947 patterns.
add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
956 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
958 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
959 [](
const APFloat &a) { return -a; });
971 return constFoldBinaryOp<FloatAttr>(
972 adaptor.getOperands(),
973 [](
const APFloat &a,
const APFloat &b) { return a + b; });
985 return constFoldBinaryOp<FloatAttr>(
986 adaptor.getOperands(),
987 [](
const APFloat &a,
const APFloat &b) { return a - b; });
994 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
996 if (getLhs() == getRhs())
1003 return constFoldBinaryOp<FloatAttr>(
1004 adaptor.getOperands(),
1005 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
1012 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1014 if (getLhs() == getRhs())
1021 return constFoldBinaryOp<FloatAttr>(
1022 adaptor.getOperands(),
1023 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
1032 if (getLhs() == getRhs())
1038 if (intValue.isMaxSignedValue())
1041 if (intValue.isMinSignedValue())
1045 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1046 [](
const APInt &a,
const APInt &b) {
1047 return llvm::APIntOps::smax(a, b);
1057 if (getLhs() == getRhs())
1063 if (intValue.isMaxValue())
1066 if (intValue.isMinValue())
1070 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1071 [](
const APInt &a,
const APInt &b) {
1072 return llvm::APIntOps::umax(a, b);
1080 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1082 if (getLhs() == getRhs())
1089 return constFoldBinaryOp<FloatAttr>(
1090 adaptor.getOperands(),
1091 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
1098 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1100 if (getLhs() == getRhs())
1107 return constFoldBinaryOp<FloatAttr>(
1108 adaptor.getOperands(),
1109 [](
const APFloat &a,
const APFloat &b) { return llvm::minnum(a, b); });
1118 if (getLhs() == getRhs())
1124 if (intValue.isMinSignedValue())
1127 if (intValue.isMaxSignedValue())
1131 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1132 [](
const APInt &a,
const APInt &b) {
1133 return llvm::APIntOps::smin(a, b);
1143 if (getLhs() == getRhs())
1149 if (intValue.isMinValue())
1152 if (intValue.isMaxValue())
1156 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1157 [](
const APInt &a,
const APInt &b) {
1158 return llvm::APIntOps::umin(a, b);
1166 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1171 return constFoldBinaryOp<FloatAttr>(
1172 adaptor.getOperands(),
1173 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1178 patterns.
add<MulFOfNegF>(context);
1185 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1190 return constFoldBinaryOp<FloatAttr>(
1191 adaptor.getOperands(),
1192 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1197 patterns.
add<DivFOfNegF>(context);
1204 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1205 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1206 [](
const APFloat &a,
const APFloat &b) {
1211 (void)result.mod(b);
1220 template <
typename... Types>
1226 template <
typename... ShapedTypes,
typename... ElementTypes>
1229 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1233 if (!llvm::isa<ElementTypes...>(underlyingType))
1236 return underlyingType;
1240 template <
typename... ElementTypes>
1247 template <
typename... ElementTypes>
1256 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1257 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1258 if (!rankedTensorA || !rankedTensorB)
1260 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1264 if (inputs.size() != 1 || outputs.size() != 1)
1276 template <
typename ValType,
typename Op>
1281 if (llvm::cast<ValType>(srcType).getWidth() >=
1282 llvm::cast<ValType>(dstType).getWidth())
1284 << dstType <<
" must be wider than operand type " << srcType;
1290 template <
typename ValType,
typename Op>
1295 if (llvm::cast<ValType>(srcType).getWidth() <=
1296 llvm::cast<ValType>(dstType).getWidth())
1298 << dstType <<
" must be shorter than operand type " << srcType;
1304 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1309 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1310 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1311 if (!srcType || !dstType)
1314 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1315 srcType.getIntOrFloatBitWidth());
1321 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1322 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1323 bool losesInfo =
false;
1324 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1325 if (losesInfo || status != APFloat::opOK)
1335 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1336 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1337 getInMutable().assign(lhs.getIn());
1342 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1343 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1344 adaptor.getOperands(),
getType(),
1345 [bitWidth](
const APInt &a,
bool &castStatus) {
1346 return a.zext(bitWidth);
1351 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1355 return verifyExtOp<IntegerType>(*
this);
1362 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1363 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1364 getInMutable().assign(lhs.getIn());
1369 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1370 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1371 adaptor.getOperands(),
getType(),
1372 [bitWidth](
const APInt &a,
bool &castStatus) {
1373 return a.sext(bitWidth);
1378 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1383 patterns.
add<ExtSIOfExtUI>(context);
1387 return verifyExtOp<IntegerType>(*
this);
1396 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1397 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1398 if (truncFOp.getOperand().getType() ==
getType()) {
1399 arith::FastMathFlags truncFMF =
1400 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1401 bool isTruncContract =
1403 arith::FastMathFlags extFMF =
1404 getFastmath().value_or(arith::FastMathFlags::none);
1405 bool isExtContract =
1407 if (isTruncContract && isExtContract) {
1408 return truncFOp.getOperand();
1414 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1415 return constFoldCastOp<FloatAttr, FloatAttr>(
1416 adaptor.getOperands(),
getType(),
1417 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1419 if (failed(result)) {
1428 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1437 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1438 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1445 if (llvm::cast<IntegerType>(srcType).getWidth() >
1446 llvm::cast<IntegerType>(dstType).getWidth()) {
1453 if (srcType == dstType)
1458 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1459 setOperand(getOperand().getDefiningOp()->getOperand(0));
1464 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1465 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1466 adaptor.getOperands(),
getType(),
1467 [bitWidth](
const APInt &a,
bool &castStatus) {
1468 return a.trunc(bitWidth);
1473 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1478 patterns.
add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1479 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1484 return verifyTruncateOp<IntegerType>(*
this);
1493 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1495 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1496 return constFoldCastOp<FloatAttr, FloatAttr>(
1497 adaptor.getOperands(),
getType(),
1498 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1499 RoundingMode roundingMode =
1500 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1501 llvm::RoundingMode llvmRoundingMode =
1503 FailureOr<APFloat> result =
1505 if (failed(result)) {
1514 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1518 return verifyTruncateOp<FloatType>(*
this);
1527 patterns.
add<AndOfExtUI, AndOfExtSI>(context);
1536 patterns.
add<OrOfExtUI, OrOfExtSI>(context);
1543 template <
typename From,
typename To>
1548 auto srcType = getTypeIfLike<From>(inputs.front());
1549 auto dstType = getTypeIfLike<To>(outputs.back());
1551 return srcType && dstType;
1559 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1562 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1564 return constFoldCastOp<IntegerAttr, FloatAttr>(
1565 adaptor.getOperands(),
getType(),
1566 [&resEleType](
const APInt &a,
bool &castStatus) {
1567 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1570 apf.convertFromAPInt(a,
false,
1571 APFloat::rmNearestTiesToEven);
1581 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1584 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1586 return constFoldCastOp<IntegerAttr, FloatAttr>(
1587 adaptor.getOperands(),
getType(),
1588 [&resEleType](
const APInt &a,
bool &castStatus) {
1589 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1592 apf.convertFromAPInt(a,
true,
1593 APFloat::rmNearestTiesToEven);
1603 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1606 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1608 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1609 return constFoldCastOp<FloatAttr, IntegerAttr>(
1610 adaptor.getOperands(),
getType(),
1611 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1613 APSInt api(bitWidth,
true);
1614 castStatus = APFloat::opInvalidOp !=
1615 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1625 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1628 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1630 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1631 return constFoldCastOp<FloatAttr, IntegerAttr>(
1632 adaptor.getOperands(),
getType(),
1633 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1635 APSInt api(bitWidth,
false);
1636 castStatus = APFloat::opInvalidOp !=
1637 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1650 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1651 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1652 if (!srcType || !dstType)
1659 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1664 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1666 unsigned resultBitwidth = 64;
1668 resultBitwidth = intTy.getWidth();
1670 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1671 adaptor.getOperands(),
getType(),
1672 [resultBitwidth](
const APInt &a,
bool & ) {
1673 return a.sextOrTrunc(resultBitwidth);
1677 void arith::IndexCastOp::getCanonicalizationPatterns(
1679 patterns.
add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1686 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1691 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1693 unsigned resultBitwidth = 64;
1695 resultBitwidth = intTy.getWidth();
1697 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1698 adaptor.getOperands(),
getType(),
1699 [resultBitwidth](
const APInt &a,
bool & ) {
1700 return a.zextOrTrunc(resultBitwidth);
1704 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1706 patterns.
add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1718 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1720 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1721 if (!srcType || !dstType)
1727 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1729 auto operand = adaptor.getIn();
1734 if (
auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1735 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1737 if (llvm::isa<ShapedType>(resType))
1741 APInt bits = llvm::isa<FloatAttr>(operand)
1742 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1743 : llvm::cast<IntegerAttr>(operand).getValue();
1745 "trying to fold on broken IR: operands have incompatible types");
1747 if (
auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1749 APFloat(resFloatType.getFloatSemantics(), bits));
1755 patterns.
add<BitcastOfBitcast>(context);
1765 const APInt &lhs,
const APInt &rhs) {
1766 switch (predicate) {
1767 case arith::CmpIPredicate::eq:
1769 case arith::CmpIPredicate::ne:
1771 case arith::CmpIPredicate::slt:
1772 return lhs.slt(rhs);
1773 case arith::CmpIPredicate::sle:
1774 return lhs.sle(rhs);
1775 case arith::CmpIPredicate::sgt:
1776 return lhs.sgt(rhs);
1777 case arith::CmpIPredicate::sge:
1778 return lhs.sge(rhs);
1779 case arith::CmpIPredicate::ult:
1780 return lhs.ult(rhs);
1781 case arith::CmpIPredicate::ule:
1782 return lhs.ule(rhs);
1783 case arith::CmpIPredicate::ugt:
1784 return lhs.ugt(rhs);
1785 case arith::CmpIPredicate::uge:
1786 return lhs.uge(rhs);
1788 llvm_unreachable(
"unknown cmpi predicate kind");
1793 switch (predicate) {
1794 case arith::CmpIPredicate::eq:
1795 case arith::CmpIPredicate::sle:
1796 case arith::CmpIPredicate::sge:
1797 case arith::CmpIPredicate::ule:
1798 case arith::CmpIPredicate::uge:
1800 case arith::CmpIPredicate::ne:
1801 case arith::CmpIPredicate::slt:
1802 case arith::CmpIPredicate::sgt:
1803 case arith::CmpIPredicate::ult:
1804 case arith::CmpIPredicate::ugt:
1807 llvm_unreachable(
"unknown cmpi predicate kind");
1811 if (
auto intType = llvm::dyn_cast<IntegerType>(t)) {
1812 return intType.getWidth();
1814 if (
auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1815 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1817 return std::nullopt;
1820 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1822 if (getLhs() == getRhs()) {
1828 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1830 std::optional<int64_t> integerWidth =
1832 if (integerWidth && integerWidth.value() == 1 &&
1833 getPredicate() == arith::CmpIPredicate::ne)
1834 return extOp.getOperand();
1836 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1838 std::optional<int64_t> integerWidth =
1840 if (integerWidth && integerWidth.value() == 1 &&
1841 getPredicate() == arith::CmpIPredicate::ne)
1842 return extOp.getOperand();
1847 if (adaptor.getLhs() && !adaptor.getRhs()) {
1849 using Pred = CmpIPredicate;
1850 const std::pair<Pred, Pred> invPreds[] = {
1851 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1852 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1853 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1854 {Pred::ne, Pred::ne},
1856 Pred origPred = getPredicate();
1857 for (
auto pred : invPreds) {
1858 if (origPred == pred.first) {
1859 setPredicate(pred.second);
1860 Value lhs = getLhs();
1861 Value rhs = getRhs();
1862 getLhsMutable().assign(rhs);
1863 getRhsMutable().assign(lhs);
1867 llvm_unreachable(
"unknown cmpi predicate kind");
1872 if (
auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1873 return constFoldBinaryOp<IntegerAttr>(
1875 [pred = getPredicate()](
const APInt &lhs,
const APInt &rhs) {
1886 patterns.
insert<CmpIExtSI, CmpIExtUI>(context);
1896 const APFloat &lhs,
const APFloat &rhs) {
1897 auto cmpResult = lhs.compare(rhs);
1898 switch (predicate) {
1899 case arith::CmpFPredicate::AlwaysFalse:
1901 case arith::CmpFPredicate::OEQ:
1902 return cmpResult == APFloat::cmpEqual;
1903 case arith::CmpFPredicate::OGT:
1904 return cmpResult == APFloat::cmpGreaterThan;
1905 case arith::CmpFPredicate::OGE:
1906 return cmpResult == APFloat::cmpGreaterThan ||
1907 cmpResult == APFloat::cmpEqual;
1908 case arith::CmpFPredicate::OLT:
1909 return cmpResult == APFloat::cmpLessThan;
1910 case arith::CmpFPredicate::OLE:
1911 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1912 case arith::CmpFPredicate::ONE:
1913 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1914 case arith::CmpFPredicate::ORD:
1915 return cmpResult != APFloat::cmpUnordered;
1916 case arith::CmpFPredicate::UEQ:
1917 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1918 case arith::CmpFPredicate::UGT:
1919 return cmpResult == APFloat::cmpUnordered ||
1920 cmpResult == APFloat::cmpGreaterThan;
1921 case arith::CmpFPredicate::UGE:
1922 return cmpResult == APFloat::cmpUnordered ||
1923 cmpResult == APFloat::cmpGreaterThan ||
1924 cmpResult == APFloat::cmpEqual;
1925 case arith::CmpFPredicate::ULT:
1926 return cmpResult == APFloat::cmpUnordered ||
1927 cmpResult == APFloat::cmpLessThan;
1928 case arith::CmpFPredicate::ULE:
1929 return cmpResult == APFloat::cmpUnordered ||
1930 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1931 case arith::CmpFPredicate::UNE:
1932 return cmpResult != APFloat::cmpEqual;
1933 case arith::CmpFPredicate::UNO:
1934 return cmpResult == APFloat::cmpUnordered;
1935 case arith::CmpFPredicate::AlwaysTrue:
1938 llvm_unreachable(
"unknown cmpf predicate kind");
1941 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1942 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1943 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1946 if (lhs && lhs.getValue().isNaN())
1948 if (rhs && rhs.getValue().isNaN())
1964 using namespace arith;
1966 case CmpFPredicate::UEQ:
1967 case CmpFPredicate::OEQ:
1968 return CmpIPredicate::eq;
1969 case CmpFPredicate::UGT:
1970 case CmpFPredicate::OGT:
1971 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1972 case CmpFPredicate::UGE:
1973 case CmpFPredicate::OGE:
1974 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1975 case CmpFPredicate::ULT:
1976 case CmpFPredicate::OLT:
1977 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1978 case CmpFPredicate::ULE:
1979 case CmpFPredicate::OLE:
1980 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1981 case CmpFPredicate::UNE:
1982 case CmpFPredicate::ONE:
1983 return CmpIPredicate::ne;
1985 llvm_unreachable(
"Unexpected predicate!");
1995 const APFloat &rhs = flt.getValue();
2003 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2005 if (mantissaWidth <= 0)
2011 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2013 intVal = si.getIn();
2014 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2016 intVal = ui.getIn();
2023 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2024 auto intWidth = intTy.getWidth();
2027 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2032 if ((
int)intWidth > mantissaWidth) {
2034 int exponent = ilogb(rhs);
2035 if (exponent == APFloat::IEK_Inf) {
2036 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2037 if (maxExponent < (
int)valueBits) {
2044 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2053 switch (op.getPredicate()) {
2054 case CmpFPredicate::ORD:
2059 case CmpFPredicate::UNO:
2072 APFloat signedMax(rhs.getSemantics());
2073 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2074 APFloat::rmNearestTiesToEven);
2075 if (signedMax < rhs) {
2076 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2077 pred == CmpIPredicate::sle)
2088 APFloat unsignedMax(rhs.getSemantics());
2089 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2090 APFloat::rmNearestTiesToEven);
2091 if (unsignedMax < rhs) {
2092 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2093 pred == CmpIPredicate::ule)
2105 APFloat signedMin(rhs.getSemantics());
2106 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2107 APFloat::rmNearestTiesToEven);
2108 if (signedMin > rhs) {
2109 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2110 pred == CmpIPredicate::sge)
2120 APFloat unsignedMin(rhs.getSemantics());
2121 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2122 APFloat::rmNearestTiesToEven);
2123 if (unsignedMin > rhs) {
2124 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2125 pred == CmpIPredicate::uge)
2140 APSInt rhsInt(intWidth, isUnsigned);
2141 if (APFloat::opInvalidOp ==
2142 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2148 if (!rhs.isZero()) {
2151 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2153 bool equal = apf == rhs;
2159 case CmpIPredicate::ne:
2163 case CmpIPredicate::eq:
2167 case CmpIPredicate::ule:
2170 if (rhs.isNegative()) {
2176 case CmpIPredicate::sle:
2179 if (rhs.isNegative())
2180 pred = CmpIPredicate::slt;
2182 case CmpIPredicate::ult:
2185 if (rhs.isNegative()) {
2190 pred = CmpIPredicate::ule;
2192 case CmpIPredicate::slt:
2195 if (!rhs.isNegative())
2196 pred = CmpIPredicate::sle;
2198 case CmpIPredicate::ugt:
2201 if (rhs.isNegative()) {
2207 case CmpIPredicate::sgt:
2210 if (rhs.isNegative())
2211 pred = CmpIPredicate::sge;
2213 case CmpIPredicate::uge:
2216 if (rhs.isNegative()) {
2221 pred = CmpIPredicate::ugt;
2223 case CmpIPredicate::sge:
2226 if (!rhs.isNegative())
2227 pred = CmpIPredicate::sgt;
2237 rewriter.
create<ConstantOp>(
2260 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2276 rewriter.
create<arith::XOrIOp>(
2277 op.
getLoc(), op.getCondition(),
2278 rewriter.
create<arith::ConstantIntOp>(
2279 op.
getLoc(), 1, op.getCondition().getType())));
2289 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2293 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2294 Value trueVal = getTrueValue();
2295 Value falseVal = getFalseValue();
2296 if (trueVal == falseVal)
2299 Value condition = getCondition();
2310 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2313 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2321 if (
auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.
getDefiningOp())) {
2322 auto pred = cmp.getPredicate();
2323 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2324 auto cmpLhs = cmp.getLhs();
2325 auto cmpRhs = cmp.getRhs();
2333 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2334 (cmpRhs == trueVal && cmpLhs == falseVal))
2335 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2342 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2344 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2346 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2348 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2349 auto condVals = llvm::make_range(cond.value_begin<
BoolAttr>(),
2351 auto lhsVals = llvm::make_range(lhs.value_begin<
Attribute>(),
2353 auto rhsVals = llvm::make_range(rhs.value_begin<
Attribute>(),
2356 for (
auto [condVal, lhsVal, rhsVal] :
2357 llvm::zip_equal(condVals, lhsVals, rhsVals))
2358 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2369 Type conditionType, resultType;
2378 conditionType = resultType;
2387 {conditionType, resultType, resultType},
2392 p <<
" " << getOperands();
2395 if (ShapedType condType =
2396 llvm::dyn_cast<ShapedType>(getCondition().
getType()))
2397 p << condType <<
", ";
2402 Type conditionType = getCondition().getType();
2409 if (!llvm::isa<TensorType, VectorType>(resultType))
2410 return emitOpError() <<
"expected condition to be a signless i1, but got "
2413 if (conditionType != shapedConditionType) {
2414 return emitOpError() <<
"expected condition type to have the same shape "
2415 "as the result type, expected "
2416 << shapedConditionType <<
", but got "
2425 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2430 bool bounded =
false;
2431 auto result = constFoldBinaryOp<IntegerAttr>(
2432 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2433 bounded = b.ult(b.getBitWidth());
2443 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2448 bool bounded =
false;
2449 auto result = constFoldBinaryOp<IntegerAttr>(
2450 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2451 bounded = b.ult(b.getBitWidth());
2461 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2466 bool bounded =
false;
2467 auto result = constFoldBinaryOp<IntegerAttr>(
2468 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2469 bounded = b.ult(b.getBitWidth());
2482 bool useOnlyFiniteValue) {
2484 case AtomicRMWKind::maximumf: {
2485 const llvm::fltSemantics &semantic =
2486 llvm::cast<FloatType>(resultType).getFloatSemantics();
2487 APFloat identity = useOnlyFiniteValue
2488 ? APFloat::getLargest(semantic,
true)
2489 : APFloat::getInf(semantic,
true);
2492 case AtomicRMWKind::maxnumf: {
2493 const llvm::fltSemantics &semantic =
2494 llvm::cast<FloatType>(resultType).getFloatSemantics();
2495 APFloat identity = APFloat::getNaN(semantic,
true);
2498 case AtomicRMWKind::addf:
2499 case AtomicRMWKind::addi:
2500 case AtomicRMWKind::maxu:
2501 case AtomicRMWKind::ori:
2503 case AtomicRMWKind::andi:
2506 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2507 case AtomicRMWKind::maxs:
2509 resultType, APInt::getSignedMinValue(
2510 llvm::cast<IntegerType>(resultType).getWidth()));
2511 case AtomicRMWKind::minimumf: {
2512 const llvm::fltSemantics &semantic =
2513 llvm::cast<FloatType>(resultType).getFloatSemantics();
2514 APFloat identity = useOnlyFiniteValue
2515 ? APFloat::getLargest(semantic,
false)
2516 : APFloat::getInf(semantic,
false);
2520 case AtomicRMWKind::minnumf: {
2521 const llvm::fltSemantics &semantic =
2522 llvm::cast<FloatType>(resultType).getFloatSemantics();
2523 APFloat identity = APFloat::getNaN(semantic,
false);
2526 case AtomicRMWKind::mins:
2528 resultType, APInt::getSignedMaxValue(
2529 llvm::cast<IntegerType>(resultType).getWidth()));
2530 case AtomicRMWKind::minu:
2533 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2534 case AtomicRMWKind::muli:
2536 case AtomicRMWKind::mulf:
2548 std::optional<AtomicRMWKind> maybeKind =
2551 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2552 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2553 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2554 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2555 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2556 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2558 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2559 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2560 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::ori; })
2561 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2562 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2563 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2564 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2565 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2566 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2567 .Default([](
Operation *op) {
return std::nullopt; });
2569 return std::nullopt;
2572 bool useOnlyFiniteValue =
false;
2573 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2574 if (fmfOpInterface) {
2575 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2576 useOnlyFiniteValue =
2577 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2585 useOnlyFiniteValue);
2591 bool useOnlyFiniteValue) {
2594 return builder.
create<arith::ConstantOp>(loc, attr);
2602 case AtomicRMWKind::addf:
2603 return builder.
create<arith::AddFOp>(loc, lhs, rhs);
2604 case AtomicRMWKind::addi:
2605 return builder.
create<arith::AddIOp>(loc, lhs, rhs);
2606 case AtomicRMWKind::mulf:
2607 return builder.
create<arith::MulFOp>(loc, lhs, rhs);
2608 case AtomicRMWKind::muli:
2609 return builder.
create<arith::MulIOp>(loc, lhs, rhs);
2610 case AtomicRMWKind::maximumf:
2611 return builder.
create<arith::MaximumFOp>(loc, lhs, rhs);
2612 case AtomicRMWKind::minimumf:
2613 return builder.
create<arith::MinimumFOp>(loc, lhs, rhs);
2614 case AtomicRMWKind::maxnumf:
2615 return builder.
create<arith::MaxNumFOp>(loc, lhs, rhs);
2616 case AtomicRMWKind::minnumf:
2617 return builder.
create<arith::MinNumFOp>(loc, lhs, rhs);
2618 case AtomicRMWKind::maxs:
2619 return builder.
create<arith::MaxSIOp>(loc, lhs, rhs);
2620 case AtomicRMWKind::mins:
2621 return builder.
create<arith::MinSIOp>(loc, lhs, rhs);
2622 case AtomicRMWKind::maxu:
2623 return builder.
create<arith::MaxUIOp>(loc, lhs, rhs);
2624 case AtomicRMWKind::minu:
2625 return builder.
create<arith::MinUIOp>(loc, lhs, rhs);
2626 case AtomicRMWKind::ori:
2627 return builder.
create<arith::OrIOp>(loc, lhs, rhs);
2628 case AtomicRMWKind::andi:
2629 return builder.
create<arith::AndIOp>(loc, lhs, rhs);
2642 #define GET_OP_CLASSES
2643 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2649 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
Specialization of arith.constant op that returns an integer value.
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)