26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
45 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
46 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48 APInt value = binFn(lhsVal, rhsVal);
68 static IntegerOverflowFlagsAttr
70 IntegerOverflowFlagsAttr val2) {
72 val1.getValue() & val2.getValue());
78 case arith::CmpIPredicate::eq:
79 return arith::CmpIPredicate::ne;
80 case arith::CmpIPredicate::ne:
81 return arith::CmpIPredicate::eq;
82 case arith::CmpIPredicate::slt:
83 return arith::CmpIPredicate::sge;
84 case arith::CmpIPredicate::sle:
85 return arith::CmpIPredicate::sgt;
86 case arith::CmpIPredicate::sgt:
87 return arith::CmpIPredicate::sle;
88 case arith::CmpIPredicate::sge:
89 return arith::CmpIPredicate::slt;
90 case arith::CmpIPredicate::ult:
91 return arith::CmpIPredicate::uge;
92 case arith::CmpIPredicate::ule:
93 return arith::CmpIPredicate::ugt;
94 case arith::CmpIPredicate::ugt:
95 return arith::CmpIPredicate::ule;
96 case arith::CmpIPredicate::uge:
97 return arith::CmpIPredicate::ult;
99 llvm_unreachable(
"unknown cmpi predicate kind");
108 static llvm::RoundingMode
110 switch (roundingMode) {
111 case RoundingMode::downward:
112 return llvm::RoundingMode::TowardNegative;
113 case RoundingMode::to_nearest_away:
114 return llvm::RoundingMode::NearestTiesToAway;
115 case RoundingMode::to_nearest_even:
116 return llvm::RoundingMode::NearestTiesToEven;
117 case RoundingMode::toward_zero:
118 return llvm::RoundingMode::TowardZero;
119 case RoundingMode::upward:
120 return llvm::RoundingMode::TowardPositive;
122 llvm_unreachable(
"Unhandled rounding mode");
152 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
163 #include "ArithCanonicalization.inc"
173 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
174 return shapedType.cloneWith(std::nullopt, i1Type);
175 if (llvm::isa<UnrankedTensorType>(type))
184 void arith::ConstantOp::getAsmResultNames(
187 if (
auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188 auto intType = llvm::dyn_cast<IntegerType>(type);
191 if (intType && intType.getWidth() == 1)
192 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
196 llvm::raw_svector_ostream specialName(specialNameBuffer);
197 specialName <<
'c' << intCst.getValue();
199 specialName <<
'_' << type;
200 setNameFn(getResult(), specialName.str());
202 setNameFn(getResult(),
"cst");
211 if (getValue().
getType() != type) {
212 return emitOpError() <<
"value type " << getValue().getType()
213 <<
" must match return type: " << type;
216 if (llvm::isa<IntegerType>(type) &&
217 !llvm::cast<IntegerType>(type).isSignless())
218 return emitOpError(
"integer return type must be signless");
220 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
222 "value must be an integer, float, or elements attribute");
228 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
230 "intializing scalable vectors with elements attribute is not supported"
231 " unless it's a vector splat");
235 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
237 auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
238 if (!typedAttr || typedAttr.getType() != type)
241 if (llvm::isa<IntegerType>(type) &&
242 !llvm::cast<IntegerType>(type).isSignless())
245 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
250 if (isBuildableWith(value, type))
251 return builder.
create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
255 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
258 int64_t value,
unsigned width) {
260 arith::ConstantOp::build(builder, result, type,
265 int64_t value,
Type type) {
267 "ConstantIntOp can only have signless integer type values");
268 arith::ConstantOp::build(builder, result, type,
273 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
274 return constOp.getType().isSignlessInteger();
279 const APFloat &value, FloatType type) {
280 arith::ConstantOp::build(builder, result, type,
285 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
286 return llvm::isa<FloatType>(constOp.getType());
292 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
297 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
298 return constOp.getType().isIndex();
312 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
313 if (getRhs() == sub.getRhs())
317 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
318 if (getLhs() == sub.getRhs())
321 return constFoldBinaryOp<IntegerAttr>(
322 adaptor.getOperands(),
323 [](APInt a,
const APInt &b) { return std::move(a) + b; });
328 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
329 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
336 std::optional<SmallVector<int64_t, 4>>
337 arith::AddUIExtendedOp::getShapeForUnroll() {
338 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
339 return llvm::to_vector<4>(vt.getShape());
346 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
350 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
352 Type overflowTy = getOverflow().getType();
358 results.push_back(getLhs());
359 results.push_back(falseValue);
367 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
368 adaptor.getOperands(),
369 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
370 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
371 ArrayRef({sumAttr, adaptor.getLhs()}),
377 results.push_back(sumAttr);
378 results.push_back(overflowAttr);
385 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
387 patterns.add<AddUIExtendedToAddI>(context);
396 if (getOperand(0) == getOperand(1)) {
397 auto shapedType = dyn_cast<ShapedType>(
getType());
399 if (!shapedType || shapedType.hasStaticShape())
406 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
408 if (getRhs() == add.getRhs())
411 if (getRhs() == add.getLhs())
415 return constFoldBinaryOp<IntegerAttr>(
416 adaptor.getOperands(),
417 [](APInt a,
const APInt &b) { return std::move(a) - b; });
422 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
423 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
424 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
441 return constFoldBinaryOp<IntegerAttr>(
442 adaptor.getOperands(),
443 [](
const APInt &a,
const APInt &b) { return a * b; });
446 void arith::MulIOp::getAsmResultNames(
448 if (!isa<IndexType>(
getType()))
454 return op && op->getName().getStringRef() ==
"vector.vscale";
457 IntegerAttr baseValue;
460 isVscale(b.getDefiningOp());
463 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
468 llvm::raw_svector_ostream specialName(specialNameBuffer);
469 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
470 setNameFn(getResult(), specialName.str());
475 patterns.add<MulIMulIConstant>(context);
482 std::optional<SmallVector<int64_t, 4>>
483 arith::MulSIExtendedOp::getShapeForUnroll() {
484 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
485 return llvm::to_vector<4>(vt.getShape());
490 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
495 results.push_back(zero);
496 results.push_back(zero);
501 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
502 adaptor.getOperands(),
503 [](
const APInt &a,
const APInt &b) { return a * b; })) {
505 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
506 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
507 return llvm::APIntOps::mulhs(a, b);
509 assert(highAttr &&
"Unexpected constant-folding failure");
511 results.push_back(lowAttr);
512 results.push_back(highAttr);
519 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
521 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
528 std::optional<SmallVector<int64_t, 4>>
529 arith::MulUIExtendedOp::getShapeForUnroll() {
530 if (
auto vt = llvm::dyn_cast<VectorType>(
getType(0)))
531 return llvm::to_vector<4>(vt.getShape());
536 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
541 results.push_back(zero);
542 results.push_back(zero);
550 results.push_back(getLhs());
551 results.push_back(zero);
556 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
557 adaptor.getOperands(),
558 [](
const APInt &a,
const APInt &b) { return a * b; })) {
560 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
561 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
562 return llvm::APIntOps::mulhu(a, b);
564 assert(highAttr &&
"Unexpected constant-folding failure");
566 results.push_back(lowAttr);
567 results.push_back(highAttr);
574 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
576 patterns.add<MulUIExtendedToMulI>(context);
585 arith::IntegerOverflowFlags ovfFlags) {
587 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
590 if (mul.getLhs() == rhs)
593 if (mul.getRhs() == rhs)
599 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
605 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
610 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
611 [&](APInt a,
const APInt &b) {
639 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
645 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
649 bool overflowOrDiv0 =
false;
650 auto result = constFoldBinaryOp<IntegerAttr>(
651 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
652 if (overflowOrDiv0 || !b) {
653 overflowOrDiv0 = true;
656 return a.sdiv_ov(b, overflowOrDiv0);
659 return overflowOrDiv0 ?
Attribute() : result;
686 APInt one(a.getBitWidth(), 1,
true);
687 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
688 return val.sadd_ov(one, overflow);
695 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
700 bool overflowOrDiv0 =
false;
701 auto result = constFoldBinaryOp<IntegerAttr>(
702 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
703 if (overflowOrDiv0 || !b) {
704 overflowOrDiv0 = true;
707 APInt quotient = a.udiv(b);
710 APInt one(a.getBitWidth(), 1,
true);
711 return quotient.uadd_ov(one, overflowOrDiv0);
714 return overflowOrDiv0 ?
Attribute() : result;
725 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
733 bool overflowOrDiv0 =
false;
734 auto result = constFoldBinaryOp<IntegerAttr>(
735 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
736 if (overflowOrDiv0 || !b) {
737 overflowOrDiv0 = true;
743 unsigned bits = a.getBitWidth();
745 bool aGtZero = a.sgt(zero);
746 bool bGtZero = b.sgt(zero);
747 if (aGtZero && bGtZero) {
754 bool overflowNegA =
false;
755 bool overflowNegB =
false;
756 bool overflowDiv =
false;
757 bool overflowNegRes =
false;
758 if (!aGtZero && !bGtZero) {
760 APInt posA = zero.ssub_ov(a, overflowNegA);
761 APInt posB = zero.ssub_ov(b, overflowNegB);
763 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
766 if (!aGtZero && bGtZero) {
768 APInt posA = zero.ssub_ov(a, overflowNegA);
769 APInt div = posA.sdiv_ov(b, overflowDiv);
770 APInt res = zero.ssub_ov(div, overflowNegRes);
771 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
775 APInt posB = zero.ssub_ov(b, overflowNegB);
776 APInt div = a.sdiv_ov(posB, overflowDiv);
777 APInt res = zero.ssub_ov(div, overflowNegRes);
779 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
783 return overflowOrDiv0 ?
Attribute() : result;
794 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
800 bool overflowOrDiv =
false;
801 auto result = constFoldBinaryOp<IntegerAttr>(
802 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
804 overflowOrDiv = true;
807 return a.sfloordiv_ov(b, overflowOrDiv);
810 return overflowOrDiv ?
Attribute() : result;
817 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
824 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
825 [&](APInt a,
const APInt &b) {
826 if (div0 || b.isZero()) {
840 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
847 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
848 [&](APInt a,
const APInt &b) {
849 if (div0 || b.isZero()) {
865 for (
bool reversePrev : {
false,
true}) {
866 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
867 .getDefiningOp<arith::AndIOp>();
871 Value other = (reversePrev ? op.getLhs() : op.getRhs());
872 if (other != prev.getLhs() && other != prev.getRhs())
875 return prev.getResult();
887 intValue.isAllOnes())
892 intValue.isAllOnes())
897 intValue.isAllOnes())
904 return constFoldBinaryOp<IntegerAttr>(
905 adaptor.getOperands(),
906 [](APInt a,
const APInt &b) { return std::move(a) & b; });
919 if (rhsVal.isAllOnes())
920 return adaptor.getRhs();
927 intValue.isAllOnes())
928 return getRhs().getDefiningOp<XOrIOp>().getRhs();
932 intValue.isAllOnes())
933 return getLhs().getDefiningOp<XOrIOp>().getRhs();
935 return constFoldBinaryOp<IntegerAttr>(
936 adaptor.getOperands(),
937 [](APInt a,
const APInt &b) { return std::move(a) | b; });
949 if (getLhs() == getRhs())
953 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
954 if (prev.getRhs() == getRhs())
955 return prev.getLhs();
956 if (prev.getLhs() == getRhs())
957 return prev.getRhs();
961 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
962 if (prev.getRhs() == getLhs())
963 return prev.getLhs();
964 if (prev.getLhs() == getLhs())
965 return prev.getRhs();
968 return constFoldBinaryOp<IntegerAttr>(
969 adaptor.getOperands(),
970 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
975 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
984 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
985 return op.getOperand();
986 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
987 [](
const APFloat &a) { return -a; });
999 return constFoldBinaryOp<FloatAttr>(
1000 adaptor.getOperands(),
1001 [](
const APFloat &a,
const APFloat &b) { return a + b; });
1008 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1013 return constFoldBinaryOp<FloatAttr>(
1014 adaptor.getOperands(),
1015 [](
const APFloat &a,
const APFloat &b) { return a - b; });
1022 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1024 if (getLhs() == getRhs())
1031 return constFoldBinaryOp<FloatAttr>(
1032 adaptor.getOperands(),
1033 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
1040 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1042 if (getLhs() == getRhs())
1049 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1058 if (getLhs() == getRhs())
1064 if (intValue.isMaxSignedValue())
1067 if (intValue.isMinSignedValue())
1071 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1072 [](
const APInt &a,
const APInt &b) {
1073 return llvm::APIntOps::smax(a, b);
1083 if (getLhs() == getRhs())
1089 if (intValue.isMaxValue())
1092 if (intValue.isMinValue())
1096 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1097 [](
const APInt &a,
const APInt &b) {
1098 return llvm::APIntOps::umax(a, b);
1106 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1108 if (getLhs() == getRhs())
1115 return constFoldBinaryOp<FloatAttr>(
1116 adaptor.getOperands(),
1117 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
1124 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1126 if (getLhs() == getRhs())
1133 return constFoldBinaryOp<FloatAttr>(
1134 adaptor.getOperands(),
1135 [](
const APFloat &a,
const APFloat &b) { return llvm::minnum(a, b); });
1144 if (getLhs() == getRhs())
1150 if (intValue.isMinSignedValue())
1153 if (intValue.isMaxSignedValue())
1157 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1158 [](
const APInt &a,
const APInt &b) {
1159 return llvm::APIntOps::smin(a, b);
1169 if (getLhs() == getRhs())
1175 if (intValue.isMinValue())
1178 if (intValue.isMaxValue())
1182 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1183 [](
const APInt &a,
const APInt &b) {
1184 return llvm::APIntOps::umin(a, b);
1192 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1197 return constFoldBinaryOp<FloatAttr>(
1198 adaptor.getOperands(),
1199 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1211 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1216 return constFoldBinaryOp<FloatAttr>(
1217 adaptor.getOperands(),
1218 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1230 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1231 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1232 [](
const APFloat &a,
const APFloat &b) {
1237 (void)result.mod(b);
1246 template <
typename... Types>
1252 template <
typename... ShapedTypes,
typename... ElementTypes>
1255 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1259 if (!llvm::isa<ElementTypes...>(underlyingType))
1262 return underlyingType;
1266 template <
typename... ElementTypes>
1273 template <
typename... ElementTypes>
1282 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1283 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1284 if (!rankedTensorA || !rankedTensorB)
1286 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1290 if (inputs.size() != 1 || outputs.size() != 1)
1302 template <
typename ValType,
typename Op>
1307 if (llvm::cast<ValType>(srcType).getWidth() >=
1308 llvm::cast<ValType>(dstType).getWidth())
1310 << dstType <<
" must be wider than operand type " << srcType;
1316 template <
typename ValType,
typename Op>
1321 if (llvm::cast<ValType>(srcType).getWidth() <=
1322 llvm::cast<ValType>(dstType).getWidth())
1324 << dstType <<
" must be shorter than operand type " << srcType;
1330 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1335 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1336 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1337 if (!srcType || !dstType)
1340 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1341 srcType.getIntOrFloatBitWidth());
1347 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1348 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1349 bool losesInfo =
false;
1350 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1351 if (losesInfo || status != APFloat::opOK)
1361 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1362 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1363 getInMutable().assign(lhs.getIn());
1368 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1369 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1370 adaptor.getOperands(),
getType(),
1371 [bitWidth](
const APInt &a,
bool &castStatus) {
1372 return a.zext(bitWidth);
1377 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1381 return verifyExtOp<IntegerType>(*
this);
1388 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1389 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1390 getInMutable().assign(lhs.getIn());
1395 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1396 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1397 adaptor.getOperands(),
getType(),
1398 [bitWidth](
const APInt &a,
bool &castStatus) {
1399 return a.sext(bitWidth);
1404 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1409 patterns.add<ExtSIOfExtUI>(context);
1413 return verifyExtOp<IntegerType>(*
this);
1422 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1423 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1424 if (truncFOp.getOperand().getType() ==
getType()) {
1425 arith::FastMathFlags truncFMF =
1426 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1427 bool isTruncContract =
1429 arith::FastMathFlags extFMF =
1430 getFastmath().value_or(arith::FastMathFlags::none);
1431 bool isExtContract =
1433 if (isTruncContract && isExtContract) {
1434 return truncFOp.getOperand();
1440 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1441 return constFoldCastOp<FloatAttr, FloatAttr>(
1442 adaptor.getOperands(),
getType(),
1443 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1445 if (failed(result)) {
1454 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1463 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1464 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1471 if (llvm::cast<IntegerType>(srcType).getWidth() >
1472 llvm::cast<IntegerType>(dstType).getWidth()) {
1479 if (srcType == dstType)
1484 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1485 setOperand(getOperand().getDefiningOp()->getOperand(0));
1490 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1491 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1492 adaptor.getOperands(),
getType(),
1493 [bitWidth](
const APInt &a,
bool &castStatus) {
1494 return a.trunc(bitWidth);
1499 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1504 patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1505 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1510 return verifyTruncateOp<IntegerType>(*
this);
1519 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1521 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1522 Value src = extOp.getIn();
1524 auto intermediateType =
1527 if (llvm::APFloatBase::isRepresentableBy(
1528 srcType.getFloatSemantics(),
1529 intermediateType.getFloatSemantics())) {
1531 if (srcType.getWidth() > resElemType.getWidth()) {
1537 if (srcType == resElemType)
1542 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1543 return constFoldCastOp<FloatAttr, FloatAttr>(
1544 adaptor.getOperands(),
getType(),
1545 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1546 RoundingMode roundingMode =
1547 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1548 llvm::RoundingMode llvmRoundingMode =
1550 FailureOr<APFloat> result =
1552 if (failed(result)) {
1561 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1565 return verifyTruncateOp<FloatType>(*
this);
1574 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1583 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1590 template <
typename From,
typename To>
1595 auto srcType = getTypeIfLike<From>(inputs.front());
1596 auto dstType = getTypeIfLike<To>(outputs.back());
1598 return srcType && dstType;
1606 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1609 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1611 return constFoldCastOp<IntegerAttr, FloatAttr>(
1612 adaptor.getOperands(),
getType(),
1613 [&resEleType](
const APInt &a,
bool &castStatus) {
1614 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1615 APFloat apf(floatTy.getFloatSemantics(),
1617 apf.convertFromAPInt(a,
false,
1618 APFloat::rmNearestTiesToEven);
1628 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1631 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1633 return constFoldCastOp<IntegerAttr, FloatAttr>(
1634 adaptor.getOperands(),
getType(),
1635 [&resEleType](
const APInt &a,
bool &castStatus) {
1636 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1637 APFloat apf(floatTy.getFloatSemantics(),
1639 apf.convertFromAPInt(a,
true,
1640 APFloat::rmNearestTiesToEven);
1650 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1653 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1655 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1656 return constFoldCastOp<FloatAttr, IntegerAttr>(
1657 adaptor.getOperands(),
getType(),
1658 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1660 APSInt api(bitWidth,
true);
1661 castStatus = APFloat::opInvalidOp !=
1662 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1672 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1675 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1677 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1678 return constFoldCastOp<FloatAttr, IntegerAttr>(
1679 adaptor.getOperands(),
getType(),
1680 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1682 APSInt api(bitWidth,
false);
1683 castStatus = APFloat::opInvalidOp !=
1684 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1697 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1698 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1699 if (!srcType || !dstType)
1706 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1711 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1713 unsigned resultBitwidth = 64;
1715 resultBitwidth = intTy.getWidth();
1717 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1718 adaptor.getOperands(),
getType(),
1719 [resultBitwidth](
const APInt &a,
bool & ) {
1720 return a.sextOrTrunc(resultBitwidth);
1724 void arith::IndexCastOp::getCanonicalizationPatterns(
1726 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1733 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1738 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1740 unsigned resultBitwidth = 64;
1742 resultBitwidth = intTy.getWidth();
1744 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1745 adaptor.getOperands(),
getType(),
1746 [resultBitwidth](
const APInt &a,
bool & ) {
1747 return a.zextOrTrunc(resultBitwidth);
1751 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1753 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1764 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1765 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1766 if (!srcType || !dstType)
1772 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1774 auto operand = adaptor.getIn();
1779 if (
auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1780 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1782 if (llvm::isa<ShapedType>(resType))
1786 if (llvm::isa<ub::PoisonAttr>(operand))
1790 APInt bits = llvm::isa<FloatAttr>(operand)
1791 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1792 : llvm::cast<IntegerAttr>(operand).getValue();
1794 "trying to fold on broken IR: operands have incompatible types");
1796 if (
auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1798 APFloat(resFloatType.getFloatSemantics(), bits));
1804 patterns.add<BitcastOfBitcast>(context);
1814 const APInt &lhs,
const APInt &rhs) {
1815 switch (predicate) {
1816 case arith::CmpIPredicate::eq:
1818 case arith::CmpIPredicate::ne:
1820 case arith::CmpIPredicate::slt:
1821 return lhs.slt(rhs);
1822 case arith::CmpIPredicate::sle:
1823 return lhs.sle(rhs);
1824 case arith::CmpIPredicate::sgt:
1825 return lhs.sgt(rhs);
1826 case arith::CmpIPredicate::sge:
1827 return lhs.sge(rhs);
1828 case arith::CmpIPredicate::ult:
1829 return lhs.ult(rhs);
1830 case arith::CmpIPredicate::ule:
1831 return lhs.ule(rhs);
1832 case arith::CmpIPredicate::ugt:
1833 return lhs.ugt(rhs);
1834 case arith::CmpIPredicate::uge:
1835 return lhs.uge(rhs);
1837 llvm_unreachable(
"unknown cmpi predicate kind");
1842 switch (predicate) {
1843 case arith::CmpIPredicate::eq:
1844 case arith::CmpIPredicate::sle:
1845 case arith::CmpIPredicate::sge:
1846 case arith::CmpIPredicate::ule:
1847 case arith::CmpIPredicate::uge:
1849 case arith::CmpIPredicate::ne:
1850 case arith::CmpIPredicate::slt:
1851 case arith::CmpIPredicate::sgt:
1852 case arith::CmpIPredicate::ult:
1853 case arith::CmpIPredicate::ugt:
1856 llvm_unreachable(
"unknown cmpi predicate kind");
1860 if (
auto intType = llvm::dyn_cast<IntegerType>(t)) {
1861 return intType.getWidth();
1863 if (
auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1864 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1866 return std::nullopt;
1869 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1871 if (getLhs() == getRhs()) {
1877 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1879 std::optional<int64_t> integerWidth =
1881 if (integerWidth && integerWidth.value() == 1 &&
1882 getPredicate() == arith::CmpIPredicate::ne)
1883 return extOp.getOperand();
1885 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1887 std::optional<int64_t> integerWidth =
1889 if (integerWidth && integerWidth.value() == 1 &&
1890 getPredicate() == arith::CmpIPredicate::ne)
1891 return extOp.getOperand();
1896 getPredicate() == arith::CmpIPredicate::ne)
1903 getPredicate() == arith::CmpIPredicate::eq)
1908 if (adaptor.getLhs() && !adaptor.getRhs()) {
1910 using Pred = CmpIPredicate;
1911 const std::pair<Pred, Pred> invPreds[] = {
1912 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1913 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1914 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1915 {Pred::ne, Pred::ne},
1917 Pred origPred = getPredicate();
1918 for (
auto pred : invPreds) {
1919 if (origPred == pred.first) {
1920 setPredicate(pred.second);
1921 Value lhs = getLhs();
1922 Value rhs = getRhs();
1923 getLhsMutable().assign(rhs);
1924 getRhsMutable().assign(lhs);
1928 llvm_unreachable(
"unknown cmpi predicate kind");
1933 if (
auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1934 return constFoldBinaryOp<IntegerAttr>(
1936 [pred = getPredicate()](
const APInt &lhs,
const APInt &rhs) {
1947 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1957 const APFloat &lhs,
const APFloat &rhs) {
1958 auto cmpResult = lhs.compare(rhs);
1959 switch (predicate) {
1960 case arith::CmpFPredicate::AlwaysFalse:
1962 case arith::CmpFPredicate::OEQ:
1963 return cmpResult == APFloat::cmpEqual;
1964 case arith::CmpFPredicate::OGT:
1965 return cmpResult == APFloat::cmpGreaterThan;
1966 case arith::CmpFPredicate::OGE:
1967 return cmpResult == APFloat::cmpGreaterThan ||
1968 cmpResult == APFloat::cmpEqual;
1969 case arith::CmpFPredicate::OLT:
1970 return cmpResult == APFloat::cmpLessThan;
1971 case arith::CmpFPredicate::OLE:
1972 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1973 case arith::CmpFPredicate::ONE:
1974 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1975 case arith::CmpFPredicate::ORD:
1976 return cmpResult != APFloat::cmpUnordered;
1977 case arith::CmpFPredicate::UEQ:
1978 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1979 case arith::CmpFPredicate::UGT:
1980 return cmpResult == APFloat::cmpUnordered ||
1981 cmpResult == APFloat::cmpGreaterThan;
1982 case arith::CmpFPredicate::UGE:
1983 return cmpResult == APFloat::cmpUnordered ||
1984 cmpResult == APFloat::cmpGreaterThan ||
1985 cmpResult == APFloat::cmpEqual;
1986 case arith::CmpFPredicate::ULT:
1987 return cmpResult == APFloat::cmpUnordered ||
1988 cmpResult == APFloat::cmpLessThan;
1989 case arith::CmpFPredicate::ULE:
1990 return cmpResult == APFloat::cmpUnordered ||
1991 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1992 case arith::CmpFPredicate::UNE:
1993 return cmpResult != APFloat::cmpEqual;
1994 case arith::CmpFPredicate::UNO:
1995 return cmpResult == APFloat::cmpUnordered;
1996 case arith::CmpFPredicate::AlwaysTrue:
1999 llvm_unreachable(
"unknown cmpf predicate kind");
2002 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2003 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2004 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2007 if (lhs && lhs.getValue().isNaN())
2009 if (rhs && rhs.getValue().isNaN())
2025 using namespace arith;
2027 case CmpFPredicate::UEQ:
2028 case CmpFPredicate::OEQ:
2029 return CmpIPredicate::eq;
2030 case CmpFPredicate::UGT:
2031 case CmpFPredicate::OGT:
2032 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2033 case CmpFPredicate::UGE:
2034 case CmpFPredicate::OGE:
2035 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2036 case CmpFPredicate::ULT:
2037 case CmpFPredicate::OLT:
2038 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2039 case CmpFPredicate::ULE:
2040 case CmpFPredicate::OLE:
2041 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2042 case CmpFPredicate::UNE:
2043 case CmpFPredicate::ONE:
2044 return CmpIPredicate::ne;
2046 llvm_unreachable(
"Unexpected predicate!");
2056 const APFloat &rhs = flt.getValue();
2064 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2065 int mantissaWidth = floatTy.getFPMantissaWidth();
2066 if (mantissaWidth <= 0)
2072 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2074 intVal = si.getIn();
2075 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2077 intVal = ui.getIn();
2084 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2085 auto intWidth = intTy.getWidth();
2088 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2093 if ((
int)intWidth > mantissaWidth) {
2095 int exponent = ilogb(rhs);
2096 if (exponent == APFloat::IEK_Inf) {
2097 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2098 if (maxExponent < (
int)valueBits) {
2105 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2114 switch (op.getPredicate()) {
2115 case CmpFPredicate::ORD:
2120 case CmpFPredicate::UNO:
2133 APFloat signedMax(rhs.getSemantics());
2134 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2135 APFloat::rmNearestTiesToEven);
2136 if (signedMax < rhs) {
2137 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2138 pred == CmpIPredicate::sle)
2149 APFloat unsignedMax(rhs.getSemantics());
2150 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2151 APFloat::rmNearestTiesToEven);
2152 if (unsignedMax < rhs) {
2153 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2154 pred == CmpIPredicate::ule)
2166 APFloat signedMin(rhs.getSemantics());
2167 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2168 APFloat::rmNearestTiesToEven);
2169 if (signedMin > rhs) {
2170 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2171 pred == CmpIPredicate::sge)
2181 APFloat unsignedMin(rhs.getSemantics());
2182 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2183 APFloat::rmNearestTiesToEven);
2184 if (unsignedMin > rhs) {
2185 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2186 pred == CmpIPredicate::uge)
2201 APSInt rhsInt(intWidth, isUnsigned);
2202 if (APFloat::opInvalidOp ==
2203 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2209 if (!rhs.isZero()) {
2210 APFloat apf(floatTy.getFloatSemantics(),
2212 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2214 bool equal = apf == rhs;
2220 case CmpIPredicate::ne:
2224 case CmpIPredicate::eq:
2228 case CmpIPredicate::ule:
2231 if (rhs.isNegative()) {
2237 case CmpIPredicate::sle:
2240 if (rhs.isNegative())
2241 pred = CmpIPredicate::slt;
2243 case CmpIPredicate::ult:
2246 if (rhs.isNegative()) {
2251 pred = CmpIPredicate::ule;
2253 case CmpIPredicate::slt:
2256 if (!rhs.isNegative())
2257 pred = CmpIPredicate::sle;
2259 case CmpIPredicate::ugt:
2262 if (rhs.isNegative()) {
2268 case CmpIPredicate::sgt:
2271 if (rhs.isNegative())
2272 pred = CmpIPredicate::sge;
2274 case CmpIPredicate::uge:
2277 if (rhs.isNegative()) {
2282 pred = CmpIPredicate::ugt;
2284 case CmpIPredicate::sge:
2287 if (!rhs.isNegative())
2288 pred = CmpIPredicate::sgt;
2298 rewriter.
create<ConstantOp>(
2299 op.getLoc(), intVal.
getType(),
2321 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2337 rewriter.
create<arith::XOrIOp>(
2338 op.getLoc(), op.getCondition(),
2339 rewriter.
create<arith::ConstantIntOp>(
2340 op.getLoc(), 1, op.getCondition().getType())));
2350 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2354 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2355 Value trueVal = getTrueValue();
2356 Value falseVal = getFalseValue();
2357 if (trueVal == falseVal)
2360 Value condition = getCondition();
2371 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2374 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2378 if (
getType().isSignlessInteger(1) &&
2383 if (
auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.
getDefiningOp())) {
2384 auto pred = cmp.getPredicate();
2385 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2386 auto cmpLhs = cmp.getLhs();
2387 auto cmpRhs = cmp.getRhs();
2395 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2396 (cmpRhs == trueVal && cmpLhs == falseVal))
2397 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2404 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2406 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2408 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2410 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2411 auto condVals = llvm::make_range(cond.value_begin<
BoolAttr>(),
2413 auto lhsVals = llvm::make_range(lhs.value_begin<
Attribute>(),
2415 auto rhsVals = llvm::make_range(rhs.value_begin<
Attribute>(),
2418 for (
auto [condVal, lhsVal, rhsVal] :
2419 llvm::zip_equal(condVals, lhsVals, rhsVals))
2420 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2431 Type conditionType, resultType;
2440 conditionType = resultType;
2449 {conditionType, resultType, resultType},
2454 p <<
" " << getOperands();
2457 if (ShapedType condType =
2458 llvm::dyn_cast<ShapedType>(getCondition().
getType()))
2459 p << condType <<
", ";
2464 Type conditionType = getCondition().getType();
2471 if (!llvm::isa<TensorType, VectorType>(resultType))
2472 return emitOpError() <<
"expected condition to be a signless i1, but got "
2475 if (conditionType != shapedConditionType) {
2476 return emitOpError() <<
"expected condition type to have the same shape "
2477 "as the result type, expected "
2478 << shapedConditionType <<
", but got "
2487 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2492 bool bounded =
false;
2493 auto result = constFoldBinaryOp<IntegerAttr>(
2494 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2495 bounded = b.ult(b.getBitWidth());
2505 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2510 bool bounded =
false;
2511 auto result = constFoldBinaryOp<IntegerAttr>(
2512 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2513 bounded = b.ult(b.getBitWidth());
2523 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2528 bool bounded =
false;
2529 auto result = constFoldBinaryOp<IntegerAttr>(
2530 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2531 bounded = b.ult(b.getBitWidth());
2544 bool useOnlyFiniteValue) {
2546 case AtomicRMWKind::maximumf: {
2547 const llvm::fltSemantics &semantic =
2548 llvm::cast<FloatType>(resultType).getFloatSemantics();
2549 APFloat identity = useOnlyFiniteValue
2550 ? APFloat::getLargest(semantic,
true)
2551 : APFloat::getInf(semantic,
true);
2554 case AtomicRMWKind::maxnumf: {
2555 const llvm::fltSemantics &semantic =
2556 llvm::cast<FloatType>(resultType).getFloatSemantics();
2557 APFloat identity = APFloat::getNaN(semantic,
true);
2560 case AtomicRMWKind::addf:
2561 case AtomicRMWKind::addi:
2562 case AtomicRMWKind::maxu:
2563 case AtomicRMWKind::ori:
2565 case AtomicRMWKind::andi:
2568 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2569 case AtomicRMWKind::maxs:
2571 resultType, APInt::getSignedMinValue(
2572 llvm::cast<IntegerType>(resultType).getWidth()));
2573 case AtomicRMWKind::minimumf: {
2574 const llvm::fltSemantics &semantic =
2575 llvm::cast<FloatType>(resultType).getFloatSemantics();
2576 APFloat identity = useOnlyFiniteValue
2577 ? APFloat::getLargest(semantic,
false)
2578 : APFloat::getInf(semantic,
false);
2582 case AtomicRMWKind::minnumf: {
2583 const llvm::fltSemantics &semantic =
2584 llvm::cast<FloatType>(resultType).getFloatSemantics();
2585 APFloat identity = APFloat::getNaN(semantic,
false);
2588 case AtomicRMWKind::mins:
2590 resultType, APInt::getSignedMaxValue(
2591 llvm::cast<IntegerType>(resultType).getWidth()));
2592 case AtomicRMWKind::minu:
2595 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2596 case AtomicRMWKind::muli:
2598 case AtomicRMWKind::mulf:
2610 std::optional<AtomicRMWKind> maybeKind =
2613 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2614 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2615 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2616 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2617 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2618 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2620 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2621 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2622 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::ori; })
2623 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2624 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2625 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2626 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2627 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2628 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2629 .Default([](
Operation *op) {
return std::nullopt; });
2631 return std::nullopt;
2634 bool useOnlyFiniteValue =
false;
2635 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2636 if (fmfOpInterface) {
2637 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2638 useOnlyFiniteValue =
2639 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2647 useOnlyFiniteValue);
2653 bool useOnlyFiniteValue) {
2656 return builder.
create<arith::ConstantOp>(loc, attr);
2664 case AtomicRMWKind::addf:
2665 return builder.
create<arith::AddFOp>(loc, lhs, rhs);
2666 case AtomicRMWKind::addi:
2667 return builder.
create<arith::AddIOp>(loc, lhs, rhs);
2668 case AtomicRMWKind::mulf:
2669 return builder.
create<arith::MulFOp>(loc, lhs, rhs);
2670 case AtomicRMWKind::muli:
2671 return builder.
create<arith::MulIOp>(loc, lhs, rhs);
2672 case AtomicRMWKind::maximumf:
2673 return builder.
create<arith::MaximumFOp>(loc, lhs, rhs);
2674 case AtomicRMWKind::minimumf:
2675 return builder.
create<arith::MinimumFOp>(loc, lhs, rhs);
2676 case AtomicRMWKind::maxnumf:
2677 return builder.
create<arith::MaxNumFOp>(loc, lhs, rhs);
2678 case AtomicRMWKind::minnumf:
2679 return builder.
create<arith::MinNumFOp>(loc, lhs, rhs);
2680 case AtomicRMWKind::maxs:
2681 return builder.
create<arith::MaxSIOp>(loc, lhs, rhs);
2682 case AtomicRMWKind::mins:
2683 return builder.
create<arith::MinSIOp>(loc, lhs, rhs);
2684 case AtomicRMWKind::maxu:
2685 return builder.
create<arith::MaxUIOp>(loc, lhs, rhs);
2686 case AtomicRMWKind::minu:
2687 return builder.
create<arith::MinUIOp>(loc, lhs, rhs);
2688 case AtomicRMWKind::ori:
2689 return builder.
create<arith::OrIOp>(loc, lhs, rhs);
2690 case AtomicRMWKind::andi:
2691 return builder.
create<arith::AndIOp>(loc, lhs, rhs);
2704 #define GET_OP_CLASSES
2705 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2711 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1176::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
Specialization of arith.constant op that returns an integer value.
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)