25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/APInt.h"
27 #include "llvm/ADT/APSInt.h"
28 #include "llvm/ADT/FloatingPointMode.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallString.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
44 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
45 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
46 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
67 static IntegerOverflowFlagsAttr
69 IntegerOverflowFlagsAttr val2) {
71 val1.getValue() & val2.getValue());
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
98 llvm_unreachable(
"unknown cmpi predicate kind");
107 static llvm::RoundingMode
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
121 llvm_unreachable(
"Unhandled rounding mode");
151 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
162 #include "ArithCanonicalization.inc"
172 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
173 return shapedType.cloneWith(std::nullopt, i1Type);
174 if (llvm::isa<UnrankedTensorType>(type))
183 void arith::ConstantOp::getAsmResultNames(
186 if (
auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
187 auto intType = llvm::dyn_cast<IntegerType>(type);
190 if (intType && intType.getWidth() == 1)
191 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
195 llvm::raw_svector_ostream specialName(specialNameBuffer);
196 specialName <<
'c' << intCst.getValue();
198 specialName <<
'_' << type;
199 setNameFn(getResult(), specialName.str());
201 setNameFn(getResult(),
"cst");
210 if (getValue().
getType() != type) {
211 return emitOpError() <<
"value type " << getValue().getType()
212 <<
" must match return type: " << type;
215 if (llvm::isa<IntegerType>(type) &&
216 !llvm::cast<IntegerType>(type).isSignless())
217 return emitOpError(
"integer return type must be signless");
219 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
221 "value must be an integer, float, or elements attribute");
227 auto vecType = dyn_cast<VectorType>(type);
228 if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
230 "intializing scalable vectors with elements attribute is not supported"
231 " unless it's a vector splat");
235 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
237 auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
238 if (!typedAttr || typedAttr.getType() != type)
241 if (llvm::isa<IntegerType>(type) &&
242 !llvm::cast<IntegerType>(type).isSignless())
245 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
250 if (isBuildableWith(value, type))
251 return builder.
create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
255 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
258 int64_t value,
unsigned width) {
260 arith::ConstantOp::build(builder, result, type,
265 int64_t value,
Type type) {
267 "ConstantIntOp can only have signless integer type values");
268 arith::ConstantOp::build(builder, result, type,
273 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
274 return constOp.getType().isSignlessInteger();
280 arith::ConstantOp::build(builder, result, type,
285 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
286 return llvm::isa<FloatType>(constOp.getType());
292 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
297 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
298 return constOp.getType().isIndex();
312 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
313 if (getRhs() == sub.getRhs())
317 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
318 if (getLhs() == sub.getRhs())
321 return constFoldBinaryOp<IntegerAttr>(
322 adaptor.getOperands(),
323 [](APInt a,
const APInt &b) { return std::move(a) + b; });
328 patterns.
add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
329 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
336 std::optional<SmallVector<int64_t, 4>>
337 arith::AddUIExtendedOp::getShapeForUnroll() {
338 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
339 return llvm::to_vector<4>(vt.getShape());
346 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
350 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
352 Type overflowTy = getOverflow().getType();
358 results.push_back(getLhs());
359 results.push_back(falseValue);
367 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
368 adaptor.getOperands(),
369 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
370 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
371 ArrayRef({sumAttr, adaptor.getLhs()}),
377 results.push_back(sumAttr);
378 results.push_back(overflowAttr);
385 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
387 patterns.
add<AddUIExtendedToAddI>(context);
396 if (getOperand(0) == getOperand(1))
402 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
404 if (getRhs() == add.getRhs())
407 if (getRhs() == add.getLhs())
411 return constFoldBinaryOp<IntegerAttr>(
412 adaptor.getOperands(),
413 [](APInt a,
const APInt &b) { return std::move(a) - b; });
418 patterns.
add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
419 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
420 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
437 return constFoldBinaryOp<IntegerAttr>(
438 adaptor.getOperands(),
439 [](
const APInt &a,
const APInt &b) { return a * b; });
442 void arith::MulIOp::getAsmResultNames(
444 if (!isa<IndexType>(
getType()))
453 IntegerAttr baseValue;
456 isVscale(b.getDefiningOp());
459 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
464 llvm::raw_svector_ostream specialName(specialNameBuffer);
465 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
466 setNameFn(getResult(), specialName.str());
471 patterns.
add<MulIMulIConstant>(context);
478 std::optional<SmallVector<int64_t, 4>>
479 arith::MulSIExtendedOp::getShapeForUnroll() {
480 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
481 return llvm::to_vector<4>(vt.getShape());
486 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
491 results.push_back(zero);
492 results.push_back(zero);
497 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
498 adaptor.getOperands(),
499 [](
const APInt &a,
const APInt &b) { return a * b; })) {
501 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
502 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
503 return llvm::APIntOps::mulhs(a, b);
505 assert(highAttr &&
"Unexpected constant-folding failure");
507 results.push_back(lowAttr);
508 results.push_back(highAttr);
515 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
517 patterns.
add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
524 std::optional<SmallVector<int64_t, 4>>
525 arith::MulUIExtendedOp::getShapeForUnroll() {
526 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
527 return llvm::to_vector<4>(vt.getShape());
532 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
537 results.push_back(zero);
538 results.push_back(zero);
546 results.push_back(getLhs());
547 results.push_back(zero);
552 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
553 adaptor.getOperands(),
554 [](
const APInt &a,
const APInt &b) { return a * b; })) {
556 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
557 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
558 return llvm::APIntOps::mulhu(a, b);
560 assert(highAttr &&
"Unexpected constant-folding failure");
562 results.push_back(lowAttr);
563 results.push_back(highAttr);
570 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
572 patterns.
add<MulUIExtendedToMulI>(context);
579 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
586 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
587 [&](APInt a,
const APInt &b) {
608 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
614 bool overflowOrDiv0 =
false;
615 auto result = constFoldBinaryOp<IntegerAttr>(
616 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
617 if (overflowOrDiv0 || !b) {
618 overflowOrDiv0 = true;
621 return a.sdiv_ov(b, overflowOrDiv0);
624 return overflowOrDiv0 ?
Attribute() : result;
628 bool mayHaveUB =
true;
634 mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
646 APInt one(a.getBitWidth(), 1,
true);
647 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
648 return val.sadd_ov(one, overflow);
655 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
660 bool overflowOrDiv0 =
false;
661 auto result = constFoldBinaryOp<IntegerAttr>(
662 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
663 if (overflowOrDiv0 || !b) {
664 overflowOrDiv0 = true;
667 APInt quotient = a.udiv(b);
670 APInt one(a.getBitWidth(), 1,
true);
671 return quotient.uadd_ov(one, overflowOrDiv0);
674 return overflowOrDiv0 ?
Attribute() : result;
687 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
695 bool overflowOrDiv0 =
false;
696 auto result = constFoldBinaryOp<IntegerAttr>(
697 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
698 if (overflowOrDiv0 || !b) {
699 overflowOrDiv0 = true;
705 unsigned bits = a.getBitWidth();
707 bool aGtZero = a.sgt(zero);
708 bool bGtZero = b.sgt(zero);
709 if (aGtZero && bGtZero) {
716 bool overflowNegA =
false;
717 bool overflowNegB =
false;
718 bool overflowDiv =
false;
719 bool overflowNegRes =
false;
720 if (!aGtZero && !bGtZero) {
722 APInt posA = zero.ssub_ov(a, overflowNegA);
723 APInt posB = zero.ssub_ov(b, overflowNegB);
725 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
728 if (!aGtZero && bGtZero) {
730 APInt posA = zero.ssub_ov(a, overflowNegA);
731 APInt div = posA.sdiv_ov(b, overflowDiv);
732 APInt res = zero.ssub_ov(div, overflowNegRes);
733 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
737 APInt posB = zero.ssub_ov(b, overflowNegB);
738 APInt div = a.sdiv_ov(posB, overflowDiv);
739 APInt res = zero.ssub_ov(div, overflowNegRes);
741 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
745 return overflowOrDiv0 ?
Attribute() : result;
749 bool mayHaveUB =
true;
755 mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
764 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
770 bool overflowOrDiv =
false;
771 auto result = constFoldBinaryOp<IntegerAttr>(
772 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
774 overflowOrDiv = true;
777 return a.sfloordiv_ov(b, overflowOrDiv);
780 return overflowOrDiv ?
Attribute() : result;
787 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
794 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
795 [&](APInt a,
const APInt &b) {
796 if (div0 || b.isZero()) {
810 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
817 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
818 [&](APInt a,
const APInt &b) {
819 if (div0 || b.isZero()) {
835 for (
bool reversePrev : {
false,
true}) {
836 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
837 .getDefiningOp<arith::AndIOp>();
841 Value other = (reversePrev ? op.getLhs() : op.getRhs());
842 if (other != prev.getLhs() && other != prev.getRhs())
845 return prev.getResult();
857 intValue.isAllOnes())
862 intValue.isAllOnes())
867 intValue.isAllOnes())
874 return constFoldBinaryOp<IntegerAttr>(
875 adaptor.getOperands(),
876 [](APInt a,
const APInt &b) { return std::move(a) & b; });
889 if (rhsVal.isAllOnes())
890 return adaptor.getRhs();
897 intValue.isAllOnes())
898 return getRhs().getDefiningOp<XOrIOp>().getRhs();
902 intValue.isAllOnes())
903 return getLhs().getDefiningOp<XOrIOp>().getRhs();
905 return constFoldBinaryOp<IntegerAttr>(
906 adaptor.getOperands(),
907 [](APInt a,
const APInt &b) { return std::move(a) | b; });
919 if (getLhs() == getRhs())
923 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
924 if (prev.getRhs() == getRhs())
925 return prev.getLhs();
926 if (prev.getLhs() == getRhs())
927 return prev.getRhs();
931 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
932 if (prev.getRhs() == getLhs())
933 return prev.getLhs();
934 if (prev.getLhs() == getLhs())
935 return prev.getRhs();
938 return constFoldBinaryOp<IntegerAttr>(
939 adaptor.getOperands(),
940 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
945 patterns.
add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
954 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
956 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
957 [](
const APFloat &a) { return -a; });
969 return constFoldBinaryOp<FloatAttr>(
970 adaptor.getOperands(),
971 [](
const APFloat &a,
const APFloat &b) { return a + b; });
983 return constFoldBinaryOp<FloatAttr>(
984 adaptor.getOperands(),
985 [](
const APFloat &a,
const APFloat &b) { return a - b; });
992 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
994 if (getLhs() == getRhs())
1001 return constFoldBinaryOp<FloatAttr>(
1002 adaptor.getOperands(),
1003 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
1010 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1012 if (getLhs() == getRhs())
1019 return constFoldBinaryOp<FloatAttr>(
1020 adaptor.getOperands(),
1021 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
1030 if (getLhs() == getRhs())
1036 if (intValue.isMaxSignedValue())
1039 if (intValue.isMinSignedValue())
1043 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1044 [](
const APInt &a,
const APInt &b) {
1045 return llvm::APIntOps::smax(a, b);
1055 if (getLhs() == getRhs())
1061 if (intValue.isMaxValue())
1064 if (intValue.isMinValue())
1068 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1069 [](
const APInt &a,
const APInt &b) {
1070 return llvm::APIntOps::umax(a, b);
1078 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1080 if (getLhs() == getRhs())
1087 return constFoldBinaryOp<FloatAttr>(
1088 adaptor.getOperands(),
1089 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
1096 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1098 if (getLhs() == getRhs())
1105 return constFoldBinaryOp<FloatAttr>(
1106 adaptor.getOperands(),
1107 [](
const APFloat &a,
const APFloat &b) { return llvm::minnum(a, b); });
1116 if (getLhs() == getRhs())
1122 if (intValue.isMinSignedValue())
1125 if (intValue.isMaxSignedValue())
1129 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1130 [](
const APInt &a,
const APInt &b) {
1131 return llvm::APIntOps::smin(a, b);
1141 if (getLhs() == getRhs())
1147 if (intValue.isMinValue())
1150 if (intValue.isMaxValue())
1154 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1155 [](
const APInt &a,
const APInt &b) {
1156 return llvm::APIntOps::umin(a, b);
1164 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1169 return constFoldBinaryOp<FloatAttr>(
1170 adaptor.getOperands(),
1171 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1176 patterns.
add<MulFOfNegF>(context);
1183 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1188 return constFoldBinaryOp<FloatAttr>(
1189 adaptor.getOperands(),
1190 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1195 patterns.
add<DivFOfNegF>(context);
1202 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1203 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1204 [](
const APFloat &a,
const APFloat &b) {
1209 (void)result.mod(b);
1218 template <
typename... Types>
1224 template <
typename... ShapedTypes,
typename... ElementTypes>
1227 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1231 if (!llvm::isa<ElementTypes...>(underlyingType))
1234 return underlyingType;
1238 template <
typename... ElementTypes>
1245 template <
typename... ElementTypes>
1254 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1255 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1256 if (!rankedTensorA || !rankedTensorB)
1258 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1262 if (inputs.size() != 1 || outputs.size() != 1)
1274 template <
typename ValType,
typename Op>
1279 if (llvm::cast<ValType>(srcType).getWidth() >=
1280 llvm::cast<ValType>(dstType).getWidth())
1282 << dstType <<
" must be wider than operand type " << srcType;
1288 template <
typename ValType,
typename Op>
1293 if (llvm::cast<ValType>(srcType).getWidth() <=
1294 llvm::cast<ValType>(dstType).getWidth())
1296 << dstType <<
" must be shorter than operand type " << srcType;
1302 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1307 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1308 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1309 if (!srcType || !dstType)
1312 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1313 srcType.getIntOrFloatBitWidth());
1319 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1320 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1321 bool losesInfo =
false;
1322 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1323 if (losesInfo || status != APFloat::opOK)
1333 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1334 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1335 getInMutable().assign(lhs.getIn());
1340 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1341 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1342 adaptor.getOperands(),
getType(),
1343 [bitWidth](
const APInt &a,
bool &castStatus) {
1344 return a.zext(bitWidth);
1349 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1353 return verifyExtOp<IntegerType>(*
this);
1360 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1361 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1362 getInMutable().assign(lhs.getIn());
1367 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1368 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1369 adaptor.getOperands(),
getType(),
1370 [bitWidth](
const APInt &a,
bool &castStatus) {
1371 return a.sext(bitWidth);
1376 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1381 patterns.
add<ExtSIOfExtUI>(context);
1385 return verifyExtOp<IntegerType>(*
this);
1394 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1395 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1396 if (truncFOp.getOperand().getType() ==
getType()) {
1397 arith::FastMathFlags truncFMF =
1398 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1399 bool isTruncContract =
1401 arith::FastMathFlags extFMF =
1402 getFastmath().value_or(arith::FastMathFlags::none);
1403 bool isExtContract =
1405 if (isTruncContract && isExtContract) {
1406 return truncFOp.getOperand();
1412 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1413 return constFoldCastOp<FloatAttr, FloatAttr>(
1414 adaptor.getOperands(),
getType(),
1415 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1417 if (failed(result)) {
1426 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1435 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1436 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1443 if (llvm::cast<IntegerType>(srcType).getWidth() >
1444 llvm::cast<IntegerType>(dstType).getWidth()) {
1451 if (srcType == dstType)
1456 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1457 setOperand(getOperand().getDefiningOp()->getOperand(0));
1462 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1463 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1464 adaptor.getOperands(),
getType(),
1465 [bitWidth](
const APInt &a,
bool &castStatus) {
1466 return a.trunc(bitWidth);
1471 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1476 patterns.
add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1477 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1482 return verifyTruncateOp<IntegerType>(*
this);
1491 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1493 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1494 return constFoldCastOp<FloatAttr, FloatAttr>(
1495 adaptor.getOperands(),
getType(),
1496 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1497 RoundingMode roundingMode =
1498 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1499 llvm::RoundingMode llvmRoundingMode =
1501 FailureOr<APFloat> result =
1503 if (failed(result)) {
1512 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1516 return verifyTruncateOp<FloatType>(*
this);
1525 patterns.
add<AndOfExtUI, AndOfExtSI>(context);
1534 patterns.
add<OrOfExtUI, OrOfExtSI>(context);
1541 template <
typename From,
typename To>
1546 auto srcType = getTypeIfLike<From>(inputs.front());
1547 auto dstType = getTypeIfLike<To>(outputs.back());
1549 return srcType && dstType;
1557 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1560 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1562 return constFoldCastOp<IntegerAttr, FloatAttr>(
1563 adaptor.getOperands(),
getType(),
1564 [&resEleType](
const APInt &a,
bool &castStatus) {
1565 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1568 apf.convertFromAPInt(a,
false,
1569 APFloat::rmNearestTiesToEven);
1579 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1582 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1584 return constFoldCastOp<IntegerAttr, FloatAttr>(
1585 adaptor.getOperands(),
getType(),
1586 [&resEleType](
const APInt &a,
bool &castStatus) {
1587 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1590 apf.convertFromAPInt(a,
true,
1591 APFloat::rmNearestTiesToEven);
1601 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1604 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1606 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1607 return constFoldCastOp<FloatAttr, IntegerAttr>(
1608 adaptor.getOperands(),
getType(),
1609 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1611 APSInt api(bitWidth,
true);
1612 castStatus = APFloat::opInvalidOp !=
1613 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1623 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1626 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1628 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1629 return constFoldCastOp<FloatAttr, IntegerAttr>(
1630 adaptor.getOperands(),
getType(),
1631 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1633 APSInt api(bitWidth,
false);
1634 castStatus = APFloat::opInvalidOp !=
1635 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1648 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1649 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1650 if (!srcType || !dstType)
1657 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1662 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1664 unsigned resultBitwidth = 64;
1666 resultBitwidth = intTy.getWidth();
1668 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1669 adaptor.getOperands(),
getType(),
1670 [resultBitwidth](
const APInt &a,
bool & ) {
1671 return a.sextOrTrunc(resultBitwidth);
1675 void arith::IndexCastOp::getCanonicalizationPatterns(
1677 patterns.
add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1684 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1689 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1691 unsigned resultBitwidth = 64;
1693 resultBitwidth = intTy.getWidth();
1695 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1696 adaptor.getOperands(),
getType(),
1697 [resultBitwidth](
const APInt &a,
bool & ) {
1698 return a.zextOrTrunc(resultBitwidth);
1702 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1704 patterns.
add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1716 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1718 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1719 if (!srcType || !dstType)
1725 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1727 auto operand = adaptor.getIn();
1732 if (
auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1733 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1735 if (llvm::isa<ShapedType>(resType))
1739 APInt bits = llvm::isa<FloatAttr>(operand)
1740 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1741 : llvm::cast<IntegerAttr>(operand).getValue();
1743 if (
auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1745 APFloat(resFloatType.getFloatSemantics(), bits));
1751 patterns.
add<BitcastOfBitcast>(context);
1761 const APInt &lhs,
const APInt &rhs) {
1762 switch (predicate) {
1763 case arith::CmpIPredicate::eq:
1765 case arith::CmpIPredicate::ne:
1767 case arith::CmpIPredicate::slt:
1768 return lhs.slt(rhs);
1769 case arith::CmpIPredicate::sle:
1770 return lhs.sle(rhs);
1771 case arith::CmpIPredicate::sgt:
1772 return lhs.sgt(rhs);
1773 case arith::CmpIPredicate::sge:
1774 return lhs.sge(rhs);
1775 case arith::CmpIPredicate::ult:
1776 return lhs.ult(rhs);
1777 case arith::CmpIPredicate::ule:
1778 return lhs.ule(rhs);
1779 case arith::CmpIPredicate::ugt:
1780 return lhs.ugt(rhs);
1781 case arith::CmpIPredicate::uge:
1782 return lhs.uge(rhs);
1784 llvm_unreachable(
"unknown cmpi predicate kind");
1789 switch (predicate) {
1790 case arith::CmpIPredicate::eq:
1791 case arith::CmpIPredicate::sle:
1792 case arith::CmpIPredicate::sge:
1793 case arith::CmpIPredicate::ule:
1794 case arith::CmpIPredicate::uge:
1796 case arith::CmpIPredicate::ne:
1797 case arith::CmpIPredicate::slt:
1798 case arith::CmpIPredicate::sgt:
1799 case arith::CmpIPredicate::ult:
1800 case arith::CmpIPredicate::ugt:
1803 llvm_unreachable(
"unknown cmpi predicate kind");
1807 if (
auto intType = llvm::dyn_cast<IntegerType>(t)) {
1808 return intType.getWidth();
1810 if (
auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1811 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1813 return std::nullopt;
1816 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1818 if (getLhs() == getRhs()) {
1824 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1826 std::optional<int64_t> integerWidth =
1828 if (integerWidth && integerWidth.value() == 1 &&
1829 getPredicate() == arith::CmpIPredicate::ne)
1830 return extOp.getOperand();
1832 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1834 std::optional<int64_t> integerWidth =
1836 if (integerWidth && integerWidth.value() == 1 &&
1837 getPredicate() == arith::CmpIPredicate::ne)
1838 return extOp.getOperand();
1843 if (adaptor.getLhs() && !adaptor.getRhs()) {
1845 using Pred = CmpIPredicate;
1846 const std::pair<Pred, Pred> invPreds[] = {
1847 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1848 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1849 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1850 {Pred::ne, Pred::ne},
1852 Pred origPred = getPredicate();
1853 for (
auto pred : invPreds) {
1854 if (origPred == pred.first) {
1855 setPredicate(pred.second);
1856 Value lhs = getLhs();
1857 Value rhs = getRhs();
1858 getLhsMutable().assign(rhs);
1859 getRhsMutable().assign(lhs);
1863 llvm_unreachable(
"unknown cmpi predicate kind");
1868 if (
auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1869 return constFoldBinaryOp<IntegerAttr>(
1871 [pred = getPredicate()](
const APInt &lhs,
const APInt &rhs) {
1882 patterns.
insert<CmpIExtSI, CmpIExtUI>(context);
1892 const APFloat &lhs,
const APFloat &rhs) {
1893 auto cmpResult = lhs.compare(rhs);
1894 switch (predicate) {
1895 case arith::CmpFPredicate::AlwaysFalse:
1897 case arith::CmpFPredicate::OEQ:
1898 return cmpResult == APFloat::cmpEqual;
1899 case arith::CmpFPredicate::OGT:
1900 return cmpResult == APFloat::cmpGreaterThan;
1901 case arith::CmpFPredicate::OGE:
1902 return cmpResult == APFloat::cmpGreaterThan ||
1903 cmpResult == APFloat::cmpEqual;
1904 case arith::CmpFPredicate::OLT:
1905 return cmpResult == APFloat::cmpLessThan;
1906 case arith::CmpFPredicate::OLE:
1907 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1908 case arith::CmpFPredicate::ONE:
1909 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1910 case arith::CmpFPredicate::ORD:
1911 return cmpResult != APFloat::cmpUnordered;
1912 case arith::CmpFPredicate::UEQ:
1913 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1914 case arith::CmpFPredicate::UGT:
1915 return cmpResult == APFloat::cmpUnordered ||
1916 cmpResult == APFloat::cmpGreaterThan;
1917 case arith::CmpFPredicate::UGE:
1918 return cmpResult == APFloat::cmpUnordered ||
1919 cmpResult == APFloat::cmpGreaterThan ||
1920 cmpResult == APFloat::cmpEqual;
1921 case arith::CmpFPredicate::ULT:
1922 return cmpResult == APFloat::cmpUnordered ||
1923 cmpResult == APFloat::cmpLessThan;
1924 case arith::CmpFPredicate::ULE:
1925 return cmpResult == APFloat::cmpUnordered ||
1926 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1927 case arith::CmpFPredicate::UNE:
1928 return cmpResult != APFloat::cmpEqual;
1929 case arith::CmpFPredicate::UNO:
1930 return cmpResult == APFloat::cmpUnordered;
1931 case arith::CmpFPredicate::AlwaysTrue:
1934 llvm_unreachable(
"unknown cmpf predicate kind");
1937 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1938 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1939 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1942 if (lhs && lhs.getValue().isNaN())
1944 if (rhs && rhs.getValue().isNaN())
1960 using namespace arith;
1962 case CmpFPredicate::UEQ:
1963 case CmpFPredicate::OEQ:
1964 return CmpIPredicate::eq;
1965 case CmpFPredicate::UGT:
1966 case CmpFPredicate::OGT:
1967 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1968 case CmpFPredicate::UGE:
1969 case CmpFPredicate::OGE:
1970 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1971 case CmpFPredicate::ULT:
1972 case CmpFPredicate::OLT:
1973 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1974 case CmpFPredicate::ULE:
1975 case CmpFPredicate::OLE:
1976 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1977 case CmpFPredicate::UNE:
1978 case CmpFPredicate::ONE:
1979 return CmpIPredicate::ne;
1981 llvm_unreachable(
"Unexpected predicate!");
1991 const APFloat &rhs = flt.getValue();
1999 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2001 if (mantissaWidth <= 0)
2007 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2009 intVal = si.getIn();
2010 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2012 intVal = ui.getIn();
2019 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2020 auto intWidth = intTy.getWidth();
2023 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2028 if ((
int)intWidth > mantissaWidth) {
2030 int exponent = ilogb(rhs);
2031 if (exponent == APFloat::IEK_Inf) {
2032 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2033 if (maxExponent < (
int)valueBits) {
2040 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2049 switch (op.getPredicate()) {
2050 case CmpFPredicate::ORD:
2055 case CmpFPredicate::UNO:
2068 APFloat signedMax(rhs.getSemantics());
2069 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2070 APFloat::rmNearestTiesToEven);
2071 if (signedMax < rhs) {
2072 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2073 pred == CmpIPredicate::sle)
2084 APFloat unsignedMax(rhs.getSemantics());
2085 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2086 APFloat::rmNearestTiesToEven);
2087 if (unsignedMax < rhs) {
2088 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2089 pred == CmpIPredicate::ule)
2101 APFloat signedMin(rhs.getSemantics());
2102 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2103 APFloat::rmNearestTiesToEven);
2104 if (signedMin > rhs) {
2105 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2106 pred == CmpIPredicate::sge)
2116 APFloat unsignedMin(rhs.getSemantics());
2117 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2118 APFloat::rmNearestTiesToEven);
2119 if (unsignedMin > rhs) {
2120 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2121 pred == CmpIPredicate::uge)
2136 APSInt rhsInt(intWidth, isUnsigned);
2137 if (APFloat::opInvalidOp ==
2138 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2144 if (!rhs.isZero()) {
2147 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2149 bool equal = apf == rhs;
2155 case CmpIPredicate::ne:
2159 case CmpIPredicate::eq:
2163 case CmpIPredicate::ule:
2166 if (rhs.isNegative()) {
2172 case CmpIPredicate::sle:
2175 if (rhs.isNegative())
2176 pred = CmpIPredicate::slt;
2178 case CmpIPredicate::ult:
2181 if (rhs.isNegative()) {
2186 pred = CmpIPredicate::ule;
2188 case CmpIPredicate::slt:
2191 if (!rhs.isNegative())
2192 pred = CmpIPredicate::sle;
2194 case CmpIPredicate::ugt:
2197 if (rhs.isNegative()) {
2203 case CmpIPredicate::sgt:
2206 if (rhs.isNegative())
2207 pred = CmpIPredicate::sge;
2209 case CmpIPredicate::uge:
2212 if (rhs.isNegative()) {
2217 pred = CmpIPredicate::ugt;
2219 case CmpIPredicate::sge:
2222 if (!rhs.isNegative())
2223 pred = CmpIPredicate::sgt;
2233 rewriter.
create<ConstantOp>(
2256 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2272 rewriter.
create<arith::XOrIOp>(
2273 op.
getLoc(), op.getCondition(),
2274 rewriter.
create<arith::ConstantIntOp>(
2275 op.
getLoc(), 1, op.getCondition().getType())));
2285 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2289 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2290 Value trueVal = getTrueValue();
2291 Value falseVal = getFalseValue();
2292 if (trueVal == falseVal)
2295 Value condition = getCondition();
2306 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2309 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2317 if (
auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.
getDefiningOp())) {
2318 auto pred = cmp.getPredicate();
2319 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2320 auto cmpLhs = cmp.getLhs();
2321 auto cmpRhs = cmp.getRhs();
2329 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2330 (cmpRhs == trueVal && cmpLhs == falseVal))
2331 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2338 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2340 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2342 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2344 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2345 auto condVals = llvm::make_range(cond.value_begin<
BoolAttr>(),
2347 auto lhsVals = llvm::make_range(lhs.value_begin<
Attribute>(),
2349 auto rhsVals = llvm::make_range(rhs.value_begin<
Attribute>(),
2352 for (
auto [condVal, lhsVal, rhsVal] :
2353 llvm::zip_equal(condVals, lhsVals, rhsVals))
2354 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2365 Type conditionType, resultType;
2374 conditionType = resultType;
2383 {conditionType, resultType, resultType},
2388 p <<
" " << getOperands();
2391 if (ShapedType condType =
2392 llvm::dyn_cast<ShapedType>(getCondition().
getType()))
2393 p << condType <<
", ";
2398 Type conditionType = getCondition().getType();
2405 if (!llvm::isa<TensorType, VectorType>(resultType))
2406 return emitOpError() <<
"expected condition to be a signless i1, but got "
2409 if (conditionType != shapedConditionType) {
2410 return emitOpError() <<
"expected condition type to have the same shape "
2411 "as the result type, expected "
2412 << shapedConditionType <<
", but got "
2421 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2426 bool bounded =
false;
2427 auto result = constFoldBinaryOp<IntegerAttr>(
2428 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2429 bounded = b.ult(b.getBitWidth());
2439 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2444 bool bounded =
false;
2445 auto result = constFoldBinaryOp<IntegerAttr>(
2446 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2447 bounded = b.ult(b.getBitWidth());
2457 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2462 bool bounded =
false;
2463 auto result = constFoldBinaryOp<IntegerAttr>(
2464 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2465 bounded = b.ult(b.getBitWidth());
2478 bool useOnlyFiniteValue) {
2480 case AtomicRMWKind::maximumf: {
2481 const llvm::fltSemantics &semantic =
2482 llvm::cast<FloatType>(resultType).getFloatSemantics();
2483 APFloat identity = useOnlyFiniteValue
2484 ? APFloat::getLargest(semantic,
true)
2485 : APFloat::getInf(semantic,
true);
2488 case AtomicRMWKind::maxnumf: {
2489 const llvm::fltSemantics &semantic =
2490 llvm::cast<FloatType>(resultType).getFloatSemantics();
2491 APFloat identity = APFloat::getNaN(semantic,
true);
2494 case AtomicRMWKind::addf:
2495 case AtomicRMWKind::addi:
2496 case AtomicRMWKind::maxu:
2497 case AtomicRMWKind::ori:
2499 case AtomicRMWKind::andi:
2502 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2503 case AtomicRMWKind::maxs:
2505 resultType, APInt::getSignedMinValue(
2506 llvm::cast<IntegerType>(resultType).getWidth()));
2507 case AtomicRMWKind::minimumf: {
2508 const llvm::fltSemantics &semantic =
2509 llvm::cast<FloatType>(resultType).getFloatSemantics();
2510 APFloat identity = useOnlyFiniteValue
2511 ? APFloat::getLargest(semantic,
false)
2512 : APFloat::getInf(semantic,
false);
2516 case AtomicRMWKind::minnumf: {
2517 const llvm::fltSemantics &semantic =
2518 llvm::cast<FloatType>(resultType).getFloatSemantics();
2519 APFloat identity = APFloat::getNaN(semantic,
false);
2522 case AtomicRMWKind::mins:
2524 resultType, APInt::getSignedMaxValue(
2525 llvm::cast<IntegerType>(resultType).getWidth()));
2526 case AtomicRMWKind::minu:
2529 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2530 case AtomicRMWKind::muli:
2532 case AtomicRMWKind::mulf:
2544 std::optional<AtomicRMWKind> maybeKind =
2547 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2548 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2549 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2550 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2551 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2552 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2554 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2555 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2556 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::ori; })
2557 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2558 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2559 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2560 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2561 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2562 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2563 .Default([](
Operation *op) {
return std::nullopt; });
2565 return std::nullopt;
2568 bool useOnlyFiniteValue =
false;
2569 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2570 if (fmfOpInterface) {
2571 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2572 useOnlyFiniteValue =
2573 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2581 useOnlyFiniteValue);
2587 bool useOnlyFiniteValue) {
2590 return builder.
create<arith::ConstantOp>(loc, attr);
2598 case AtomicRMWKind::addf:
2599 return builder.
create<arith::AddFOp>(loc, lhs, rhs);
2600 case AtomicRMWKind::addi:
2601 return builder.
create<arith::AddIOp>(loc, lhs, rhs);
2602 case AtomicRMWKind::mulf:
2603 return builder.
create<arith::MulFOp>(loc, lhs, rhs);
2604 case AtomicRMWKind::muli:
2605 return builder.
create<arith::MulIOp>(loc, lhs, rhs);
2606 case AtomicRMWKind::maximumf:
2607 return builder.
create<arith::MaximumFOp>(loc, lhs, rhs);
2608 case AtomicRMWKind::minimumf:
2609 return builder.
create<arith::MinimumFOp>(loc, lhs, rhs);
2610 case AtomicRMWKind::maxnumf:
2611 return builder.
create<arith::MaxNumFOp>(loc, lhs, rhs);
2612 case AtomicRMWKind::minnumf:
2613 return builder.
create<arith::MinNumFOp>(loc, lhs, rhs);
2614 case AtomicRMWKind::maxs:
2615 return builder.
create<arith::MaxSIOp>(loc, lhs, rhs);
2616 case AtomicRMWKind::mins:
2617 return builder.
create<arith::MinSIOp>(loc, lhs, rhs);
2618 case AtomicRMWKind::maxu:
2619 return builder.
create<arith::MaxUIOp>(loc, lhs, rhs);
2620 case AtomicRMWKind::minu:
2621 return builder.
create<arith::MinUIOp>(loc, lhs, rhs);
2622 case AtomicRMWKind::ori:
2623 return builder.
create<arith::OrIOp>(loc, lhs, rhs);
2624 case AtomicRMWKind::andi:
2625 return builder.
create<arith::AndIOp>(loc, lhs, rhs);
2638 #define GET_OP_CLASSES
2639 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2645 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
Specialization of arith.constant op that returns an integer value.
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)