26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
45 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
46 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48 APInt value = binFn(lhsVal, rhsVal);
68 static IntegerOverflowFlagsAttr
70 IntegerOverflowFlagsAttr val2) {
72 val1.getValue() & val2.getValue());
78 case arith::CmpIPredicate::eq:
79 return arith::CmpIPredicate::ne;
80 case arith::CmpIPredicate::ne:
81 return arith::CmpIPredicate::eq;
82 case arith::CmpIPredicate::slt:
83 return arith::CmpIPredicate::sge;
84 case arith::CmpIPredicate::sle:
85 return arith::CmpIPredicate::sgt;
86 case arith::CmpIPredicate::sgt:
87 return arith::CmpIPredicate::sle;
88 case arith::CmpIPredicate::sge:
89 return arith::CmpIPredicate::slt;
90 case arith::CmpIPredicate::ult:
91 return arith::CmpIPredicate::uge;
92 case arith::CmpIPredicate::ule:
93 return arith::CmpIPredicate::ugt;
94 case arith::CmpIPredicate::ugt:
95 return arith::CmpIPredicate::ule;
96 case arith::CmpIPredicate::uge:
97 return arith::CmpIPredicate::ult;
99 llvm_unreachable(
"unknown cmpi predicate kind");
108 static llvm::RoundingMode
110 switch (roundingMode) {
111 case RoundingMode::downward:
112 return llvm::RoundingMode::TowardNegative;
113 case RoundingMode::to_nearest_away:
114 return llvm::RoundingMode::NearestTiesToAway;
115 case RoundingMode::to_nearest_even:
116 return llvm::RoundingMode::NearestTiesToEven;
117 case RoundingMode::toward_zero:
118 return llvm::RoundingMode::TowardZero;
119 case RoundingMode::upward:
120 return llvm::RoundingMode::TowardPositive;
122 llvm_unreachable(
"Unhandled rounding mode");
152 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
163 #include "ArithCanonicalization.inc"
173 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
174 return shapedType.cloneWith(std::nullopt, i1Type);
175 if (llvm::isa<UnrankedTensorType>(type))
184 void arith::ConstantOp::getAsmResultNames(
187 if (
auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188 auto intType = llvm::dyn_cast<IntegerType>(type);
191 if (intType && intType.getWidth() == 1)
192 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
196 llvm::raw_svector_ostream specialName(specialNameBuffer);
197 specialName <<
'c' << intCst.getValue();
199 specialName <<
'_' << type;
200 setNameFn(getResult(), specialName.str());
202 setNameFn(getResult(),
"cst");
211 if (getValue().
getType() != type) {
212 return emitOpError() <<
"value type " << getValue().getType()
213 <<
" must match return type: " << type;
216 if (llvm::isa<IntegerType>(type) &&
217 !llvm::cast<IntegerType>(type).isSignless())
218 return emitOpError(
"integer return type must be signless");
220 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
222 "value must be an integer, float, or elements attribute");
228 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
230 "intializing scalable vectors with elements attribute is not supported"
231 " unless it's a vector splat");
235 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
237 auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
238 if (!typedAttr || typedAttr.getType() != type)
241 if (llvm::isa<IntegerType>(type) &&
242 !llvm::cast<IntegerType>(type).isSignless())
245 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
250 if (isBuildableWith(value, type))
251 return builder.
create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
255 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
258 int64_t value,
unsigned width) {
260 arith::ConstantOp::build(builder, result, type,
265 int64_t value,
Type type) {
267 "ConstantIntOp can only have signless integer type values");
268 arith::ConstantOp::build(builder, result, type,
273 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
274 return constOp.getType().isSignlessInteger();
280 arith::ConstantOp::build(builder, result, type,
285 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
286 return llvm::isa<FloatType>(constOp.getType());
292 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
297 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
298 return constOp.getType().isIndex();
312 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
313 if (getRhs() == sub.getRhs())
317 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
318 if (getLhs() == sub.getRhs())
321 return constFoldBinaryOp<IntegerAttr>(
322 adaptor.getOperands(),
323 [](APInt a,
const APInt &b) { return std::move(a) + b; });
328 patterns.
add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
329 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
336 std::optional<SmallVector<int64_t, 4>>
337 arith::AddUIExtendedOp::getShapeForUnroll() {
338 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
339 return llvm::to_vector<4>(vt.getShape());
346 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
350 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
352 Type overflowTy = getOverflow().getType();
358 results.push_back(getLhs());
359 results.push_back(falseValue);
367 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
368 adaptor.getOperands(),
369 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
370 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
371 ArrayRef({sumAttr, adaptor.getLhs()}),
377 results.push_back(sumAttr);
378 results.push_back(overflowAttr);
385 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
387 patterns.
add<AddUIExtendedToAddI>(context);
396 if (getOperand(0) == getOperand(1)) {
397 auto shapedType = dyn_cast<ShapedType>(
getType());
399 if (!shapedType || shapedType.hasStaticShape())
406 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
408 if (getRhs() == add.getRhs())
411 if (getRhs() == add.getLhs())
415 return constFoldBinaryOp<IntegerAttr>(
416 adaptor.getOperands(),
417 [](APInt a,
const APInt &b) { return std::move(a) - b; });
422 patterns.
add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
423 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
424 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
441 return constFoldBinaryOp<IntegerAttr>(
442 adaptor.getOperands(),
443 [](
const APInt &a,
const APInt &b) { return a * b; });
446 void arith::MulIOp::getAsmResultNames(
448 if (!isa<IndexType>(
getType()))
454 return op && op->getName().getStringRef() ==
"vector.vscale";
457 IntegerAttr baseValue;
460 isVscale(b.getDefiningOp());
463 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
468 llvm::raw_svector_ostream specialName(specialNameBuffer);
469 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
470 setNameFn(getResult(), specialName.str());
475 patterns.
add<MulIMulIConstant>(context);
482 std::optional<SmallVector<int64_t, 4>>
483 arith::MulSIExtendedOp::getShapeForUnroll() {
484 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
485 return llvm::to_vector<4>(vt.getShape());
490 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
495 results.push_back(zero);
496 results.push_back(zero);
501 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
502 adaptor.getOperands(),
503 [](
const APInt &a,
const APInt &b) { return a * b; })) {
505 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
506 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
507 return llvm::APIntOps::mulhs(a, b);
509 assert(highAttr &&
"Unexpected constant-folding failure");
511 results.push_back(lowAttr);
512 results.push_back(highAttr);
519 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
521 patterns.
add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
528 std::optional<SmallVector<int64_t, 4>>
529 arith::MulUIExtendedOp::getShapeForUnroll() {
530 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
531 return llvm::to_vector<4>(vt.getShape());
536 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
541 results.push_back(zero);
542 results.push_back(zero);
550 results.push_back(getLhs());
551 results.push_back(zero);
556 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
557 adaptor.getOperands(),
558 [](
const APInt &a,
const APInt &b) { return a * b; })) {
560 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
561 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
562 return llvm::APIntOps::mulhu(a, b);
564 assert(highAttr &&
"Unexpected constant-folding failure");
566 results.push_back(lowAttr);
567 results.push_back(highAttr);
574 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
576 patterns.
add<MulUIExtendedToMulI>(context);
583 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
590 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
591 [&](APInt a,
const APInt &b) {
619 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
625 bool overflowOrDiv0 =
false;
626 auto result = constFoldBinaryOp<IntegerAttr>(
627 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
628 if (overflowOrDiv0 || !b) {
629 overflowOrDiv0 = true;
632 return a.sdiv_ov(b, overflowOrDiv0);
635 return overflowOrDiv0 ?
Attribute() : result;
662 APInt one(a.getBitWidth(), 1,
true);
663 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
664 return val.sadd_ov(one, overflow);
671 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
676 bool overflowOrDiv0 =
false;
677 auto result = constFoldBinaryOp<IntegerAttr>(
678 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
679 if (overflowOrDiv0 || !b) {
680 overflowOrDiv0 = true;
683 APInt quotient = a.udiv(b);
686 APInt one(a.getBitWidth(), 1,
true);
687 return quotient.uadd_ov(one, overflowOrDiv0);
690 return overflowOrDiv0 ?
Attribute() : result;
701 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
709 bool overflowOrDiv0 =
false;
710 auto result = constFoldBinaryOp<IntegerAttr>(
711 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
712 if (overflowOrDiv0 || !b) {
713 overflowOrDiv0 = true;
719 unsigned bits = a.getBitWidth();
721 bool aGtZero = a.sgt(zero);
722 bool bGtZero = b.sgt(zero);
723 if (aGtZero && bGtZero) {
730 bool overflowNegA =
false;
731 bool overflowNegB =
false;
732 bool overflowDiv =
false;
733 bool overflowNegRes =
false;
734 if (!aGtZero && !bGtZero) {
736 APInt posA = zero.ssub_ov(a, overflowNegA);
737 APInt posB = zero.ssub_ov(b, overflowNegB);
739 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
742 if (!aGtZero && bGtZero) {
744 APInt posA = zero.ssub_ov(a, overflowNegA);
745 APInt div = posA.sdiv_ov(b, overflowDiv);
746 APInt res = zero.ssub_ov(div, overflowNegRes);
747 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
751 APInt posB = zero.ssub_ov(b, overflowNegB);
752 APInt div = a.sdiv_ov(posB, overflowDiv);
753 APInt res = zero.ssub_ov(div, overflowNegRes);
755 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
759 return overflowOrDiv0 ?
Attribute() : result;
770 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
776 bool overflowOrDiv =
false;
777 auto result = constFoldBinaryOp<IntegerAttr>(
778 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
780 overflowOrDiv = true;
783 return a.sfloordiv_ov(b, overflowOrDiv);
786 return overflowOrDiv ?
Attribute() : result;
793 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
800 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
801 [&](APInt a,
const APInt &b) {
802 if (div0 || b.isZero()) {
816 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
823 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
824 [&](APInt a,
const APInt &b) {
825 if (div0 || b.isZero()) {
841 for (
bool reversePrev : {
false,
true}) {
842 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
843 .getDefiningOp<arith::AndIOp>();
847 Value other = (reversePrev ? op.getLhs() : op.getRhs());
848 if (other != prev.getLhs() && other != prev.getRhs())
851 return prev.getResult();
863 intValue.isAllOnes())
868 intValue.isAllOnes())
873 intValue.isAllOnes())
880 return constFoldBinaryOp<IntegerAttr>(
881 adaptor.getOperands(),
882 [](APInt a,
const APInt &b) { return std::move(a) & b; });
895 if (rhsVal.isAllOnes())
896 return adaptor.getRhs();
903 intValue.isAllOnes())
904 return getRhs().getDefiningOp<XOrIOp>().getRhs();
908 intValue.isAllOnes())
909 return getLhs().getDefiningOp<XOrIOp>().getRhs();
911 return constFoldBinaryOp<IntegerAttr>(
912 adaptor.getOperands(),
913 [](APInt a,
const APInt &b) { return std::move(a) | b; });
925 if (getLhs() == getRhs())
929 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
930 if (prev.getRhs() == getRhs())
931 return prev.getLhs();
932 if (prev.getLhs() == getRhs())
933 return prev.getRhs();
937 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
938 if (prev.getRhs() == getLhs())
939 return prev.getLhs();
940 if (prev.getLhs() == getLhs())
941 return prev.getRhs();
944 return constFoldBinaryOp<IntegerAttr>(
945 adaptor.getOperands(),
946 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
951 patterns.
add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
960 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
961 return op.getOperand();
962 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
963 [](
const APFloat &a) { return -a; });
975 return constFoldBinaryOp<FloatAttr>(
976 adaptor.getOperands(),
977 [](
const APFloat &a,
const APFloat &b) { return a + b; });
992 bool isPositiveZeroMode =
993 getDenormalModeAttr().getValue() == DenormalMode::positive_zero;
997 return constFoldBinaryOp<FloatAttr>(
998 adaptor.getOperands(),
999 [](
const APFloat &a,
const APFloat &b) { return a - b; });
1006 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1008 if (getLhs() == getRhs())
1015 return constFoldBinaryOp<FloatAttr>(
1016 adaptor.getOperands(),
1017 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
1024 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1026 if (getLhs() == getRhs())
1033 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1042 if (getLhs() == getRhs())
1048 if (intValue.isMaxSignedValue())
1051 if (intValue.isMinSignedValue())
1055 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1056 [](
const APInt &a,
const APInt &b) {
1057 return llvm::APIntOps::smax(a, b);
1067 if (getLhs() == getRhs())
1073 if (intValue.isMaxValue())
1076 if (intValue.isMinValue())
1080 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1081 [](
const APInt &a,
const APInt &b) {
1082 return llvm::APIntOps::umax(a, b);
1090 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1092 if (getLhs() == getRhs())
1099 return constFoldBinaryOp<FloatAttr>(
1100 adaptor.getOperands(),
1101 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
1108 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1110 if (getLhs() == getRhs())
1117 return constFoldBinaryOp<FloatAttr>(
1118 adaptor.getOperands(),
1119 [](
const APFloat &a,
const APFloat &b) { return llvm::minnum(a, b); });
1128 if (getLhs() == getRhs())
1134 if (intValue.isMinSignedValue())
1137 if (intValue.isMaxSignedValue())
1141 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1142 [](
const APInt &a,
const APInt &b) {
1143 return llvm::APIntOps::smin(a, b);
1153 if (getLhs() == getRhs())
1159 if (intValue.isMinValue())
1162 if (intValue.isMaxValue())
1166 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1167 [](
const APInt &a,
const APInt &b) {
1168 return llvm::APIntOps::umin(a, b);
1176 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1181 return constFoldBinaryOp<FloatAttr>(
1182 adaptor.getOperands(),
1183 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1188 patterns.
add<MulFOfNegF>(context);
1195 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1200 return constFoldBinaryOp<FloatAttr>(
1201 adaptor.getOperands(),
1202 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1207 patterns.
add<DivFOfNegF>(context);
1214 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1215 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1216 [](
const APFloat &a,
const APFloat &b) {
1221 (void)result.mod(b);
1230 template <
typename... Types>
1236 template <
typename... ShapedTypes,
typename... ElementTypes>
1239 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1243 if (!llvm::isa<ElementTypes...>(underlyingType))
1246 return underlyingType;
1250 template <
typename... ElementTypes>
1257 template <
typename... ElementTypes>
1266 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1267 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1268 if (!rankedTensorA || !rankedTensorB)
1270 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1274 if (inputs.size() != 1 || outputs.size() != 1)
1286 template <
typename ValType,
typename Op>
1291 if (llvm::cast<ValType>(srcType).getWidth() >=
1292 llvm::cast<ValType>(dstType).getWidth())
1294 << dstType <<
" must be wider than operand type " << srcType;
1300 template <
typename ValType,
typename Op>
1305 if (llvm::cast<ValType>(srcType).getWidth() <=
1306 llvm::cast<ValType>(dstType).getWidth())
1308 << dstType <<
" must be shorter than operand type " << srcType;
1314 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1319 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1320 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1321 if (!srcType || !dstType)
1324 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1325 srcType.getIntOrFloatBitWidth());
1331 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1332 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1333 bool losesInfo =
false;
1334 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1335 if (losesInfo || status != APFloat::opOK)
1345 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1346 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1347 getInMutable().assign(lhs.getIn());
1352 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1353 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1354 adaptor.getOperands(),
getType(),
1355 [bitWidth](
const APInt &a,
bool &castStatus) {
1356 return a.zext(bitWidth);
1361 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1365 return verifyExtOp<IntegerType>(*
this);
1372 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1373 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1374 getInMutable().assign(lhs.getIn());
1379 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1380 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1381 adaptor.getOperands(),
getType(),
1382 [bitWidth](
const APInt &a,
bool &castStatus) {
1383 return a.sext(bitWidth);
1388 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1393 patterns.
add<ExtSIOfExtUI>(context);
1397 return verifyExtOp<IntegerType>(*
this);
1406 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1407 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1408 if (truncFOp.getOperand().getType() ==
getType()) {
1409 arith::FastMathFlags truncFMF =
1410 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1411 bool isTruncContract =
1413 arith::FastMathFlags extFMF =
1414 getFastmath().value_or(arith::FastMathFlags::none);
1415 bool isExtContract =
1417 if (isTruncContract && isExtContract) {
1418 return truncFOp.getOperand();
1424 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1425 return constFoldCastOp<FloatAttr, FloatAttr>(
1426 adaptor.getOperands(),
getType(),
1427 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1429 if (failed(result)) {
1438 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1447 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1448 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1455 if (llvm::cast<IntegerType>(srcType).getWidth() >
1456 llvm::cast<IntegerType>(dstType).getWidth()) {
1463 if (srcType == dstType)
1468 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1469 setOperand(getOperand().getDefiningOp()->getOperand(0));
1474 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1475 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1476 adaptor.getOperands(),
getType(),
1477 [bitWidth](
const APInt &a,
bool &castStatus) {
1478 return a.trunc(bitWidth);
1483 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1488 patterns.
add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1489 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1494 return verifyTruncateOp<IntegerType>(*
this);
1503 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1505 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1506 return constFoldCastOp<FloatAttr, FloatAttr>(
1507 adaptor.getOperands(),
getType(),
1508 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1509 RoundingMode roundingMode =
1510 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1511 llvm::RoundingMode llvmRoundingMode =
1513 FailureOr<APFloat> result =
1515 if (failed(result)) {
1524 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1528 return verifyTruncateOp<FloatType>(*
this);
1537 patterns.
add<AndOfExtUI, AndOfExtSI>(context);
1546 patterns.
add<OrOfExtUI, OrOfExtSI>(context);
1553 template <
typename From,
typename To>
1558 auto srcType = getTypeIfLike<From>(inputs.front());
1559 auto dstType = getTypeIfLike<To>(outputs.back());
1561 return srcType && dstType;
1569 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1572 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1574 return constFoldCastOp<IntegerAttr, FloatAttr>(
1575 adaptor.getOperands(),
getType(),
1576 [&resEleType](
const APInt &a,
bool &castStatus) {
1577 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1580 apf.convertFromAPInt(a,
false,
1581 APFloat::rmNearestTiesToEven);
1591 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1594 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1596 return constFoldCastOp<IntegerAttr, FloatAttr>(
1597 adaptor.getOperands(),
getType(),
1598 [&resEleType](
const APInt &a,
bool &castStatus) {
1599 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1602 apf.convertFromAPInt(a,
true,
1603 APFloat::rmNearestTiesToEven);
1613 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1616 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1618 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1619 return constFoldCastOp<FloatAttr, IntegerAttr>(
1620 adaptor.getOperands(),
getType(),
1621 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1623 APSInt api(bitWidth,
true);
1624 castStatus = APFloat::opInvalidOp !=
1625 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1635 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1638 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1640 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1641 return constFoldCastOp<FloatAttr, IntegerAttr>(
1642 adaptor.getOperands(),
getType(),
1643 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1645 APSInt api(bitWidth,
false);
1646 castStatus = APFloat::opInvalidOp !=
1647 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1660 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1661 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1662 if (!srcType || !dstType)
1669 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1674 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1676 unsigned resultBitwidth = 64;
1678 resultBitwidth = intTy.getWidth();
1680 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1681 adaptor.getOperands(),
getType(),
1682 [resultBitwidth](
const APInt &a,
bool & ) {
1683 return a.sextOrTrunc(resultBitwidth);
1687 void arith::IndexCastOp::getCanonicalizationPatterns(
1689 patterns.
add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1696 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1701 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1703 unsigned resultBitwidth = 64;
1705 resultBitwidth = intTy.getWidth();
1707 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1708 adaptor.getOperands(),
getType(),
1709 [resultBitwidth](
const APInt &a,
bool & ) {
1710 return a.zextOrTrunc(resultBitwidth);
1714 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1716 patterns.
add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1728 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1730 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1731 if (!srcType || !dstType)
1737 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1739 auto operand = adaptor.getIn();
1744 if (
auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1745 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1747 if (llvm::isa<ShapedType>(resType))
1751 APInt bits = llvm::isa<FloatAttr>(operand)
1752 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1753 : llvm::cast<IntegerAttr>(operand).getValue();
1755 "trying to fold on broken IR: operands have incompatible types");
1757 if (
auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1759 APFloat(resFloatType.getFloatSemantics(), bits));
1765 patterns.
add<BitcastOfBitcast>(context);
1775 const APInt &lhs,
const APInt &rhs) {
1776 switch (predicate) {
1777 case arith::CmpIPredicate::eq:
1779 case arith::CmpIPredicate::ne:
1781 case arith::CmpIPredicate::slt:
1782 return lhs.slt(rhs);
1783 case arith::CmpIPredicate::sle:
1784 return lhs.sle(rhs);
1785 case arith::CmpIPredicate::sgt:
1786 return lhs.sgt(rhs);
1787 case arith::CmpIPredicate::sge:
1788 return lhs.sge(rhs);
1789 case arith::CmpIPredicate::ult:
1790 return lhs.ult(rhs);
1791 case arith::CmpIPredicate::ule:
1792 return lhs.ule(rhs);
1793 case arith::CmpIPredicate::ugt:
1794 return lhs.ugt(rhs);
1795 case arith::CmpIPredicate::uge:
1796 return lhs.uge(rhs);
1798 llvm_unreachable(
"unknown cmpi predicate kind");
1803 switch (predicate) {
1804 case arith::CmpIPredicate::eq:
1805 case arith::CmpIPredicate::sle:
1806 case arith::CmpIPredicate::sge:
1807 case arith::CmpIPredicate::ule:
1808 case arith::CmpIPredicate::uge:
1810 case arith::CmpIPredicate::ne:
1811 case arith::CmpIPredicate::slt:
1812 case arith::CmpIPredicate::sgt:
1813 case arith::CmpIPredicate::ult:
1814 case arith::CmpIPredicate::ugt:
1817 llvm_unreachable(
"unknown cmpi predicate kind");
1821 if (
auto intType = llvm::dyn_cast<IntegerType>(t)) {
1822 return intType.getWidth();
1824 if (
auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1825 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1827 return std::nullopt;
1830 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1832 if (getLhs() == getRhs()) {
1838 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1840 std::optional<int64_t> integerWidth =
1842 if (integerWidth && integerWidth.value() == 1 &&
1843 getPredicate() == arith::CmpIPredicate::ne)
1844 return extOp.getOperand();
1846 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1848 std::optional<int64_t> integerWidth =
1850 if (integerWidth && integerWidth.value() == 1 &&
1851 getPredicate() == arith::CmpIPredicate::ne)
1852 return extOp.getOperand();
1857 if (adaptor.getLhs() && !adaptor.getRhs()) {
1859 using Pred = CmpIPredicate;
1860 const std::pair<Pred, Pred> invPreds[] = {
1861 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1862 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1863 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1864 {Pred::ne, Pred::ne},
1866 Pred origPred = getPredicate();
1867 for (
auto pred : invPreds) {
1868 if (origPred == pred.first) {
1869 setPredicate(pred.second);
1870 Value lhs = getLhs();
1871 Value rhs = getRhs();
1872 getLhsMutable().assign(rhs);
1873 getRhsMutable().assign(lhs);
1877 llvm_unreachable(
"unknown cmpi predicate kind");
1882 if (
auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1883 return constFoldBinaryOp<IntegerAttr>(
1885 [pred = getPredicate()](
const APInt &lhs,
const APInt &rhs) {
1896 patterns.
insert<CmpIExtSI, CmpIExtUI>(context);
1906 const APFloat &lhs,
const APFloat &rhs) {
1907 auto cmpResult = lhs.compare(rhs);
1908 switch (predicate) {
1909 case arith::CmpFPredicate::AlwaysFalse:
1911 case arith::CmpFPredicate::OEQ:
1912 return cmpResult == APFloat::cmpEqual;
1913 case arith::CmpFPredicate::OGT:
1914 return cmpResult == APFloat::cmpGreaterThan;
1915 case arith::CmpFPredicate::OGE:
1916 return cmpResult == APFloat::cmpGreaterThan ||
1917 cmpResult == APFloat::cmpEqual;
1918 case arith::CmpFPredicate::OLT:
1919 return cmpResult == APFloat::cmpLessThan;
1920 case arith::CmpFPredicate::OLE:
1921 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1922 case arith::CmpFPredicate::ONE:
1923 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1924 case arith::CmpFPredicate::ORD:
1925 return cmpResult != APFloat::cmpUnordered;
1926 case arith::CmpFPredicate::UEQ:
1927 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1928 case arith::CmpFPredicate::UGT:
1929 return cmpResult == APFloat::cmpUnordered ||
1930 cmpResult == APFloat::cmpGreaterThan;
1931 case arith::CmpFPredicate::UGE:
1932 return cmpResult == APFloat::cmpUnordered ||
1933 cmpResult == APFloat::cmpGreaterThan ||
1934 cmpResult == APFloat::cmpEqual;
1935 case arith::CmpFPredicate::ULT:
1936 return cmpResult == APFloat::cmpUnordered ||
1937 cmpResult == APFloat::cmpLessThan;
1938 case arith::CmpFPredicate::ULE:
1939 return cmpResult == APFloat::cmpUnordered ||
1940 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1941 case arith::CmpFPredicate::UNE:
1942 return cmpResult != APFloat::cmpEqual;
1943 case arith::CmpFPredicate::UNO:
1944 return cmpResult == APFloat::cmpUnordered;
1945 case arith::CmpFPredicate::AlwaysTrue:
1948 llvm_unreachable(
"unknown cmpf predicate kind");
1951 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1952 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1953 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1956 if (lhs && lhs.getValue().isNaN())
1958 if (rhs && rhs.getValue().isNaN())
1974 using namespace arith;
1976 case CmpFPredicate::UEQ:
1977 case CmpFPredicate::OEQ:
1978 return CmpIPredicate::eq;
1979 case CmpFPredicate::UGT:
1980 case CmpFPredicate::OGT:
1981 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1982 case CmpFPredicate::UGE:
1983 case CmpFPredicate::OGE:
1984 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1985 case CmpFPredicate::ULT:
1986 case CmpFPredicate::OLT:
1987 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1988 case CmpFPredicate::ULE:
1989 case CmpFPredicate::OLE:
1990 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1991 case CmpFPredicate::UNE:
1992 case CmpFPredicate::ONE:
1993 return CmpIPredicate::ne;
1995 llvm_unreachable(
"Unexpected predicate!");
2005 const APFloat &rhs = flt.getValue();
2013 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2015 if (mantissaWidth <= 0)
2021 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2023 intVal = si.getIn();
2024 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2026 intVal = ui.getIn();
2033 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2034 auto intWidth = intTy.getWidth();
2037 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2042 if ((
int)intWidth > mantissaWidth) {
2044 int exponent = ilogb(rhs);
2045 if (exponent == APFloat::IEK_Inf) {
2046 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2047 if (maxExponent < (
int)valueBits) {
2054 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2063 switch (op.getPredicate()) {
2064 case CmpFPredicate::ORD:
2069 case CmpFPredicate::UNO:
2082 APFloat signedMax(rhs.getSemantics());
2083 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2084 APFloat::rmNearestTiesToEven);
2085 if (signedMax < rhs) {
2086 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2087 pred == CmpIPredicate::sle)
2098 APFloat unsignedMax(rhs.getSemantics());
2099 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2100 APFloat::rmNearestTiesToEven);
2101 if (unsignedMax < rhs) {
2102 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2103 pred == CmpIPredicate::ule)
2115 APFloat signedMin(rhs.getSemantics());
2116 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2117 APFloat::rmNearestTiesToEven);
2118 if (signedMin > rhs) {
2119 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2120 pred == CmpIPredicate::sge)
2130 APFloat unsignedMin(rhs.getSemantics());
2131 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2132 APFloat::rmNearestTiesToEven);
2133 if (unsignedMin > rhs) {
2134 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2135 pred == CmpIPredicate::uge)
2150 APSInt rhsInt(intWidth, isUnsigned);
2151 if (APFloat::opInvalidOp ==
2152 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2158 if (!rhs.isZero()) {
2161 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2163 bool equal = apf == rhs;
2169 case CmpIPredicate::ne:
2173 case CmpIPredicate::eq:
2177 case CmpIPredicate::ule:
2180 if (rhs.isNegative()) {
2186 case CmpIPredicate::sle:
2189 if (rhs.isNegative())
2190 pred = CmpIPredicate::slt;
2192 case CmpIPredicate::ult:
2195 if (rhs.isNegative()) {
2200 pred = CmpIPredicate::ule;
2202 case CmpIPredicate::slt:
2205 if (!rhs.isNegative())
2206 pred = CmpIPredicate::sle;
2208 case CmpIPredicate::ugt:
2211 if (rhs.isNegative()) {
2217 case CmpIPredicate::sgt:
2220 if (rhs.isNegative())
2221 pred = CmpIPredicate::sge;
2223 case CmpIPredicate::uge:
2226 if (rhs.isNegative()) {
2231 pred = CmpIPredicate::ugt;
2233 case CmpIPredicate::sge:
2236 if (!rhs.isNegative())
2237 pred = CmpIPredicate::sgt;
2247 rewriter.
create<ConstantOp>(
2248 op.getLoc(), intVal.
getType(),
2270 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2286 rewriter.
create<arith::XOrIOp>(
2287 op.getLoc(), op.getCondition(),
2288 rewriter.
create<arith::ConstantIntOp>(
2289 op.getLoc(), 1, op.getCondition().getType())));
2299 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2303 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2304 Value trueVal = getTrueValue();
2305 Value falseVal = getFalseValue();
2306 if (trueVal == falseVal)
2309 Value condition = getCondition();
2320 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2323 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2327 if (
getType().isSignlessInteger(1) &&
2332 if (
auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.
getDefiningOp())) {
2333 auto pred = cmp.getPredicate();
2334 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2335 auto cmpLhs = cmp.getLhs();
2336 auto cmpRhs = cmp.getRhs();
2344 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2345 (cmpRhs == trueVal && cmpLhs == falseVal))
2346 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2353 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2355 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2357 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2359 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2360 auto condVals = llvm::make_range(cond.value_begin<
BoolAttr>(),
2362 auto lhsVals = llvm::make_range(lhs.value_begin<
Attribute>(),
2364 auto rhsVals = llvm::make_range(rhs.value_begin<
Attribute>(),
2367 for (
auto [condVal, lhsVal, rhsVal] :
2368 llvm::zip_equal(condVals, lhsVals, rhsVals))
2369 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2380 Type conditionType, resultType;
2389 conditionType = resultType;
2398 {conditionType, resultType, resultType},
2403 p <<
" " << getOperands();
2406 if (ShapedType condType =
2407 llvm::dyn_cast<ShapedType>(getCondition().
getType()))
2408 p << condType <<
", ";
2413 Type conditionType = getCondition().getType();
2420 if (!llvm::isa<TensorType, VectorType>(resultType))
2421 return emitOpError() <<
"expected condition to be a signless i1, but got "
2424 if (conditionType != shapedConditionType) {
2425 return emitOpError() <<
"expected condition type to have the same shape "
2426 "as the result type, expected "
2427 << shapedConditionType <<
", but got "
2436 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2441 bool bounded =
false;
2442 auto result = constFoldBinaryOp<IntegerAttr>(
2443 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2444 bounded = b.ult(b.getBitWidth());
2454 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2459 bool bounded =
false;
2460 auto result = constFoldBinaryOp<IntegerAttr>(
2461 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2462 bounded = b.ult(b.getBitWidth());
2472 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2477 bool bounded =
false;
2478 auto result = constFoldBinaryOp<IntegerAttr>(
2479 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2480 bounded = b.ult(b.getBitWidth());
2493 bool useOnlyFiniteValue) {
2495 case AtomicRMWKind::maximumf: {
2496 const llvm::fltSemantics &semantic =
2497 llvm::cast<FloatType>(resultType).getFloatSemantics();
2498 APFloat identity = useOnlyFiniteValue
2499 ? APFloat::getLargest(semantic,
true)
2500 : APFloat::getInf(semantic,
true);
2503 case AtomicRMWKind::maxnumf: {
2504 const llvm::fltSemantics &semantic =
2505 llvm::cast<FloatType>(resultType).getFloatSemantics();
2506 APFloat identity = APFloat::getNaN(semantic,
true);
2509 case AtomicRMWKind::addf:
2510 case AtomicRMWKind::addi:
2511 case AtomicRMWKind::maxu:
2512 case AtomicRMWKind::ori:
2514 case AtomicRMWKind::andi:
2517 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2518 case AtomicRMWKind::maxs:
2520 resultType, APInt::getSignedMinValue(
2521 llvm::cast<IntegerType>(resultType).getWidth()));
2522 case AtomicRMWKind::minimumf: {
2523 const llvm::fltSemantics &semantic =
2524 llvm::cast<FloatType>(resultType).getFloatSemantics();
2525 APFloat identity = useOnlyFiniteValue
2526 ? APFloat::getLargest(semantic,
false)
2527 : APFloat::getInf(semantic,
false);
2531 case AtomicRMWKind::minnumf: {
2532 const llvm::fltSemantics &semantic =
2533 llvm::cast<FloatType>(resultType).getFloatSemantics();
2534 APFloat identity = APFloat::getNaN(semantic,
false);
2537 case AtomicRMWKind::mins:
2539 resultType, APInt::getSignedMaxValue(
2540 llvm::cast<IntegerType>(resultType).getWidth()));
2541 case AtomicRMWKind::minu:
2544 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2545 case AtomicRMWKind::muli:
2547 case AtomicRMWKind::mulf:
2559 std::optional<AtomicRMWKind> maybeKind =
2562 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2563 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2564 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2565 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2566 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2567 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2569 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2570 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2571 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::ori; })
2572 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2573 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2574 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2575 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2576 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2577 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2578 .Default([](
Operation *op) {
return std::nullopt; });
2580 return std::nullopt;
2583 bool useOnlyFiniteValue =
false;
2584 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2585 if (fmfOpInterface) {
2586 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2587 useOnlyFiniteValue =
2588 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2596 useOnlyFiniteValue);
2602 bool useOnlyFiniteValue) {
2605 return builder.
create<arith::ConstantOp>(loc, attr);
2613 case AtomicRMWKind::addf:
2614 return builder.
create<arith::AddFOp>(loc, lhs, rhs);
2615 case AtomicRMWKind::addi:
2616 return builder.
create<arith::AddIOp>(loc, lhs, rhs);
2617 case AtomicRMWKind::mulf:
2618 return builder.
create<arith::MulFOp>(loc, lhs, rhs);
2619 case AtomicRMWKind::muli:
2620 return builder.
create<arith::MulIOp>(loc, lhs, rhs);
2621 case AtomicRMWKind::maximumf:
2622 return builder.
create<arith::MaximumFOp>(loc, lhs, rhs);
2623 case AtomicRMWKind::minimumf:
2624 return builder.
create<arith::MinimumFOp>(loc, lhs, rhs);
2625 case AtomicRMWKind::maxnumf:
2626 return builder.
create<arith::MaxNumFOp>(loc, lhs, rhs);
2627 case AtomicRMWKind::minnumf:
2628 return builder.
create<arith::MinNumFOp>(loc, lhs, rhs);
2629 case AtomicRMWKind::maxs:
2630 return builder.
create<arith::MaxSIOp>(loc, lhs, rhs);
2631 case AtomicRMWKind::mins:
2632 return builder.
create<arith::MinSIOp>(loc, lhs, rhs);
2633 case AtomicRMWKind::maxu:
2634 return builder.
create<arith::MaxUIOp>(loc, lhs, rhs);
2635 case AtomicRMWKind::minu:
2636 return builder.
create<arith::MinUIOp>(loc, lhs, rhs);
2637 case AtomicRMWKind::ori:
2638 return builder.
create<arith::OrIOp>(loc, lhs, rhs);
2639 case AtomicRMWKind::andi:
2640 return builder.
create<arith::AndIOp>(loc, lhs, rhs);
2653 #define GET_OP_CLASSES
2654 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2660 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
Specialization of arith.constant op that returns an integer value.
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
detail::constant_float_predicate_matcher m_isDenormalFloat()
Matches a constant scalar / vector splat / tensor splat with denormal values.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)