26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
45 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
46 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48 APInt value = binFn(lhsVal, rhsVal);
70 case arith::CmpIPredicate::eq:
71 return arith::CmpIPredicate::ne;
72 case arith::CmpIPredicate::ne:
73 return arith::CmpIPredicate::eq;
74 case arith::CmpIPredicate::slt:
75 return arith::CmpIPredicate::sge;
76 case arith::CmpIPredicate::sle:
77 return arith::CmpIPredicate::sgt;
78 case arith::CmpIPredicate::sgt:
79 return arith::CmpIPredicate::sle;
80 case arith::CmpIPredicate::sge:
81 return arith::CmpIPredicate::slt;
82 case arith::CmpIPredicate::ult:
83 return arith::CmpIPredicate::uge;
84 case arith::CmpIPredicate::ule:
85 return arith::CmpIPredicate::ugt;
86 case arith::CmpIPredicate::ugt:
87 return arith::CmpIPredicate::ule;
88 case arith::CmpIPredicate::uge:
89 return arith::CmpIPredicate::ult;
91 llvm_unreachable(
"unknown cmpi predicate kind");
100 static llvm::RoundingMode
102 switch (roundingMode) {
103 case RoundingMode::downward:
104 return llvm::RoundingMode::TowardNegative;
105 case RoundingMode::to_nearest_away:
106 return llvm::RoundingMode::NearestTiesToAway;
107 case RoundingMode::to_nearest_even:
108 return llvm::RoundingMode::NearestTiesToEven;
109 case RoundingMode::toward_zero:
110 return llvm::RoundingMode::TowardZero;
111 case RoundingMode::upward:
112 return llvm::RoundingMode::TowardPositive;
114 llvm_unreachable(
"Unhandled rounding mode");
144 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
155 #include "ArithCanonicalization.inc"
165 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
166 return shapedType.cloneWith(std::nullopt, i1Type);
167 if (llvm::isa<UnrankedTensorType>(type))
176 void arith::ConstantOp::getAsmResultNames(
178 auto type = getType();
179 if (
auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
180 auto intType = llvm::dyn_cast<IntegerType>(type);
183 if (intType && intType.getWidth() == 1)
184 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
188 llvm::raw_svector_ostream specialName(specialNameBuffer);
189 specialName <<
'c' << intCst.getValue();
191 specialName <<
'_' << type;
192 setNameFn(getResult(), specialName.str());
194 setNameFn(getResult(),
"cst");
201 auto type = getType();
203 if (getValue().getType() != type) {
204 return emitOpError() <<
"value type " << getValue().getType()
205 <<
" must match return type: " << type;
208 if (llvm::isa<IntegerType>(type) &&
209 !llvm::cast<IntegerType>(type).isSignless())
210 return emitOpError(
"integer return type must be signless");
212 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
214 "value must be an integer, float, or elements attribute");
220 auto vecType = dyn_cast<VectorType>(type);
221 if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
223 "intializing scalable vectors with elements attribute is not supported"
224 " unless it's a vector splat");
228 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
230 auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
231 if (!typedAttr || typedAttr.getType() != type)
234 if (llvm::isa<IntegerType>(type) &&
235 !llvm::cast<IntegerType>(type).isSignless())
238 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
243 if (isBuildableWith(value, type))
244 return builder.
create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
248 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
251 int64_t value,
unsigned width) {
253 arith::ConstantOp::build(builder, result, type,
258 int64_t value,
Type type) {
260 "ConstantIntOp can only have signless integer type values");
261 arith::ConstantOp::build(builder, result, type,
266 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
267 return constOp.getType().isSignlessInteger();
273 arith::ConstantOp::build(builder, result, type,
278 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
279 return llvm::isa<FloatType>(constOp.getType());
285 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
290 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
291 return constOp.getType().isIndex();
305 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
306 if (getRhs() == sub.getRhs())
310 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
311 if (getLhs() == sub.getRhs())
314 return constFoldBinaryOp<IntegerAttr>(
315 adaptor.getOperands(),
316 [](APInt a,
const APInt &b) { return std::move(a) + b; });
321 patterns.
add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
322 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
329 std::optional<SmallVector<int64_t, 4>>
330 arith::AddUIExtendedOp::getShapeForUnroll() {
331 if (
auto vt = llvm::dyn_cast<VectorType>(getType(0)))
332 return llvm::to_vector<4>(vt.getShape());
339 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
343 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
345 Type overflowTy = getOverflow().getType();
351 results.push_back(getLhs());
352 results.push_back(falseValue);
360 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
361 adaptor.getOperands(),
362 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
363 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
364 ArrayRef({sumAttr, adaptor.getLhs()}),
370 results.push_back(sumAttr);
371 results.push_back(overflowAttr);
378 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
380 patterns.
add<AddUIExtendedToAddI>(context);
389 if (getOperand(0) == getOperand(1))
395 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
397 if (getRhs() == add.getRhs())
400 if (getRhs() == add.getLhs())
404 return constFoldBinaryOp<IntegerAttr>(
405 adaptor.getOperands(),
406 [](APInt a,
const APInt &b) { return std::move(a) - b; });
411 patterns.
add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
412 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
413 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
430 return constFoldBinaryOp<IntegerAttr>(
431 adaptor.getOperands(),
432 [](
const APInt &a,
const APInt &b) { return a * b; });
435 void arith::MulIOp::getAsmResultNames(
437 if (!isa<IndexType>(getType()))
446 IntegerAttr baseValue;
449 isVscale(b.getDefiningOp());
452 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
457 llvm::raw_svector_ostream specialName(specialNameBuffer);
458 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
459 setNameFn(getResult(), specialName.str());
464 patterns.
add<MulIMulIConstant>(context);
471 std::optional<SmallVector<int64_t, 4>>
472 arith::MulSIExtendedOp::getShapeForUnroll() {
473 if (
auto vt = llvm::dyn_cast<VectorType>(getType(0)))
474 return llvm::to_vector<4>(vt.getShape());
479 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
484 results.push_back(zero);
485 results.push_back(zero);
490 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
491 adaptor.getOperands(),
492 [](
const APInt &a,
const APInt &b) { return a * b; })) {
494 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
495 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
496 return llvm::APIntOps::mulhs(a, b);
498 assert(highAttr &&
"Unexpected constant-folding failure");
500 results.push_back(lowAttr);
501 results.push_back(highAttr);
508 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
510 patterns.
add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
517 std::optional<SmallVector<int64_t, 4>>
518 arith::MulUIExtendedOp::getShapeForUnroll() {
519 if (
auto vt = llvm::dyn_cast<VectorType>(getType(0)))
520 return llvm::to_vector<4>(vt.getShape());
525 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
530 results.push_back(zero);
531 results.push_back(zero);
539 results.push_back(getLhs());
540 results.push_back(zero);
545 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
546 adaptor.getOperands(),
547 [](
const APInt &a,
const APInt &b) { return a * b; })) {
549 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
550 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
551 return llvm::APIntOps::mulhu(a, b);
553 assert(highAttr &&
"Unexpected constant-folding failure");
555 results.push_back(lowAttr);
556 results.push_back(highAttr);
563 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
565 patterns.
add<MulUIExtendedToMulI>(context);
572 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
579 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
580 [&](APInt a,
const APInt &b) {
601 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
607 bool overflowOrDiv0 =
false;
608 auto result = constFoldBinaryOp<IntegerAttr>(
609 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
610 if (overflowOrDiv0 || !b) {
611 overflowOrDiv0 = true;
614 return a.sdiv_ov(b, overflowOrDiv0);
617 return overflowOrDiv0 ?
Attribute() : result;
621 bool mayHaveUB =
true;
627 mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
639 APInt one(a.getBitWidth(), 1,
true);
640 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
641 return val.sadd_ov(one, overflow);
648 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
653 bool overflowOrDiv0 =
false;
654 auto result = constFoldBinaryOp<IntegerAttr>(
655 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
656 if (overflowOrDiv0 || !b) {
657 overflowOrDiv0 = true;
660 APInt quotient = a.udiv(b);
663 APInt one(a.getBitWidth(), 1,
true);
664 return quotient.uadd_ov(one, overflowOrDiv0);
667 return overflowOrDiv0 ?
Attribute() : result;
680 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
686 bool overflowOrDiv0 =
false;
687 auto result = constFoldBinaryOp<IntegerAttr>(
688 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
689 if (overflowOrDiv0 || !b) {
690 overflowOrDiv0 = true;
696 unsigned bits = a.getBitWidth();
698 bool aGtZero = a.sgt(zero);
699 bool bGtZero = b.sgt(zero);
700 if (aGtZero && bGtZero) {
704 if (!aGtZero && !bGtZero) {
706 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
707 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
710 if (!aGtZero && bGtZero) {
712 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
713 APInt div = posA.sdiv_ov(b, overflowOrDiv0);
714 return zero.ssub_ov(div, overflowOrDiv0);
717 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
718 APInt div = a.sdiv_ov(posB, overflowOrDiv0);
719 return zero.ssub_ov(div, overflowOrDiv0);
722 return overflowOrDiv0 ?
Attribute() : result;
726 bool mayHaveUB =
true;
732 mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
741 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
747 bool overflowOrDiv =
false;
748 auto result = constFoldBinaryOp<IntegerAttr>(
749 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
751 overflowOrDiv = true;
754 return a.sfloordiv_ov(b, overflowOrDiv);
757 return overflowOrDiv ?
Attribute() : result;
764 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
771 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
772 [&](APInt a,
const APInt &b) {
773 if (div0 || b.isZero()) {
787 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
794 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
795 [&](APInt a,
const APInt &b) {
796 if (div0 || b.isZero()) {
812 for (
bool reversePrev : {
false,
true}) {
813 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
814 .getDefiningOp<arith::AndIOp>();
818 Value other = (reversePrev ? op.getLhs() : op.getRhs());
819 if (other != prev.getLhs() && other != prev.getRhs())
822 return prev.getResult();
834 intValue.isAllOnes())
839 intValue.isAllOnes())
844 intValue.isAllOnes())
851 return constFoldBinaryOp<IntegerAttr>(
852 adaptor.getOperands(),
853 [](APInt a,
const APInt &b) { return std::move(a) & b; });
866 if (rhsVal.isAllOnes())
867 return adaptor.getRhs();
874 intValue.isAllOnes())
875 return getRhs().getDefiningOp<XOrIOp>().getRhs();
879 intValue.isAllOnes())
880 return getLhs().getDefiningOp<XOrIOp>().getRhs();
882 return constFoldBinaryOp<IntegerAttr>(
883 adaptor.getOperands(),
884 [](APInt a,
const APInt &b) { return std::move(a) | b; });
896 if (getLhs() == getRhs())
900 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
901 if (prev.getRhs() == getRhs())
902 return prev.getLhs();
903 if (prev.getLhs() == getRhs())
904 return prev.getRhs();
908 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
909 if (prev.getRhs() == getLhs())
910 return prev.getLhs();
911 if (prev.getLhs() == getLhs())
912 return prev.getRhs();
915 return constFoldBinaryOp<IntegerAttr>(
916 adaptor.getOperands(),
917 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
922 patterns.
add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
931 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
933 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
934 [](
const APFloat &a) { return -a; });
946 return constFoldBinaryOp<FloatAttr>(
947 adaptor.getOperands(),
948 [](
const APFloat &a,
const APFloat &b) { return a + b; });
960 return constFoldBinaryOp<FloatAttr>(
961 adaptor.getOperands(),
962 [](
const APFloat &a,
const APFloat &b) { return a - b; });
969 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
971 if (getLhs() == getRhs())
978 return constFoldBinaryOp<FloatAttr>(
979 adaptor.getOperands(),
980 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
987 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
989 if (getLhs() == getRhs())
996 return constFoldBinaryOp<FloatAttr>(
997 adaptor.getOperands(),
998 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
1007 if (getLhs() == getRhs())
1013 if (intValue.isMaxSignedValue())
1016 if (intValue.isMinSignedValue())
1020 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1021 [](
const APInt &a,
const APInt &b) {
1022 return llvm::APIntOps::smax(a, b);
1032 if (getLhs() == getRhs())
1038 if (intValue.isMaxValue())
1041 if (intValue.isMinValue())
1045 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1046 [](
const APInt &a,
const APInt &b) {
1047 return llvm::APIntOps::umax(a, b);
1055 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1057 if (getLhs() == getRhs())
1064 return constFoldBinaryOp<FloatAttr>(
1065 adaptor.getOperands(),
1066 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
1073 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1075 if (getLhs() == getRhs())
1082 return constFoldBinaryOp<FloatAttr>(
1083 adaptor.getOperands(),
1084 [](
const APFloat &a,
const APFloat &b) { return llvm::minnum(a, b); });
1093 if (getLhs() == getRhs())
1099 if (intValue.isMinSignedValue())
1102 if (intValue.isMaxSignedValue())
1106 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1107 [](
const APInt &a,
const APInt &b) {
1108 return llvm::APIntOps::smin(a, b);
1118 if (getLhs() == getRhs())
1124 if (intValue.isMinValue())
1127 if (intValue.isMaxValue())
1131 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1132 [](
const APInt &a,
const APInt &b) {
1133 return llvm::APIntOps::umin(a, b);
1141 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1146 return constFoldBinaryOp<FloatAttr>(
1147 adaptor.getOperands(),
1148 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1153 patterns.
add<MulFOfNegF>(context);
1160 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1165 return constFoldBinaryOp<FloatAttr>(
1166 adaptor.getOperands(),
1167 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1172 patterns.
add<DivFOfNegF>(context);
1179 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1180 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1181 [](
const APFloat &a,
const APFloat &b) {
1183 (void)result.remainder(b);
1192 template <
typename... Types>
1198 template <
typename... ShapedTypes,
typename... ElementTypes>
1201 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1205 if (!llvm::isa<ElementTypes...>(underlyingType))
1208 return underlyingType;
1212 template <
typename... ElementTypes>
1219 template <
typename... ElementTypes>
1228 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1229 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1230 if (!rankedTensorA || !rankedTensorB)
1232 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1236 if (inputs.size() != 1 || outputs.size() != 1)
1248 template <
typename ValType,
typename Op>
1253 if (llvm::cast<ValType>(srcType).getWidth() >=
1254 llvm::cast<ValType>(dstType).getWidth())
1256 << dstType <<
" must be wider than operand type " << srcType;
1262 template <
typename ValType,
typename Op>
1267 if (llvm::cast<ValType>(srcType).getWidth() <=
1268 llvm::cast<ValType>(dstType).getWidth())
1270 << dstType <<
" must be shorter than operand type " << srcType;
1276 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1281 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1282 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1283 if (!srcType || !dstType)
1286 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1287 srcType.getIntOrFloatBitWidth());
1293 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1294 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1295 bool losesInfo =
false;
1296 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1297 if (losesInfo || status != APFloat::opOK)
1307 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1308 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1309 getInMutable().assign(lhs.getIn());
1314 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1315 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1316 adaptor.getOperands(), getType(),
1317 [bitWidth](
const APInt &a,
bool &castStatus) {
1318 return a.zext(bitWidth);
1323 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1327 return verifyExtOp<IntegerType>(*
this);
1334 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1335 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1336 getInMutable().assign(lhs.getIn());
1341 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1342 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1343 adaptor.getOperands(), getType(),
1344 [bitWidth](
const APInt &a,
bool &castStatus) {
1345 return a.sext(bitWidth);
1350 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1355 patterns.
add<ExtSIOfExtUI>(context);
1359 return verifyExtOp<IntegerType>(*
this);
1368 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1370 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1371 return constFoldCastOp<FloatAttr, FloatAttr>(
1372 adaptor.getOperands(), getType(),
1373 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1384 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1393 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1394 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1401 if (llvm::cast<IntegerType>(srcType).getWidth() >
1402 llvm::cast<IntegerType>(dstType).getWidth()) {
1409 if (srcType == dstType)
1414 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1415 setOperand(getOperand().getDefiningOp()->getOperand(0));
1420 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1421 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1422 adaptor.getOperands(), getType(),
1423 [bitWidth](
const APInt &a,
bool &castStatus) {
1424 return a.trunc(bitWidth);
1429 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1434 patterns.
add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1435 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1440 return verifyTruncateOp<IntegerType>(*
this);
1449 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1451 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1452 return constFoldCastOp<FloatAttr, FloatAttr>(
1453 adaptor.getOperands(), getType(),
1454 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1455 RoundingMode roundingMode =
1456 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1457 llvm::RoundingMode llvmRoundingMode =
1470 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1474 return verifyTruncateOp<FloatType>(*
this);
1483 patterns.
add<AndOfExtUI, AndOfExtSI>(context);
1492 patterns.
add<OrOfExtUI, OrOfExtSI>(context);
1499 template <
typename From,
typename To>
1504 auto srcType = getTypeIfLike<From>(inputs.front());
1505 auto dstType = getTypeIfLike<To>(outputs.back());
1507 return srcType && dstType;
1515 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1518 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1520 return constFoldCastOp<IntegerAttr, FloatAttr>(
1521 adaptor.getOperands(), getType(),
1522 [&resEleType](
const APInt &a,
bool &castStatus) {
1523 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1526 apf.convertFromAPInt(a,
false,
1527 APFloat::rmNearestTiesToEven);
1537 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1540 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1542 return constFoldCastOp<IntegerAttr, FloatAttr>(
1543 adaptor.getOperands(), getType(),
1544 [&resEleType](
const APInt &a,
bool &castStatus) {
1545 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1548 apf.convertFromAPInt(a,
true,
1549 APFloat::rmNearestTiesToEven);
1559 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1562 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1564 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1565 return constFoldCastOp<FloatAttr, IntegerAttr>(
1566 adaptor.getOperands(), getType(),
1567 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1569 APSInt api(bitWidth,
true);
1570 castStatus = APFloat::opInvalidOp !=
1571 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1581 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1584 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1586 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1587 return constFoldCastOp<FloatAttr, IntegerAttr>(
1588 adaptor.getOperands(), getType(),
1589 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1591 APSInt api(bitWidth,
false);
1592 castStatus = APFloat::opInvalidOp !=
1593 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1606 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1607 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1608 if (!srcType || !dstType)
1615 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1620 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1622 unsigned resultBitwidth = 64;
1624 resultBitwidth = intTy.getWidth();
1626 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1627 adaptor.getOperands(), getType(),
1628 [resultBitwidth](
const APInt &a,
bool & ) {
1629 return a.sextOrTrunc(resultBitwidth);
1633 void arith::IndexCastOp::getCanonicalizationPatterns(
1635 patterns.
add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1642 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1647 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1649 unsigned resultBitwidth = 64;
1651 resultBitwidth = intTy.getWidth();
1653 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1654 adaptor.getOperands(), getType(),
1655 [resultBitwidth](
const APInt &a,
bool & ) {
1656 return a.zextOrTrunc(resultBitwidth);
1660 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1662 patterns.
add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1674 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1676 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1677 if (!srcType || !dstType)
1683 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1684 auto resType = getType();
1685 auto operand = adaptor.getIn();
1690 if (
auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1691 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1693 if (llvm::isa<ShapedType>(resType))
1697 APInt bits = llvm::isa<FloatAttr>(operand)
1698 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1699 : llvm::cast<IntegerAttr>(operand).getValue();
1701 if (
auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1703 APFloat(resFloatType.getFloatSemantics(), bits));
1709 patterns.
add<BitcastOfBitcast>(context);
1719 const APInt &lhs,
const APInt &rhs) {
1720 switch (predicate) {
1721 case arith::CmpIPredicate::eq:
1723 case arith::CmpIPredicate::ne:
1725 case arith::CmpIPredicate::slt:
1726 return lhs.slt(rhs);
1727 case arith::CmpIPredicate::sle:
1728 return lhs.sle(rhs);
1729 case arith::CmpIPredicate::sgt:
1730 return lhs.sgt(rhs);
1731 case arith::CmpIPredicate::sge:
1732 return lhs.sge(rhs);
1733 case arith::CmpIPredicate::ult:
1734 return lhs.ult(rhs);
1735 case arith::CmpIPredicate::ule:
1736 return lhs.ule(rhs);
1737 case arith::CmpIPredicate::ugt:
1738 return lhs.ugt(rhs);
1739 case arith::CmpIPredicate::uge:
1740 return lhs.uge(rhs);
1742 llvm_unreachable(
"unknown cmpi predicate kind");
1747 switch (predicate) {
1748 case arith::CmpIPredicate::eq:
1749 case arith::CmpIPredicate::sle:
1750 case arith::CmpIPredicate::sge:
1751 case arith::CmpIPredicate::ule:
1752 case arith::CmpIPredicate::uge:
1754 case arith::CmpIPredicate::ne:
1755 case arith::CmpIPredicate::slt:
1756 case arith::CmpIPredicate::sgt:
1757 case arith::CmpIPredicate::ult:
1758 case arith::CmpIPredicate::ugt:
1761 llvm_unreachable(
"unknown cmpi predicate kind");
1765 if (
auto intType = llvm::dyn_cast<IntegerType>(t)) {
1766 return intType.getWidth();
1768 if (
auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1769 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1771 return std::nullopt;
1774 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1776 if (getLhs() == getRhs()) {
1782 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1784 std::optional<int64_t> integerWidth =
1786 if (integerWidth && integerWidth.value() == 1 &&
1787 getPredicate() == arith::CmpIPredicate::ne)
1788 return extOp.getOperand();
1790 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1792 std::optional<int64_t> integerWidth =
1794 if (integerWidth && integerWidth.value() == 1 &&
1795 getPredicate() == arith::CmpIPredicate::ne)
1796 return extOp.getOperand();
1801 if (adaptor.getLhs() && !adaptor.getRhs()) {
1803 using Pred = CmpIPredicate;
1804 const std::pair<Pred, Pred> invPreds[] = {
1805 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1806 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1807 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1808 {Pred::ne, Pred::ne},
1810 Pred origPred = getPredicate();
1811 for (
auto pred : invPreds) {
1812 if (origPred == pred.first) {
1813 setPredicate(pred.second);
1814 Value lhs = getLhs();
1815 Value rhs = getRhs();
1816 getLhsMutable().assign(rhs);
1817 getRhsMutable().assign(lhs);
1821 llvm_unreachable(
"unknown cmpi predicate kind");
1826 if (
auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1827 return constFoldBinaryOp<IntegerAttr>(
1829 [pred = getPredicate()](
const APInt &lhs,
const APInt &rhs) {
1840 patterns.
insert<CmpIExtSI, CmpIExtUI>(context);
1850 const APFloat &lhs,
const APFloat &rhs) {
1851 auto cmpResult = lhs.compare(rhs);
1852 switch (predicate) {
1853 case arith::CmpFPredicate::AlwaysFalse:
1855 case arith::CmpFPredicate::OEQ:
1856 return cmpResult == APFloat::cmpEqual;
1857 case arith::CmpFPredicate::OGT:
1858 return cmpResult == APFloat::cmpGreaterThan;
1859 case arith::CmpFPredicate::OGE:
1860 return cmpResult == APFloat::cmpGreaterThan ||
1861 cmpResult == APFloat::cmpEqual;
1862 case arith::CmpFPredicate::OLT:
1863 return cmpResult == APFloat::cmpLessThan;
1864 case arith::CmpFPredicate::OLE:
1865 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1866 case arith::CmpFPredicate::ONE:
1867 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1868 case arith::CmpFPredicate::ORD:
1869 return cmpResult != APFloat::cmpUnordered;
1870 case arith::CmpFPredicate::UEQ:
1871 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1872 case arith::CmpFPredicate::UGT:
1873 return cmpResult == APFloat::cmpUnordered ||
1874 cmpResult == APFloat::cmpGreaterThan;
1875 case arith::CmpFPredicate::UGE:
1876 return cmpResult == APFloat::cmpUnordered ||
1877 cmpResult == APFloat::cmpGreaterThan ||
1878 cmpResult == APFloat::cmpEqual;
1879 case arith::CmpFPredicate::ULT:
1880 return cmpResult == APFloat::cmpUnordered ||
1881 cmpResult == APFloat::cmpLessThan;
1882 case arith::CmpFPredicate::ULE:
1883 return cmpResult == APFloat::cmpUnordered ||
1884 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1885 case arith::CmpFPredicate::UNE:
1886 return cmpResult != APFloat::cmpEqual;
1887 case arith::CmpFPredicate::UNO:
1888 return cmpResult == APFloat::cmpUnordered;
1889 case arith::CmpFPredicate::AlwaysTrue:
1892 llvm_unreachable(
"unknown cmpf predicate kind");
1895 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1896 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1897 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1900 if (lhs && lhs.getValue().isNaN())
1902 if (rhs && rhs.getValue().isNaN())
1918 using namespace arith;
1920 case CmpFPredicate::UEQ:
1921 case CmpFPredicate::OEQ:
1922 return CmpIPredicate::eq;
1923 case CmpFPredicate::UGT:
1924 case CmpFPredicate::OGT:
1925 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1926 case CmpFPredicate::UGE:
1927 case CmpFPredicate::OGE:
1928 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1929 case CmpFPredicate::ULT:
1930 case CmpFPredicate::OLT:
1931 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1932 case CmpFPredicate::ULE:
1933 case CmpFPredicate::OLE:
1934 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1935 case CmpFPredicate::UNE:
1936 case CmpFPredicate::ONE:
1937 return CmpIPredicate::ne;
1939 llvm_unreachable(
"Unexpected predicate!");
1949 const APFloat &rhs = flt.getValue();
1957 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
1959 if (mantissaWidth <= 0)
1965 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1967 intVal = si.getIn();
1968 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1970 intVal = ui.getIn();
1977 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
1978 auto intWidth = intTy.getWidth();
1981 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1986 if ((
int)intWidth > mantissaWidth) {
1988 int exponent = ilogb(rhs);
1989 if (exponent == APFloat::IEK_Inf) {
1990 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1991 if (maxExponent < (
int)valueBits) {
1998 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2007 switch (op.getPredicate()) {
2008 case CmpFPredicate::ORD:
2013 case CmpFPredicate::UNO:
2026 APFloat signedMax(rhs.getSemantics());
2027 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2028 APFloat::rmNearestTiesToEven);
2029 if (signedMax < rhs) {
2030 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2031 pred == CmpIPredicate::sle)
2042 APFloat unsignedMax(rhs.getSemantics());
2043 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2044 APFloat::rmNearestTiesToEven);
2045 if (unsignedMax < rhs) {
2046 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2047 pred == CmpIPredicate::ule)
2059 APFloat signedMin(rhs.getSemantics());
2060 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2061 APFloat::rmNearestTiesToEven);
2062 if (signedMin > rhs) {
2063 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2064 pred == CmpIPredicate::sge)
2074 APFloat unsignedMin(rhs.getSemantics());
2075 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2076 APFloat::rmNearestTiesToEven);
2077 if (unsignedMin > rhs) {
2078 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2079 pred == CmpIPredicate::uge)
2094 APSInt rhsInt(intWidth, isUnsigned);
2095 if (APFloat::opInvalidOp ==
2096 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2102 if (!rhs.isZero()) {
2105 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2107 bool equal = apf == rhs;
2113 case CmpIPredicate::ne:
2117 case CmpIPredicate::eq:
2121 case CmpIPredicate::ule:
2124 if (rhs.isNegative()) {
2130 case CmpIPredicate::sle:
2133 if (rhs.isNegative())
2134 pred = CmpIPredicate::slt;
2136 case CmpIPredicate::ult:
2139 if (rhs.isNegative()) {
2144 pred = CmpIPredicate::ule;
2146 case CmpIPredicate::slt:
2149 if (!rhs.isNegative())
2150 pred = CmpIPredicate::sle;
2152 case CmpIPredicate::ugt:
2155 if (rhs.isNegative()) {
2161 case CmpIPredicate::sgt:
2164 if (rhs.isNegative())
2165 pred = CmpIPredicate::sge;
2167 case CmpIPredicate::uge:
2170 if (rhs.isNegative()) {
2175 pred = CmpIPredicate::ugt;
2177 case CmpIPredicate::sge:
2180 if (!rhs.isNegative())
2181 pred = CmpIPredicate::sgt;
2191 rewriter.
create<ConstantOp>(
2214 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2230 rewriter.
create<arith::XOrIOp>(
2231 op.
getLoc(), op.getCondition(),
2232 rewriter.
create<arith::ConstantIntOp>(
2233 op.
getLoc(), 1, op.getCondition().getType())));
2243 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2247 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2248 Value trueVal = getTrueValue();
2249 Value falseVal = getFalseValue();
2250 if (trueVal == falseVal)
2253 Value condition = getCondition();
2264 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2267 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2275 if (
auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.
getDefiningOp())) {
2276 auto pred = cmp.getPredicate();
2277 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2278 auto cmpLhs = cmp.getLhs();
2279 auto cmpRhs = cmp.getRhs();
2287 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2288 (cmpRhs == trueVal && cmpLhs == falseVal))
2289 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2296 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2298 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2300 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2302 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2303 auto condVals = llvm::make_range(cond.value_begin<
BoolAttr>(),
2305 auto lhsVals = llvm::make_range(lhs.value_begin<
Attribute>(),
2307 auto rhsVals = llvm::make_range(rhs.value_begin<
Attribute>(),
2310 for (
auto [condVal, lhsVal, rhsVal] :
2311 llvm::zip_equal(condVals, lhsVals, rhsVals))
2312 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2323 Type conditionType, resultType;
2332 conditionType = resultType;
2341 {conditionType, resultType, resultType},
2346 p <<
" " << getOperands();
2349 if (ShapedType condType =
2350 llvm::dyn_cast<ShapedType>(getCondition().getType()))
2351 p << condType <<
", ";
2356 Type conditionType = getCondition().getType();
2362 Type resultType = getType();
2363 if (!llvm::isa<TensorType, VectorType>(resultType))
2364 return emitOpError() <<
"expected condition to be a signless i1, but got "
2367 if (conditionType != shapedConditionType) {
2368 return emitOpError() <<
"expected condition type to have the same shape "
2369 "as the result type, expected "
2370 << shapedConditionType <<
", but got "
2379 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2384 bool bounded =
false;
2385 auto result = constFoldBinaryOp<IntegerAttr>(
2386 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2387 bounded = b.ult(b.getBitWidth());
2397 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2402 bool bounded =
false;
2403 auto result = constFoldBinaryOp<IntegerAttr>(
2404 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2405 bounded = b.ult(b.getBitWidth());
2415 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2420 bool bounded =
false;
2421 auto result = constFoldBinaryOp<IntegerAttr>(
2422 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2423 bounded = b.ult(b.getBitWidth());
2436 bool useOnlyFiniteValue) {
2438 case AtomicRMWKind::maximumf: {
2439 const llvm::fltSemantics &semantic =
2440 llvm::cast<FloatType>(resultType).getFloatSemantics();
2441 APFloat identity = useOnlyFiniteValue
2442 ? APFloat::getLargest(semantic,
true)
2443 : APFloat::getInf(semantic,
true);
2446 case AtomicRMWKind::addf:
2447 case AtomicRMWKind::addi:
2448 case AtomicRMWKind::maxu:
2449 case AtomicRMWKind::ori:
2451 case AtomicRMWKind::andi:
2454 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2455 case AtomicRMWKind::maxs:
2457 resultType, APInt::getSignedMinValue(
2458 llvm::cast<IntegerType>(resultType).getWidth()));
2459 case AtomicRMWKind::minimumf: {
2460 const llvm::fltSemantics &semantic =
2461 llvm::cast<FloatType>(resultType).getFloatSemantics();
2462 APFloat identity = useOnlyFiniteValue
2463 ? APFloat::getLargest(semantic,
false)
2464 : APFloat::getInf(semantic,
false);
2468 case AtomicRMWKind::mins:
2470 resultType, APInt::getSignedMaxValue(
2471 llvm::cast<IntegerType>(resultType).getWidth()));
2472 case AtomicRMWKind::minu:
2475 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2476 case AtomicRMWKind::muli:
2478 case AtomicRMWKind::mulf:
2490 std::optional<AtomicRMWKind> maybeKind =
2493 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2494 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2495 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2496 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2498 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2499 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2500 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::ori; })
2501 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2502 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2503 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2504 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2505 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2506 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2507 .Default([](
Operation *op) {
return std::nullopt; });
2509 op->
emitError() <<
"Unknown neutral element for: " << *op;
2510 return std::nullopt;
2513 bool useOnlyFiniteValue =
false;
2514 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2515 if (fmfOpInterface) {
2516 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2517 useOnlyFiniteValue =
2518 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2526 useOnlyFiniteValue);
2532 bool useOnlyFiniteValue) {
2535 return builder.
create<arith::ConstantOp>(loc, attr);
2543 case AtomicRMWKind::addf:
2544 return builder.
create<arith::AddFOp>(loc, lhs, rhs);
2545 case AtomicRMWKind::addi:
2546 return builder.
create<arith::AddIOp>(loc, lhs, rhs);
2547 case AtomicRMWKind::mulf:
2548 return builder.
create<arith::MulFOp>(loc, lhs, rhs);
2549 case AtomicRMWKind::muli:
2550 return builder.
create<arith::MulIOp>(loc, lhs, rhs);
2551 case AtomicRMWKind::maximumf:
2552 return builder.
create<arith::MaximumFOp>(loc, lhs, rhs);
2553 case AtomicRMWKind::minimumf:
2554 return builder.
create<arith::MinimumFOp>(loc, lhs, rhs);
2555 case AtomicRMWKind::maxnumf:
2556 return builder.
create<arith::MaxNumFOp>(loc, lhs, rhs);
2557 case AtomicRMWKind::minnumf:
2558 return builder.
create<arith::MinNumFOp>(loc, lhs, rhs);
2559 case AtomicRMWKind::maxs:
2560 return builder.
create<arith::MaxSIOp>(loc, lhs, rhs);
2561 case AtomicRMWKind::mins:
2562 return builder.
create<arith::MinSIOp>(loc, lhs, rhs);
2563 case AtomicRMWKind::maxu:
2564 return builder.
create<arith::MaxUIOp>(loc, lhs, rhs);
2565 case AtomicRMWKind::minu:
2566 return builder.
create<arith::MinUIOp>(loc, lhs, rhs);
2567 case AtomicRMWKind::ori:
2568 return builder.
create<arith::OrIOp>(loc, lhs, rhs);
2569 case AtomicRMWKind::andi:
2570 return builder.
create<arith::AndIOp>(loc, lhs, rhs);
2583 #define GET_OP_CLASSES
2584 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2590 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
Specialization of arith.constant op that returns an integer value.
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)