25 #include "llvm/ADT/APInt.h"
26 #include "llvm/ADT/APSInt.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallString.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
42 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
43 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
44 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
45 APInt value = binFn(lhsVal, rhsVal);
67 case arith::CmpIPredicate::eq:
68 return arith::CmpIPredicate::ne;
69 case arith::CmpIPredicate::ne:
70 return arith::CmpIPredicate::eq;
71 case arith::CmpIPredicate::slt:
72 return arith::CmpIPredicate::sge;
73 case arith::CmpIPredicate::sle:
74 return arith::CmpIPredicate::sgt;
75 case arith::CmpIPredicate::sgt:
76 return arith::CmpIPredicate::sle;
77 case arith::CmpIPredicate::sge:
78 return arith::CmpIPredicate::slt;
79 case arith::CmpIPredicate::ult:
80 return arith::CmpIPredicate::uge;
81 case arith::CmpIPredicate::ule:
82 return arith::CmpIPredicate::ugt;
83 case arith::CmpIPredicate::ugt:
84 return arith::CmpIPredicate::ule;
85 case arith::CmpIPredicate::uge:
86 return arith::CmpIPredicate::ult;
88 llvm_unreachable(
"unknown cmpi predicate kind");
118 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
129 #include "ArithCanonicalization.inc"
139 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
140 return shapedType.cloneWith(std::nullopt, i1Type);
141 if (llvm::isa<UnrankedTensorType>(type))
150 void arith::ConstantOp::getAsmResultNames(
152 auto type = getType();
153 if (
auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
154 auto intType = llvm::dyn_cast<IntegerType>(type);
157 if (intType && intType.getWidth() == 1)
158 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
162 llvm::raw_svector_ostream specialName(specialNameBuffer);
163 specialName <<
'c' << intCst.getValue();
165 specialName <<
'_' << type;
166 setNameFn(getResult(), specialName.str());
168 setNameFn(getResult(),
"cst");
175 auto type = getType();
177 if (getValue().getType() != type) {
178 return emitOpError() <<
"value type " << getValue().getType()
179 <<
" must match return type: " << type;
182 if (llvm::isa<IntegerType>(type) &&
183 !llvm::cast<IntegerType>(type).isSignless())
184 return emitOpError(
"integer return type must be signless");
186 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
188 "value must be an integer, float, or elements attribute");
193 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
195 auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
196 if (!typedAttr || typedAttr.getType() != type)
199 if (llvm::isa<IntegerType>(type) &&
200 !llvm::cast<IntegerType>(type).isSignless())
203 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
208 if (isBuildableWith(value, type))
209 return builder.
create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
213 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
216 int64_t value,
unsigned width) {
218 arith::ConstantOp::build(builder, result, type,
223 int64_t value,
Type type) {
225 "ConstantIntOp can only have signless integer type values");
226 arith::ConstantOp::build(builder, result, type,
231 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
232 return constOp.getType().isSignlessInteger();
238 arith::ConstantOp::build(builder, result, type,
243 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
244 return llvm::isa<FloatType>(constOp.getType());
250 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
255 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
256 return constOp.getType().isIndex();
270 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
271 if (getRhs() == sub.getRhs())
275 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
276 if (getLhs() == sub.getRhs())
279 return constFoldBinaryOp<IntegerAttr>(
280 adaptor.getOperands(),
281 [](APInt a,
const APInt &b) { return std::move(a) + b; });
286 patterns.
add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
287 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
294 std::optional<SmallVector<int64_t, 4>>
295 arith::AddUIExtendedOp::getShapeForUnroll() {
296 if (
auto vt = llvm::dyn_cast<VectorType>(getType(0)))
297 return llvm::to_vector<4>(vt.getShape());
304 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
308 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
310 Type overflowTy = getOverflow().getType();
316 results.push_back(getLhs());
317 results.push_back(falseValue);
325 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
326 adaptor.getOperands(),
327 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
328 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
329 ArrayRef({sumAttr, adaptor.getLhs()}),
335 results.push_back(sumAttr);
336 results.push_back(overflowAttr);
343 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
345 patterns.
add<AddUIExtendedToAddI>(context);
354 if (getOperand(0) == getOperand(1))
360 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
362 if (getRhs() == add.getRhs())
365 if (getRhs() == add.getLhs())
369 return constFoldBinaryOp<IntegerAttr>(
370 adaptor.getOperands(),
371 [](APInt a,
const APInt &b) { return std::move(a) - b; });
376 patterns.
add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
377 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
378 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
395 return constFoldBinaryOp<IntegerAttr>(
396 adaptor.getOperands(),
397 [](
const APInt &a,
const APInt &b) { return a * b; });
402 patterns.
add<MulIMulIConstant>(context);
409 std::optional<SmallVector<int64_t, 4>>
410 arith::MulSIExtendedOp::getShapeForUnroll() {
411 if (
auto vt = llvm::dyn_cast<VectorType>(getType(0)))
412 return llvm::to_vector<4>(vt.getShape());
417 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
422 results.push_back(zero);
423 results.push_back(zero);
428 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
429 adaptor.getOperands(),
430 [](
const APInt &a,
const APInt &b) { return a * b; })) {
432 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
433 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
434 unsigned bitWidth = a.getBitWidth();
435 APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
436 return fullProduct.extractBits(bitWidth, bitWidth);
438 assert(highAttr &&
"Unexpected constant-folding failure");
440 results.push_back(lowAttr);
441 results.push_back(highAttr);
448 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
450 patterns.
add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
457 std::optional<SmallVector<int64_t, 4>>
458 arith::MulUIExtendedOp::getShapeForUnroll() {
459 if (
auto vt = llvm::dyn_cast<VectorType>(getType(0)))
460 return llvm::to_vector<4>(vt.getShape());
465 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
470 results.push_back(zero);
471 results.push_back(zero);
479 results.push_back(getLhs());
480 results.push_back(zero);
485 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
486 adaptor.getOperands(),
487 [](
const APInt &a,
const APInt &b) { return a * b; })) {
489 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
490 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
491 unsigned bitWidth = a.getBitWidth();
492 APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
493 return fullProduct.extractBits(bitWidth, bitWidth);
495 assert(highAttr &&
"Unexpected constant-folding failure");
497 results.push_back(lowAttr);
498 results.push_back(highAttr);
505 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
507 patterns.
add<MulUIExtendedToMulI>(context);
514 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
521 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
522 [&](APInt a,
const APInt &b) {
543 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
549 bool overflowOrDiv0 =
false;
550 auto result = constFoldBinaryOp<IntegerAttr>(
551 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
552 if (overflowOrDiv0 || !b) {
553 overflowOrDiv0 = true;
556 return a.sdiv_ov(b, overflowOrDiv0);
559 return overflowOrDiv0 ?
Attribute() : result;
563 bool mayHaveUB =
true;
569 mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
581 APInt one(a.getBitWidth(), 1,
true);
582 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
583 return val.sadd_ov(one, overflow);
590 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
595 bool overflowOrDiv0 =
false;
596 auto result = constFoldBinaryOp<IntegerAttr>(
597 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
598 if (overflowOrDiv0 || !b) {
599 overflowOrDiv0 = true;
602 APInt quotient = a.udiv(b);
605 APInt one(a.getBitWidth(), 1,
true);
606 return quotient.uadd_ov(one, overflowOrDiv0);
609 return overflowOrDiv0 ?
Attribute() : result;
622 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
628 bool overflowOrDiv0 =
false;
629 auto result = constFoldBinaryOp<IntegerAttr>(
630 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
631 if (overflowOrDiv0 || !b) {
632 overflowOrDiv0 = true;
638 unsigned bits = a.getBitWidth();
640 bool aGtZero = a.sgt(zero);
641 bool bGtZero = b.sgt(zero);
642 if (aGtZero && bGtZero) {
646 if (!aGtZero && !bGtZero) {
648 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
649 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
652 if (!aGtZero && bGtZero) {
654 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
655 APInt div = posA.sdiv_ov(b, overflowOrDiv0);
656 return zero.ssub_ov(div, overflowOrDiv0);
659 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
660 APInt div = a.sdiv_ov(posB, overflowOrDiv0);
661 return zero.ssub_ov(div, overflowOrDiv0);
664 return overflowOrDiv0 ?
Attribute() : result;
668 bool mayHaveUB =
true;
674 mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
683 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
689 bool overflowOrDiv0 =
false;
690 auto result = constFoldBinaryOp<IntegerAttr>(
691 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
692 if (overflowOrDiv0 || !b) {
693 overflowOrDiv0 = true;
699 unsigned bits = a.getBitWidth();
701 bool aGtZero = a.sgt(zero);
702 bool bGtZero = b.sgt(zero);
703 if (aGtZero && bGtZero) {
705 return a.sdiv_ov(b, overflowOrDiv0);
707 if (!aGtZero && !bGtZero) {
709 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
710 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
711 return posA.sdiv_ov(posB, overflowOrDiv0);
713 if (!aGtZero && bGtZero) {
715 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
717 return zero.ssub_ov(
ceil, overflowOrDiv0);
720 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
722 return zero.ssub_ov(
ceil, overflowOrDiv0);
725 return overflowOrDiv0 ?
Attribute() : result;
732 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
739 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
740 [&](APInt a,
const APInt &b) {
741 if (div0 || b.isZero()) {
755 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
762 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
763 [&](APInt a,
const APInt &b) {
764 if (div0 || b.isZero()) {
780 for (
bool reversePrev : {
false,
true}) {
781 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
782 .getDefiningOp<arith::AndIOp>();
786 Value other = (reversePrev ? op.getLhs() : op.getRhs());
787 if (other != prev.getLhs() && other != prev.getRhs())
790 return prev.getResult();
802 intValue.isAllOnes())
807 intValue.isAllOnes())
812 intValue.isAllOnes())
819 return constFoldBinaryOp<IntegerAttr>(
820 adaptor.getOperands(),
821 [](APInt a,
const APInt &b) { return std::move(a) & b; });
834 if (rhsVal.isAllOnes())
835 return adaptor.getRhs();
842 intValue.isAllOnes())
843 return getRhs().getDefiningOp<XOrIOp>().getRhs();
847 intValue.isAllOnes())
848 return getLhs().getDefiningOp<XOrIOp>().getRhs();
850 return constFoldBinaryOp<IntegerAttr>(
851 adaptor.getOperands(),
852 [](APInt a,
const APInt &b) { return std::move(a) | b; });
864 if (getLhs() == getRhs())
868 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
869 if (prev.getRhs() == getRhs())
870 return prev.getLhs();
871 if (prev.getLhs() == getRhs())
872 return prev.getRhs();
876 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
877 if (prev.getRhs() == getLhs())
878 return prev.getLhs();
879 if (prev.getLhs() == getLhs())
880 return prev.getRhs();
883 return constFoldBinaryOp<IntegerAttr>(
884 adaptor.getOperands(),
885 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
890 patterns.
add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
899 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
901 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
902 [](
const APFloat &a) { return -a; });
914 return constFoldBinaryOp<FloatAttr>(
915 adaptor.getOperands(),
916 [](
const APFloat &a,
const APFloat &b) { return a + b; });
928 return constFoldBinaryOp<FloatAttr>(
929 adaptor.getOperands(),
930 [](
const APFloat &a,
const APFloat &b) { return a - b; });
937 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
939 if (getLhs() == getRhs())
946 return constFoldBinaryOp<FloatAttr>(
947 adaptor.getOperands(),
948 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
955 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
957 if (getLhs() == getRhs())
964 return constFoldBinaryOp<FloatAttr>(
965 adaptor.getOperands(),
966 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
976 if (getLhs() == getRhs())
982 if (intValue.isMaxSignedValue())
985 if (intValue.isMinSignedValue())
989 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
990 [](
const APInt &a,
const APInt &b) {
991 return llvm::APIntOps::smax(a, b);
1001 if (getLhs() == getRhs())
1007 if (intValue.isMaxValue())
1010 if (intValue.isMinValue())
1014 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1015 [](
const APInt &a,
const APInt &b) {
1016 return llvm::APIntOps::umax(a, b);
1024 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1026 if (getLhs() == getRhs())
1033 return constFoldBinaryOp<FloatAttr>(
1034 adaptor.getOperands(),
1035 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
1042 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1044 if (getLhs() == getRhs())
1051 return constFoldBinaryOp<FloatAttr>(
1052 adaptor.getOperands(),
1053 [](
const APFloat &a,
const APFloat &b) { return llvm::minnum(a, b); });
1062 if (getLhs() == getRhs())
1068 if (intValue.isMinSignedValue())
1071 if (intValue.isMaxSignedValue())
1075 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1076 [](
const APInt &a,
const APInt &b) {
1077 return llvm::APIntOps::smin(a, b);
1087 if (getLhs() == getRhs())
1093 if (intValue.isMinValue())
1096 if (intValue.isMaxValue())
1100 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1101 [](
const APInt &a,
const APInt &b) {
1102 return llvm::APIntOps::umin(a, b);
1110 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1115 return constFoldBinaryOp<FloatAttr>(
1116 adaptor.getOperands(),
1117 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1122 patterns.
add<MulFOfNegF>(context);
1129 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1134 return constFoldBinaryOp<FloatAttr>(
1135 adaptor.getOperands(),
1136 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1141 patterns.
add<DivFOfNegF>(context);
1148 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1149 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1150 [](
const APFloat &a,
const APFloat &b) {
1152 (void)result.remainder(b);
1161 template <
typename... Types>
1167 template <
typename... ShapedTypes,
typename... ElementTypes>
1170 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1174 if (!llvm::isa<ElementTypes...>(underlyingType))
1177 return underlyingType;
1181 template <
typename... ElementTypes>
1188 template <
typename... ElementTypes>
1197 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1198 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1199 if (!rankedTensorA || !rankedTensorB)
1201 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1205 if (inputs.size() != 1 || outputs.size() != 1)
1217 template <
typename ValType,
typename Op>
1222 if (llvm::cast<ValType>(srcType).getWidth() >=
1223 llvm::cast<ValType>(dstType).getWidth())
1225 << dstType <<
" must be wider than operand type " << srcType;
1231 template <
typename ValType,
typename Op>
1236 if (llvm::cast<ValType>(srcType).getWidth() <=
1237 llvm::cast<ValType>(dstType).getWidth())
1239 << dstType <<
" must be shorter than operand type " << srcType;
1245 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1250 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1251 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1252 if (!srcType || !dstType)
1255 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1256 srcType.getIntOrFloatBitWidth());
1263 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1264 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1265 getInMutable().assign(lhs.getIn());
1270 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1271 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1272 adaptor.getOperands(), getType(),
1273 [bitWidth](
const APInt &a,
bool &castStatus) {
1274 return a.zext(bitWidth);
1279 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1283 return verifyExtOp<IntegerType>(*
this);
1290 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1291 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1292 getInMutable().assign(lhs.getIn());
1297 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1298 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1299 adaptor.getOperands(), getType(),
1300 [bitWidth](
const APInt &a,
bool &castStatus) {
1301 return a.sext(bitWidth);
1306 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1311 patterns.
add<ExtSIOfExtUI>(context);
1315 return verifyExtOp<IntegerType>(*
this);
1323 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1324 auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
1329 return FloatAttr::get(getType(), constOperand.getValue().convertToDouble());
1333 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1342 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1343 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1350 if (llvm::cast<IntegerType>(srcType).getWidth() >
1351 llvm::cast<IntegerType>(dstType).getWidth()) {
1361 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1362 setOperand(getOperand().getDefiningOp()->getOperand(0));
1367 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1368 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1369 adaptor.getOperands(), getType(),
1370 [bitWidth](
const APInt &a,
bool &castStatus) {
1371 return a.trunc(bitWidth);
1376 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1381 patterns.
add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1382 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1387 return verifyTruncateOp<IntegerType>(*
this);
1396 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1397 auto constOperand = adaptor.getIn();
1398 if (!constOperand || !llvm::isa<FloatAttr>(constOperand))
1402 double sourceValue =
1403 llvm::dyn_cast<FloatAttr>(constOperand).getValue().convertToDouble();
1407 if (sourceValue == targetAttr.getValue().convertToDouble())
1414 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1418 return verifyTruncateOp<FloatType>(*
this);
1427 patterns.
add<AndOfExtUI, AndOfExtSI>(context);
1436 patterns.
add<OrOfExtUI, OrOfExtSI>(context);
1443 template <
typename From,
typename To>
1448 auto srcType = getTypeIfLike<From>(inputs.front());
1449 auto dstType = getTypeIfLike<To>(outputs.back());
1451 return srcType && dstType;
1459 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1462 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1464 return constFoldCastOp<IntegerAttr, FloatAttr>(
1465 adaptor.getOperands(), getType(),
1466 [&resEleType](
const APInt &a,
bool &castStatus) {
1467 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1470 apf.convertFromAPInt(a,
false,
1471 APFloat::rmNearestTiesToEven);
1481 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1484 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1486 return constFoldCastOp<IntegerAttr, FloatAttr>(
1487 adaptor.getOperands(), getType(),
1488 [&resEleType](
const APInt &a,
bool &castStatus) {
1489 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1492 apf.convertFromAPInt(a,
true,
1493 APFloat::rmNearestTiesToEven);
1502 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1505 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1507 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1508 return constFoldCastOp<FloatAttr, IntegerAttr>(
1509 adaptor.getOperands(), getType(),
1510 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1512 APSInt api(bitWidth,
true);
1513 castStatus = APFloat::opInvalidOp !=
1514 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1524 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1527 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1529 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1530 return constFoldCastOp<FloatAttr, IntegerAttr>(
1531 adaptor.getOperands(), getType(),
1532 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1534 APSInt api(bitWidth,
false);
1535 castStatus = APFloat::opInvalidOp !=
1536 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1549 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1550 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1551 if (!srcType || !dstType)
1558 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1563 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1565 unsigned resultBitwidth = 64;
1567 resultBitwidth = intTy.getWidth();
1569 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1570 adaptor.getOperands(), getType(),
1571 [resultBitwidth](
const APInt &a,
bool & ) {
1572 return a.sextOrTrunc(resultBitwidth);
1576 void arith::IndexCastOp::getCanonicalizationPatterns(
1578 patterns.
add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1585 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1590 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1592 unsigned resultBitwidth = 64;
1594 resultBitwidth = intTy.getWidth();
1596 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1597 adaptor.getOperands(), getType(),
1598 [resultBitwidth](
const APInt &a,
bool & ) {
1599 return a.zextOrTrunc(resultBitwidth);
1603 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1605 patterns.
add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1617 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1619 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1620 if (!srcType || !dstType)
1626 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1627 auto resType = getType();
1628 auto operand = adaptor.getIn();
1633 if (
auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1634 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1636 if (llvm::isa<ShapedType>(resType))
1640 APInt bits = llvm::isa<FloatAttr>(operand)
1641 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1642 : llvm::cast<IntegerAttr>(operand).getValue();
1644 if (
auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1646 APFloat(resFloatType.getFloatSemantics(), bits));
1652 patterns.
add<BitcastOfBitcast>(context);
1662 const APInt &lhs,
const APInt &rhs) {
1663 switch (predicate) {
1664 case arith::CmpIPredicate::eq:
1666 case arith::CmpIPredicate::ne:
1668 case arith::CmpIPredicate::slt:
1669 return lhs.slt(rhs);
1670 case arith::CmpIPredicate::sle:
1671 return lhs.sle(rhs);
1672 case arith::CmpIPredicate::sgt:
1673 return lhs.sgt(rhs);
1674 case arith::CmpIPredicate::sge:
1675 return lhs.sge(rhs);
1676 case arith::CmpIPredicate::ult:
1677 return lhs.ult(rhs);
1678 case arith::CmpIPredicate::ule:
1679 return lhs.ule(rhs);
1680 case arith::CmpIPredicate::ugt:
1681 return lhs.ugt(rhs);
1682 case arith::CmpIPredicate::uge:
1683 return lhs.uge(rhs);
1685 llvm_unreachable(
"unknown cmpi predicate kind");
1690 switch (predicate) {
1691 case arith::CmpIPredicate::eq:
1692 case arith::CmpIPredicate::sle:
1693 case arith::CmpIPredicate::sge:
1694 case arith::CmpIPredicate::ule:
1695 case arith::CmpIPredicate::uge:
1697 case arith::CmpIPredicate::ne:
1698 case arith::CmpIPredicate::slt:
1699 case arith::CmpIPredicate::sgt:
1700 case arith::CmpIPredicate::ult:
1701 case arith::CmpIPredicate::ugt:
1704 llvm_unreachable(
"unknown cmpi predicate kind");
1708 if (
auto intType = llvm::dyn_cast<IntegerType>(t)) {
1709 return intType.getWidth();
1711 if (
auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1712 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1714 return std::nullopt;
1717 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1719 if (getLhs() == getRhs()) {
1725 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1727 std::optional<int64_t> integerWidth =
1729 if (integerWidth && integerWidth.value() == 1 &&
1730 getPredicate() == arith::CmpIPredicate::ne)
1731 return extOp.getOperand();
1733 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1735 std::optional<int64_t> integerWidth =
1737 if (integerWidth && integerWidth.value() == 1 &&
1738 getPredicate() == arith::CmpIPredicate::ne)
1739 return extOp.getOperand();
1744 if (adaptor.getLhs() && !adaptor.getRhs()) {
1746 using Pred = CmpIPredicate;
1747 const std::pair<Pred, Pred> invPreds[] = {
1748 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1749 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1750 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1751 {Pred::ne, Pred::ne},
1753 Pred origPred = getPredicate();
1754 for (
auto pred : invPreds) {
1755 if (origPred == pred.first) {
1756 setPredicate(pred.second);
1757 Value lhs = getLhs();
1758 Value rhs = getRhs();
1759 getLhsMutable().assign(rhs);
1760 getRhsMutable().assign(lhs);
1764 llvm_unreachable(
"unknown cmpi predicate kind");
1769 if (
auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1770 return constFoldBinaryOp<IntegerAttr>(
1772 [pred = getPredicate()](
const APInt &lhs,
const APInt &rhs) {
1783 patterns.
insert<CmpIExtSI, CmpIExtUI>(context);
1793 const APFloat &lhs,
const APFloat &rhs) {
1794 auto cmpResult = lhs.compare(rhs);
1795 switch (predicate) {
1796 case arith::CmpFPredicate::AlwaysFalse:
1798 case arith::CmpFPredicate::OEQ:
1799 return cmpResult == APFloat::cmpEqual;
1800 case arith::CmpFPredicate::OGT:
1801 return cmpResult == APFloat::cmpGreaterThan;
1802 case arith::CmpFPredicate::OGE:
1803 return cmpResult == APFloat::cmpGreaterThan ||
1804 cmpResult == APFloat::cmpEqual;
1805 case arith::CmpFPredicate::OLT:
1806 return cmpResult == APFloat::cmpLessThan;
1807 case arith::CmpFPredicate::OLE:
1808 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1809 case arith::CmpFPredicate::ONE:
1810 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1811 case arith::CmpFPredicate::ORD:
1812 return cmpResult != APFloat::cmpUnordered;
1813 case arith::CmpFPredicate::UEQ:
1814 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1815 case arith::CmpFPredicate::UGT:
1816 return cmpResult == APFloat::cmpUnordered ||
1817 cmpResult == APFloat::cmpGreaterThan;
1818 case arith::CmpFPredicate::UGE:
1819 return cmpResult == APFloat::cmpUnordered ||
1820 cmpResult == APFloat::cmpGreaterThan ||
1821 cmpResult == APFloat::cmpEqual;
1822 case arith::CmpFPredicate::ULT:
1823 return cmpResult == APFloat::cmpUnordered ||
1824 cmpResult == APFloat::cmpLessThan;
1825 case arith::CmpFPredicate::ULE:
1826 return cmpResult == APFloat::cmpUnordered ||
1827 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1828 case arith::CmpFPredicate::UNE:
1829 return cmpResult != APFloat::cmpEqual;
1830 case arith::CmpFPredicate::UNO:
1831 return cmpResult == APFloat::cmpUnordered;
1832 case arith::CmpFPredicate::AlwaysTrue:
1835 llvm_unreachable(
"unknown cmpf predicate kind");
1838 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1839 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1840 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1843 if (lhs && lhs.getValue().isNaN())
1845 if (rhs && rhs.getValue().isNaN())
1861 using namespace arith;
1863 case CmpFPredicate::UEQ:
1864 case CmpFPredicate::OEQ:
1865 return CmpIPredicate::eq;
1866 case CmpFPredicate::UGT:
1867 case CmpFPredicate::OGT:
1868 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1869 case CmpFPredicate::UGE:
1870 case CmpFPredicate::OGE:
1871 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1872 case CmpFPredicate::ULT:
1873 case CmpFPredicate::OLT:
1874 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1875 case CmpFPredicate::ULE:
1876 case CmpFPredicate::OLE:
1877 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1878 case CmpFPredicate::UNE:
1879 case CmpFPredicate::ONE:
1880 return CmpIPredicate::ne;
1882 llvm_unreachable(
"Unexpected predicate!");
1892 const APFloat &rhs = flt.getValue();
1900 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
1902 if (mantissaWidth <= 0)
1908 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1910 intVal = si.getIn();
1911 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1913 intVal = ui.getIn();
1920 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
1921 auto intWidth = intTy.getWidth();
1924 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1929 if ((
int)intWidth > mantissaWidth) {
1931 int exponent = ilogb(rhs);
1932 if (exponent == APFloat::IEK_Inf) {
1933 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1934 if (maxExponent < (
int)valueBits) {
1941 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
1950 switch (op.getPredicate()) {
1951 case CmpFPredicate::ORD:
1956 case CmpFPredicate::UNO:
1969 APFloat signedMax(rhs.getSemantics());
1970 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
1971 APFloat::rmNearestTiesToEven);
1972 if (signedMax < rhs) {
1973 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1974 pred == CmpIPredicate::sle)
1985 APFloat unsignedMax(rhs.getSemantics());
1986 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
1987 APFloat::rmNearestTiesToEven);
1988 if (unsignedMax < rhs) {
1989 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1990 pred == CmpIPredicate::ule)
2002 APFloat signedMin(rhs.getSemantics());
2003 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2004 APFloat::rmNearestTiesToEven);
2005 if (signedMin > rhs) {
2006 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2007 pred == CmpIPredicate::sge)
2017 APFloat unsignedMin(rhs.getSemantics());
2018 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2019 APFloat::rmNearestTiesToEven);
2020 if (unsignedMin > rhs) {
2021 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2022 pred == CmpIPredicate::uge)
2037 APSInt rhsInt(intWidth, isUnsigned);
2038 if (APFloat::opInvalidOp ==
2039 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2045 if (!rhs.isZero()) {
2048 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2050 bool equal = apf == rhs;
2056 case CmpIPredicate::ne:
2060 case CmpIPredicate::eq:
2064 case CmpIPredicate::ule:
2067 if (rhs.isNegative()) {
2073 case CmpIPredicate::sle:
2076 if (rhs.isNegative())
2077 pred = CmpIPredicate::slt;
2079 case CmpIPredicate::ult:
2082 if (rhs.isNegative()) {
2087 pred = CmpIPredicate::ule;
2089 case CmpIPredicate::slt:
2092 if (!rhs.isNegative())
2093 pred = CmpIPredicate::sle;
2095 case CmpIPredicate::ugt:
2098 if (rhs.isNegative()) {
2104 case CmpIPredicate::sgt:
2107 if (rhs.isNegative())
2108 pred = CmpIPredicate::sge;
2110 case CmpIPredicate::uge:
2113 if (rhs.isNegative()) {
2118 pred = CmpIPredicate::ugt;
2120 case CmpIPredicate::sge:
2123 if (!rhs.isNegative())
2124 pred = CmpIPredicate::sgt;
2134 rewriter.
create<ConstantOp>(
2162 if (!op.getType().isInteger(1))
2165 Value falseConstant =
2166 rewriter.
create<arith::ConstantIntOp>(op.
getLoc(),
true, 1);
2167 Value notCondition = rewriter.
create<arith::XOrIOp>(
2168 op.
getLoc(), op.getCondition(), falseConstant);
2171 op.
getLoc(), op.getCondition(), op.getTrueValue());
2173 op.getFalseValue());
2186 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2202 rewriter.
create<arith::XOrIOp>(
2203 op.
getLoc(), op.getCondition(),
2204 rewriter.
create<arith::ConstantIntOp>(
2205 op.
getLoc(), 1, op.getCondition().getType())));
2216 SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
2220 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2221 Value trueVal = getTrueValue();
2222 Value falseVal = getFalseValue();
2223 if (trueVal == falseVal)
2226 Value condition = getCondition();
2237 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2240 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2248 if (
auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.
getDefiningOp())) {
2249 auto pred = cmp.getPredicate();
2250 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2251 auto cmpLhs = cmp.getLhs();
2252 auto cmpRhs = cmp.getRhs();
2260 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2261 (cmpRhs == trueVal && cmpLhs == falseVal))
2262 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2269 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2271 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2273 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2275 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2276 auto condVals = llvm::make_range(cond.value_begin<
BoolAttr>(),
2278 auto lhsVals = llvm::make_range(lhs.value_begin<
Attribute>(),
2280 auto rhsVals = llvm::make_range(rhs.value_begin<
Attribute>(),
2283 for (
auto [condVal, lhsVal, rhsVal] :
2284 llvm::zip_equal(condVals, lhsVals, rhsVals))
2285 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2296 Type conditionType, resultType;
2305 conditionType = resultType;
2314 {conditionType, resultType, resultType},
2319 p <<
" " << getOperands();
2322 if (ShapedType condType =
2323 llvm::dyn_cast<ShapedType>(getCondition().getType()))
2324 p << condType <<
", ";
2329 Type conditionType = getCondition().getType();
2335 Type resultType = getType();
2336 if (!llvm::isa<TensorType, VectorType>(resultType))
2337 return emitOpError() <<
"expected condition to be a signless i1, but got "
2340 if (conditionType != shapedConditionType) {
2341 return emitOpError() <<
"expected condition type to have the same shape "
2342 "as the result type, expected "
2343 << shapedConditionType <<
", but got "
2352 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2357 bool bounded =
false;
2358 auto result = constFoldBinaryOp<IntegerAttr>(
2359 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2360 bounded = b.ule(b.getBitWidth());
2370 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2375 bool bounded =
false;
2376 auto result = constFoldBinaryOp<IntegerAttr>(
2377 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2378 bounded = b.ule(b.getBitWidth());
2388 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2393 bool bounded =
false;
2394 auto result = constFoldBinaryOp<IntegerAttr>(
2395 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2396 bounded = b.ule(b.getBitWidth());
2409 bool useOnlyFiniteValue) {
2411 case AtomicRMWKind::maximumf: {
2412 const llvm::fltSemantics &semantic =
2413 llvm::cast<FloatType>(resultType).getFloatSemantics();
2414 APFloat identity = useOnlyFiniteValue
2415 ? APFloat::getLargest(semantic,
true)
2416 : APFloat::getInf(semantic,
true);
2419 case AtomicRMWKind::addf:
2420 case AtomicRMWKind::addi:
2421 case AtomicRMWKind::maxu:
2422 case AtomicRMWKind::ori:
2424 case AtomicRMWKind::andi:
2427 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2428 case AtomicRMWKind::maxs:
2430 resultType, APInt::getSignedMinValue(
2431 llvm::cast<IntegerType>(resultType).getWidth()));
2432 case AtomicRMWKind::minimumf: {
2433 const llvm::fltSemantics &semantic =
2434 llvm::cast<FloatType>(resultType).getFloatSemantics();
2435 APFloat identity = useOnlyFiniteValue
2436 ? APFloat::getLargest(semantic,
false)
2437 : APFloat::getInf(semantic,
false);
2441 case AtomicRMWKind::mins:
2443 resultType, APInt::getSignedMaxValue(
2444 llvm::cast<IntegerType>(resultType).getWidth()));
2445 case AtomicRMWKind::minu:
2448 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2449 case AtomicRMWKind::muli:
2451 case AtomicRMWKind::mulf:
2463 std::optional<AtomicRMWKind> maybeKind =
2466 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2467 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2468 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2469 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2471 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2472 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2473 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::ori; })
2474 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2475 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2476 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2477 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2478 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2479 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2480 .Default([](
Operation *op) {
return std::nullopt; });
2482 op->
emitError() <<
"Unknown neutral element for: " << *op;
2483 return std::nullopt;
2486 bool useOnlyFiniteValue =
false;
2487 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2488 if (fmfOpInterface) {
2489 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2490 useOnlyFiniteValue =
2491 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2499 useOnlyFiniteValue);
2505 bool useOnlyFiniteValue) {
2508 return builder.
create<arith::ConstantOp>(loc, attr);
2516 case AtomicRMWKind::addf:
2517 return builder.
create<arith::AddFOp>(loc, lhs, rhs);
2518 case AtomicRMWKind::addi:
2519 return builder.
create<arith::AddIOp>(loc, lhs, rhs);
2520 case AtomicRMWKind::mulf:
2521 return builder.
create<arith::MulFOp>(loc, lhs, rhs);
2522 case AtomicRMWKind::muli:
2523 return builder.
create<arith::MulIOp>(loc, lhs, rhs);
2524 case AtomicRMWKind::maximumf:
2525 return builder.
create<arith::MaximumFOp>(loc, lhs, rhs);
2526 case AtomicRMWKind::minimumf:
2527 return builder.
create<arith::MinimumFOp>(loc, lhs, rhs);
2528 case AtomicRMWKind::maxnumf:
2529 return builder.
create<arith::MaxNumFOp>(loc, lhs, rhs);
2530 case AtomicRMWKind::minnumf:
2531 return builder.
create<arith::MinNumFOp>(loc, lhs, rhs);
2532 case AtomicRMWKind::maxs:
2533 return builder.
create<arith::MaxSIOp>(loc, lhs, rhs);
2534 case AtomicRMWKind::mins:
2535 return builder.
create<arith::MinSIOp>(loc, lhs, rhs);
2536 case AtomicRMWKind::maxu:
2537 return builder.
create<arith::MaxUIOp>(loc, lhs, rhs);
2538 case AtomicRMWKind::minu:
2539 return builder.
create<arith::MinUIOp>(loc, lhs, rhs);
2540 case AtomicRMWKind::ori:
2541 return builder.
create<arith::OrIOp>(loc, lhs, rhs);
2542 case AtomicRMWKind::andi:
2543 return builder.
create<arith::AndIOp>(loc, lhs, rhs);
2556 #define GET_OP_CLASSES
2557 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2563 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
Specialization of arith.constant op that returns an integer value.
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
MPInt ceil(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)