21 #include "llvm/ADT/APSInt.h"
22 #include "llvm/ADT/SmallString.h"
23 #include "llvm/ADT/TypeSwitch.h"
35 lhs.
cast<IntegerAttr>().getInt() +
36 rhs.
cast<IntegerAttr>().getInt());
42 lhs.
cast<IntegerAttr>().getInt() -
43 rhs.
cast<IntegerAttr>().getInt());
49 case arith::CmpIPredicate::eq:
50 return arith::CmpIPredicate::ne;
51 case arith::CmpIPredicate::ne:
52 return arith::CmpIPredicate::eq;
53 case arith::CmpIPredicate::slt:
54 return arith::CmpIPredicate::sge;
55 case arith::CmpIPredicate::sle:
56 return arith::CmpIPredicate::sgt;
57 case arith::CmpIPredicate::sgt:
58 return arith::CmpIPredicate::sle;
59 case arith::CmpIPredicate::sge:
60 return arith::CmpIPredicate::slt;
61 case arith::CmpIPredicate::ult:
62 return arith::CmpIPredicate::uge;
63 case arith::CmpIPredicate::ule:
64 return arith::CmpIPredicate::ugt;
65 case arith::CmpIPredicate::ugt:
66 return arith::CmpIPredicate::ule;
67 case arith::CmpIPredicate::uge:
68 return arith::CmpIPredicate::ult;
70 llvm_unreachable(
"unknown cmpi predicate kind");
74 return arith::CmpIPredicateAttr::get(pred.getContext(),
91 if (
auto intAttr = attr.
dyn_cast<IntegerAttr>())
92 return intAttr.getValue();
95 if (splatAttr.getElementType().isa<IntegerType>())
96 return splatAttr.getSplatValue<APInt>();
106 #include "ArithCanonicalization.inc"
113 void arith::ConstantOp::getAsmResultNames(
115 auto type = getType();
116 if (
auto intCst = getValue().dyn_cast<IntegerAttr>()) {
117 auto intType = type.dyn_cast<IntegerType>();
120 if (intType && intType.getWidth() == 1)
121 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
125 llvm::raw_svector_ostream specialName(specialNameBuffer);
126 specialName <<
'c' << intCst.getValue();
128 specialName <<
'_' << type;
129 setNameFn(getResult(), specialName.str());
131 setNameFn(getResult(),
"cst");
138 auto type = getType();
140 if (getValue().getType() != type) {
141 return emitOpError() <<
"value type " << getValue().getType()
142 <<
" must match return type: " << type;
145 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
146 return emitOpError(
"integer return type must be signless");
148 if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
150 "value must be an integer, float, or elements attribute");
155 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
157 auto typedAttr = value.
dyn_cast<TypedAttr>();
158 if (!typedAttr || typedAttr.getType() != type)
161 if (type.
isa<IntegerType>() && !type.
cast<IntegerType>().isSignless())
164 return value.
isa<IntegerAttr, FloatAttr, ElementsAttr>();
167 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
170 int64_t value,
unsigned width) {
172 arith::ConstantOp::build(builder, result, type,
177 int64_t value,
Type type) {
179 "ConstantIntOp can only have signless integer type values");
180 arith::ConstantOp::build(builder, result, type,
185 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
186 return constOp.getType().isSignlessInteger();
192 arith::ConstantOp::build(builder, result, type,
197 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
198 return constOp.getType().isa<
FloatType>();
204 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
209 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
210 return constOp.getType().isIndex();
224 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
225 if (getRhs() == sub.getRhs())
229 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
230 if (getLhs() == sub.getRhs())
233 return constFoldBinaryOp<IntegerAttr>(
234 adaptor.getOperands(),
235 [](APInt a,
const APInt &b) { return std::move(a) + b; });
240 patterns.
add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
248 std::optional<SmallVector<int64_t, 4>>
249 arith::AddUIExtendedOp::getShapeForUnroll() {
250 if (
auto vt = getType(0).dyn_cast<VectorType>())
251 return llvm::to_vector<4>(vt.getShape());
258 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
262 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
264 Type overflowTy = getOverflow().getType();
268 auto falseValue = builder.getZeroAttr(overflowTy);
270 results.push_back(getLhs());
271 results.push_back(falseValue);
281 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
282 adaptor.getOperands(),
283 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
285 if (
auto lhs = adaptor.getLhs().dyn_cast<IntegerAttr>()) {
287 auto sum = sumAttr.
cast<IntegerAttr>();
288 overflowAttr = IntegerAttr::get(
295 lhs.getSplatValue<APInt>());
297 }
else if (
auto lhs = adaptor.getLhs().dyn_cast<ElementsAttr>()) {
299 auto sum = sumAttr.cast<ElementsAttr>();
300 const auto numElems =
static_cast<size_t>(sum.getNumElements());
302 overflowValues.reserve(numElems);
304 auto sumIt = sum.value_begin<APInt>();
305 auto lhsIt = lhs.value_begin<APInt>();
306 for (
size_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt)
314 results.push_back(sumAttr);
315 results.push_back(overflowAttr);
322 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
324 patterns.
add<AddUIExtendedToAddI>(context);
333 if (getOperand(0) == getOperand(1))
339 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
341 if (getRhs() == add.getRhs())
344 if (getRhs() == add.getLhs())
348 return constFoldBinaryOp<IntegerAttr>(
349 adaptor.getOperands(),
350 [](APInt a,
const APInt &b) { return std::move(a) - b; });
355 patterns.
add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
356 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
357 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
370 return getOperand(0);
374 return constFoldBinaryOp<IntegerAttr>(
375 adaptor.getOperands(),
376 [](
const APInt &a,
const APInt &b) { return a * b; });
383 std::optional<SmallVector<int64_t, 4>>
384 arith::MulSIExtendedOp::getShapeForUnroll() {
385 if (
auto vt = getType(0).dyn_cast<VectorType>())
386 return llvm::to_vector<4>(vt.getShape());
391 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
396 results.push_back(zero);
397 results.push_back(zero);
402 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
403 adaptor.getOperands(),
404 [](
const APInt &a,
const APInt &b) { return a * b; })) {
406 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
407 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
408 unsigned bitWidth = a.getBitWidth();
409 APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
410 return fullProduct.extractBits(bitWidth, bitWidth);
412 assert(highAttr &&
"Unexpected constant-folding failure");
414 results.push_back(lowAttr);
415 results.push_back(highAttr);
422 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
424 patterns.
add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
431 std::optional<SmallVector<int64_t, 4>>
432 arith::MulUIExtendedOp::getShapeForUnroll() {
433 if (
auto vt = getType(0).dyn_cast<VectorType>())
434 return llvm::to_vector<4>(vt.getShape());
439 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
444 results.push_back(zero);
445 results.push_back(zero);
452 Attribute zero = builder.getZeroAttr(getLhs().getType());
453 results.push_back(getLhs());
454 results.push_back(zero);
459 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
460 adaptor.getOperands(),
461 [](
const APInt &a,
const APInt &b) { return a * b; })) {
463 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
464 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
465 unsigned bitWidth = a.getBitWidth();
466 APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
467 return fullProduct.extractBits(bitWidth, bitWidth);
469 assert(highAttr &&
"Unexpected constant-folding failure");
471 results.push_back(lowAttr);
472 results.push_back(highAttr);
479 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
481 patterns.
add<MulUIExtendedToMulI>(context);
488 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
495 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
496 [&](APInt a,
const APInt &b) {
517 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
523 bool overflowOrDiv0 =
false;
524 auto result = constFoldBinaryOp<IntegerAttr>(
525 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
526 if (overflowOrDiv0 || !b) {
527 overflowOrDiv0 = true;
530 return a.sdiv_ov(b, overflowOrDiv0);
533 return overflowOrDiv0 ?
Attribute() : result;
537 bool mayHaveUB =
true;
543 mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
555 APInt one(a.getBitWidth(), 1,
true);
556 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
557 return val.sadd_ov(one, overflow);
564 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
569 bool overflowOrDiv0 =
false;
570 auto result = constFoldBinaryOp<IntegerAttr>(
571 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
572 if (overflowOrDiv0 || !b) {
573 overflowOrDiv0 = true;
576 APInt quotient = a.udiv(b);
579 APInt one(a.getBitWidth(), 1,
true);
580 return quotient.uadd_ov(one, overflowOrDiv0);
583 return overflowOrDiv0 ?
Attribute() : result;
596 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
602 bool overflowOrDiv0 =
false;
603 auto result = constFoldBinaryOp<IntegerAttr>(
604 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
605 if (overflowOrDiv0 || !b) {
606 overflowOrDiv0 = true;
612 unsigned bits = a.getBitWidth();
614 bool aGtZero = a.sgt(zero);
615 bool bGtZero = b.sgt(zero);
616 if (aGtZero && bGtZero) {
620 if (!aGtZero && !bGtZero) {
622 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
623 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
626 if (!aGtZero && bGtZero) {
628 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
629 APInt div = posA.sdiv_ov(b, overflowOrDiv0);
630 return zero.ssub_ov(div, overflowOrDiv0);
633 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
634 APInt div = a.sdiv_ov(posB, overflowOrDiv0);
635 return zero.ssub_ov(div, overflowOrDiv0);
638 return overflowOrDiv0 ?
Attribute() : result;
642 bool mayHaveUB =
true;
648 mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
657 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
663 bool overflowOrDiv0 =
false;
664 auto result = constFoldBinaryOp<IntegerAttr>(
665 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
666 if (overflowOrDiv0 || !b) {
667 overflowOrDiv0 = true;
673 unsigned bits = a.getBitWidth();
675 bool aGtZero = a.sgt(zero);
676 bool bGtZero = b.sgt(zero);
677 if (aGtZero && bGtZero) {
679 return a.sdiv_ov(b, overflowOrDiv0);
681 if (!aGtZero && !bGtZero) {
683 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
684 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
685 return posA.sdiv_ov(posB, overflowOrDiv0);
687 if (!aGtZero && bGtZero) {
689 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
691 return zero.ssub_ov(
ceil, overflowOrDiv0);
694 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
696 return zero.ssub_ov(
ceil, overflowOrDiv0);
699 return overflowOrDiv0 ?
Attribute() : result;
706 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
713 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
714 [&](APInt a,
const APInt &b) {
715 if (div0 || b.isNullValue()) {
729 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
736 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
737 [&](APInt a,
const APInt &b) {
738 if (div0 || b.isNullValue()) {
754 for (
bool reversePrev : {
false,
true}) {
755 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
756 .getDefiningOp<arith::AndIOp>();
760 Value other = (reversePrev ? op.getLhs() : op.getRhs());
761 if (other != prev.getLhs() && other != prev.getRhs())
764 return prev.getResult();
780 intValue.isAllOnes())
785 intValue.isAllOnes())
792 return constFoldBinaryOp<IntegerAttr>(
793 adaptor.getOperands(),
794 [](APInt a,
const APInt &b) { return std::move(a) & b; });
806 if (
auto rhsAttr = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
807 if (rhsAttr.getValue().isAllOnes())
810 return constFoldBinaryOp<IntegerAttr>(
811 adaptor.getOperands(),
812 [](APInt a,
const APInt &b) { return std::move(a) | b; });
824 if (getLhs() == getRhs())
828 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
829 if (prev.getRhs() == getRhs())
830 return prev.getLhs();
831 if (prev.getLhs() == getRhs())
832 return prev.getRhs();
836 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
837 if (prev.getRhs() == getLhs())
838 return prev.getLhs();
839 if (prev.getLhs() == getLhs())
840 return prev.getRhs();
843 return constFoldBinaryOp<IntegerAttr>(
844 adaptor.getOperands(),
845 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
850 patterns.
add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
859 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
860 return op.getOperand();
861 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
862 [](
const APFloat &a) { return -a; });
874 return constFoldBinaryOp<FloatAttr>(
875 adaptor.getOperands(),
876 [](
const APFloat &a,
const APFloat &b) { return a + b; });
888 return constFoldBinaryOp<FloatAttr>(
889 adaptor.getOperands(),
890 [](
const APFloat &a,
const APFloat &b) { return a - b; });
899 if (getLhs() == getRhs())
906 return constFoldBinaryOp<FloatAttr>(
907 adaptor.getOperands(),
908 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
917 if (getLhs() == getRhs())
923 intValue.isMaxSignedValue())
928 intValue.isMinSignedValue())
931 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
932 [](
const APInt &a,
const APInt &b) {
933 return llvm::APIntOps::smax(a, b);
943 if (getLhs() == getRhs())
955 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
956 [](
const APInt &a,
const APInt &b) {
957 return llvm::APIntOps::umax(a, b);
967 if (getLhs() == getRhs())
974 return constFoldBinaryOp<FloatAttr>(
975 adaptor.getOperands(),
976 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
985 if (getLhs() == getRhs())
991 intValue.isMinSignedValue())
996 intValue.isMaxSignedValue())
999 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1000 [](
const APInt &a,
const APInt &b) {
1001 return llvm::APIntOps::smin(a, b);
1011 if (getLhs() == getRhs())
1023 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1024 [](
const APInt &a,
const APInt &b) {
1025 return llvm::APIntOps::umin(a, b);
1033 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1038 return constFoldBinaryOp<FloatAttr>(
1039 adaptor.getOperands(),
1040 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1045 patterns.
add<MulFOfNegF>(context);
1052 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1057 return constFoldBinaryOp<FloatAttr>(
1058 adaptor.getOperands(),
1059 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1064 patterns.
add<DivFOfNegF>(context);
1071 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1072 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1073 [](
const APFloat &a,
const APFloat &b) {
1075 (void)result.remainder(b);
1084 template <
typename... Types>
1090 template <
typename... ShapedTypes,
typename... ElementTypes>
1093 if (type.
isa<ShapedType>() && !type.
isa<ShapedTypes...>())
1097 if (!underlyingType.isa<ElementTypes...>())
1100 return underlyingType;
1104 template <
typename... ElementTypes>
1111 template <
typename... ElementTypes>
1119 return inputs.size() == 1 && outputs.size() == 1 &&
1128 template <
typename ValType,
typename Op>
1133 if (srcType.
cast<ValType>().getWidth() >= dstType.
cast<ValType>().getWidth())
1135 << dstType <<
" must be wider than operand type " << srcType;
1141 template <
typename ValType,
typename Op>
1146 if (srcType.
cast<ValType>().getWidth() <= dstType.
cast<ValType>().getWidth())
1148 << dstType <<
" must be shorter than operand type " << srcType;
1154 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1159 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1160 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1161 if (!srcType || !dstType)
1164 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1165 srcType.getIntOrFloatBitWidth());
1172 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1173 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1174 getInMutable().assign(lhs.getIn());
1179 unsigned bitWidth = resType.
cast<IntegerType>().getWidth();
1180 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1181 adaptor.getOperands(), getType(),
1182 [bitWidth](
const APInt &a,
bool &castStatus) {
1183 return a.zext(bitWidth);
1188 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1192 return verifyExtOp<IntegerType>(*
this);
1199 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1200 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1201 getInMutable().assign(lhs.getIn());
1206 unsigned bitWidth = resType.
cast<IntegerType>().getWidth();
1207 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1208 adaptor.getOperands(), getType(),
1209 [bitWidth](
const APInt &a,
bool &castStatus) {
1210 return a.sext(bitWidth);
1215 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1220 patterns.
add<ExtSIOfExtUI>(context);
1224 return verifyExtOp<IntegerType>(*
this);
1232 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1241 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1244 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1246 return getOperand().getDefiningOp()->getOperand(0);
1249 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1250 setOperand(getOperand().getDefiningOp()->getOperand(0));
1255 unsigned bitWidth = resType.
cast<IntegerType>().getWidth();
1256 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1257 adaptor.getOperands(), getType(),
1258 [bitWidth](
const APInt &a,
bool &castStatus) {
1259 return a.trunc(bitWidth);
1264 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1269 patterns.
add<TruncIShrSIToTrunciShrUI, TruncIShrUIMulIToMulSIExtended,
1270 TruncIShrUIMulIToMulUIExtended>(context);
1274 return verifyTruncateOp<IntegerType>(*
this);
1283 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1284 auto constOperand = adaptor.getIn();
1285 if (!constOperand || !constOperand.isa<FloatAttr>())
1289 double sourceValue =
1290 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
1291 auto targetAttr = FloatAttr::get(getType(), sourceValue);
1294 if (sourceValue == targetAttr.getValue().convertToDouble())
1301 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1305 return verifyTruncateOp<FloatType>(*
this);
1314 patterns.
add<AndOfExtUI, AndOfExtSI>(context);
1323 patterns.
add<OrOfExtUI, OrOfExtSI>(context);
1330 template <
typename From,
typename To>
1335 auto srcType = getTypeIfLike<From>(inputs.front());
1336 auto dstType = getTypeIfLike<To>(outputs.back());
1338 return srcType && dstType;
1346 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1349 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1351 return constFoldCastOp<IntegerAttr, FloatAttr>(
1352 adaptor.getOperands(), getType(),
1353 [&resEleType](
const APInt &a,
bool &castStatus) {
1357 apf.convertFromAPInt(a,
false,
1358 APFloat::rmNearestTiesToEven);
1368 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1371 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1373 return constFoldCastOp<IntegerAttr, FloatAttr>(
1374 adaptor.getOperands(), getType(),
1375 [&resEleType](
const APInt &a,
bool &castStatus) {
1379 apf.convertFromAPInt(a,
true,
1380 APFloat::rmNearestTiesToEven);
1389 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1392 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1394 unsigned bitWidth = resType.
cast<IntegerType>().getWidth();
1395 return constFoldCastOp<FloatAttr, IntegerAttr>(
1396 adaptor.getOperands(), getType(),
1397 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1399 APSInt api(bitWidth,
true);
1400 castStatus = APFloat::opInvalidOp !=
1401 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1411 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1414 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1416 unsigned bitWidth = resType.
cast<IntegerType>().getWidth();
1417 return constFoldCastOp<FloatAttr, IntegerAttr>(
1418 adaptor.getOperands(), getType(),
1419 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1421 APSInt api(bitWidth,
false);
1422 castStatus = APFloat::opInvalidOp !=
1423 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1436 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1437 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1438 if (!srcType || !dstType)
1441 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1442 (srcType.isSignlessInteger() && dstType.isIndex());
1445 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1450 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1454 if (
auto value = adaptor.getIn().dyn_cast_or_null<IntegerAttr>())
1455 return IntegerAttr::get(getType(), value.getInt());
1460 void arith::IndexCastOp::getCanonicalizationPatterns(
1462 patterns.
add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1469 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1474 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1478 if (
auto value = adaptor.getIn().dyn_cast_or_null<IntegerAttr>())
1479 return IntegerAttr::get(getType(), value.getValue().getZExtValue());
1484 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1486 patterns.
add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1498 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1500 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1501 if (!srcType || !dstType)
1504 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1507 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1508 auto resType = getType();
1509 auto operand = adaptor.getIn();
1515 return denseAttr.bitcast(resType.
cast<ShapedType>().getElementType());
1517 if (resType.
isa<ShapedType>())
1521 APInt bits = operand.isa<FloatAttr>()
1522 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1523 : operand.cast<IntegerAttr>().getValue();
1526 return FloatAttr::get(resType,
1527 APFloat(resFloatType.getFloatSemantics(), bits));
1528 return IntegerAttr::get(resType, bits);
1533 patterns.
add<BitcastOfBitcast>(context);
1542 auto i1Type = IntegerType::get(type.
getContext(), 1);
1543 if (
auto tensorType = type.
dyn_cast<RankedTensorType>())
1544 return RankedTensorType::get(tensorType.getShape(), i1Type);
1545 if (type.
isa<UnrankedTensorType>())
1546 return UnrankedTensorType::get(i1Type);
1547 if (
auto vectorType = type.
dyn_cast<VectorType>())
1548 return VectorType::get(vectorType.getShape(), i1Type,
1549 vectorType.getNumScalableDims());
1560 const APInt &lhs,
const APInt &rhs) {
1561 switch (predicate) {
1562 case arith::CmpIPredicate::eq:
1564 case arith::CmpIPredicate::ne:
1566 case arith::CmpIPredicate::slt:
1567 return lhs.slt(rhs);
1568 case arith::CmpIPredicate::sle:
1569 return lhs.sle(rhs);
1570 case arith::CmpIPredicate::sgt:
1571 return lhs.sgt(rhs);
1572 case arith::CmpIPredicate::sge:
1573 return lhs.sge(rhs);
1574 case arith::CmpIPredicate::ult:
1575 return lhs.ult(rhs);
1576 case arith::CmpIPredicate::ule:
1577 return lhs.ule(rhs);
1578 case arith::CmpIPredicate::ugt:
1579 return lhs.ugt(rhs);
1580 case arith::CmpIPredicate::uge:
1581 return lhs.uge(rhs);
1583 llvm_unreachable(
"unknown cmpi predicate kind");
1588 switch (predicate) {
1589 case arith::CmpIPredicate::eq:
1590 case arith::CmpIPredicate::sle:
1591 case arith::CmpIPredicate::sge:
1592 case arith::CmpIPredicate::ule:
1593 case arith::CmpIPredicate::uge:
1595 case arith::CmpIPredicate::ne:
1596 case arith::CmpIPredicate::slt:
1597 case arith::CmpIPredicate::sgt:
1598 case arith::CmpIPredicate::ult:
1599 case arith::CmpIPredicate::ugt:
1602 llvm_unreachable(
"unknown cmpi predicate kind");
1606 auto boolAttr = BoolAttr::get(ctx, value);
1610 return DenseElementsAttr::get(shapedType, boolAttr);
1614 if (
auto intType = t.
dyn_cast<IntegerType>()) {
1615 return intType.getWidth();
1617 if (
auto vectorIntType = t.
dyn_cast<VectorType>()) {
1618 return vectorIntType.getElementType().cast<IntegerType>().getWidth();
1620 return std::nullopt;
1623 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1625 if (getLhs() == getRhs()) {
1631 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1633 std::optional<int64_t> integerWidth =
1635 if (integerWidth && integerWidth.value() == 1 &&
1636 getPredicate() == arith::CmpIPredicate::ne)
1637 return extOp.getOperand();
1639 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1641 std::optional<int64_t> integerWidth =
1643 if (integerWidth && integerWidth.value() == 1 &&
1644 getPredicate() == arith::CmpIPredicate::ne)
1645 return extOp.getOperand();
1650 if (adaptor.getLhs() && !adaptor.getRhs()) {
1652 using Pred = CmpIPredicate;
1653 const std::pair<Pred, Pred> invPreds[] = {
1654 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1655 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1656 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1657 {Pred::ne, Pred::ne},
1659 Pred origPred = getPredicate();
1660 for (
auto pred : invPreds) {
1661 if (origPred == pred.first) {
1662 setPredicate(pred.second);
1663 Value lhs = getLhs();
1664 Value rhs = getRhs();
1665 getLhsMutable().assign(rhs);
1666 getRhsMutable().assign(lhs);
1670 llvm_unreachable(
"unknown cmpi predicate kind");
1673 auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
1679 auto rhs = adaptor.getRhs().cast<IntegerAttr>();
1682 return BoolAttr::get(getContext(), val);
1687 patterns.
insert<CmpIExtSI, CmpIExtUI>(context);
1697 const APFloat &lhs,
const APFloat &rhs) {
1698 auto cmpResult = lhs.compare(rhs);
1699 switch (predicate) {
1700 case arith::CmpFPredicate::AlwaysFalse:
1702 case arith::CmpFPredicate::OEQ:
1703 return cmpResult == APFloat::cmpEqual;
1704 case arith::CmpFPredicate::OGT:
1705 return cmpResult == APFloat::cmpGreaterThan;
1706 case arith::CmpFPredicate::OGE:
1707 return cmpResult == APFloat::cmpGreaterThan ||
1708 cmpResult == APFloat::cmpEqual;
1709 case arith::CmpFPredicate::OLT:
1710 return cmpResult == APFloat::cmpLessThan;
1711 case arith::CmpFPredicate::OLE:
1712 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1713 case arith::CmpFPredicate::ONE:
1714 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1715 case arith::CmpFPredicate::ORD:
1716 return cmpResult != APFloat::cmpUnordered;
1717 case arith::CmpFPredicate::UEQ:
1718 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1719 case arith::CmpFPredicate::UGT:
1720 return cmpResult == APFloat::cmpUnordered ||
1721 cmpResult == APFloat::cmpGreaterThan;
1722 case arith::CmpFPredicate::UGE:
1723 return cmpResult == APFloat::cmpUnordered ||
1724 cmpResult == APFloat::cmpGreaterThan ||
1725 cmpResult == APFloat::cmpEqual;
1726 case arith::CmpFPredicate::ULT:
1727 return cmpResult == APFloat::cmpUnordered ||
1728 cmpResult == APFloat::cmpLessThan;
1729 case arith::CmpFPredicate::ULE:
1730 return cmpResult == APFloat::cmpUnordered ||
1731 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1732 case arith::CmpFPredicate::UNE:
1733 return cmpResult != APFloat::cmpEqual;
1734 case arith::CmpFPredicate::UNO:
1735 return cmpResult == APFloat::cmpUnordered;
1736 case arith::CmpFPredicate::AlwaysTrue:
1739 llvm_unreachable(
"unknown cmpf predicate kind");
1742 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1743 auto lhs = adaptor.getLhs().dyn_cast_or_null<FloatAttr>();
1744 auto rhs = adaptor.getRhs().dyn_cast_or_null<FloatAttr>();
1747 if (lhs && lhs.getValue().isNaN())
1749 if (rhs && rhs.getValue().isNaN())
1756 return BoolAttr::get(getContext(), val);
1765 using namespace arith;
1767 case CmpFPredicate::UEQ:
1768 case CmpFPredicate::OEQ:
1769 return CmpIPredicate::eq;
1770 case CmpFPredicate::UGT:
1771 case CmpFPredicate::OGT:
1772 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1773 case CmpFPredicate::UGE:
1774 case CmpFPredicate::OGE:
1775 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1776 case CmpFPredicate::ULT:
1777 case CmpFPredicate::OLT:
1778 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1779 case CmpFPredicate::ULE:
1780 case CmpFPredicate::OLE:
1781 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1782 case CmpFPredicate::UNE:
1783 case CmpFPredicate::ONE:
1784 return CmpIPredicate::ne;
1786 llvm_unreachable(
"Unexpected predicate!");
1796 const APFloat &rhs = flt.getValue();
1806 if (mantissaWidth <= 0)
1812 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1814 intVal = si.getIn();
1815 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1817 intVal = ui.getIn();
1824 auto intTy = intVal.
getType().
cast<IntegerType>();
1825 auto intWidth = intTy.getWidth();
1828 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1833 if ((
int)intWidth > mantissaWidth) {
1835 int exponent = ilogb(rhs);
1836 if (exponent == APFloat::IEK_Inf) {
1837 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1838 if (maxExponent < (
int)valueBits) {
1845 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
1854 switch (op.getPredicate()) {
1855 case CmpFPredicate::ORD:
1860 case CmpFPredicate::UNO:
1873 APFloat signedMax(rhs.getSemantics());
1874 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
1875 APFloat::rmNearestTiesToEven);
1876 if (signedMax < rhs) {
1877 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1878 pred == CmpIPredicate::sle)
1889 APFloat unsignedMax(rhs.getSemantics());
1890 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
1891 APFloat::rmNearestTiesToEven);
1892 if (unsignedMax < rhs) {
1893 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1894 pred == CmpIPredicate::ule)
1906 APFloat signedMin(rhs.getSemantics());
1907 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
1908 APFloat::rmNearestTiesToEven);
1909 if (signedMin > rhs) {
1910 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1911 pred == CmpIPredicate::sge)
1921 APFloat unsignedMin(rhs.getSemantics());
1922 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
1923 APFloat::rmNearestTiesToEven);
1924 if (unsignedMin > rhs) {
1925 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1926 pred == CmpIPredicate::uge)
1941 APSInt rhsInt(intWidth, isUnsigned);
1942 if (APFloat::opInvalidOp ==
1943 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1949 if (!rhs.isZero()) {
1952 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1954 bool equal = apf == rhs;
1960 case CmpIPredicate::ne:
1964 case CmpIPredicate::eq:
1968 case CmpIPredicate::ule:
1971 if (rhs.isNegative()) {
1977 case CmpIPredicate::sle:
1980 if (rhs.isNegative())
1981 pred = CmpIPredicate::slt;
1983 case CmpIPredicate::ult:
1986 if (rhs.isNegative()) {
1991 pred = CmpIPredicate::ule;
1993 case CmpIPredicate::slt:
1996 if (!rhs.isNegative())
1997 pred = CmpIPredicate::sle;
1999 case CmpIPredicate::ugt:
2002 if (rhs.isNegative()) {
2008 case CmpIPredicate::sgt:
2011 if (rhs.isNegative())
2012 pred = CmpIPredicate::sge;
2014 case CmpIPredicate::uge:
2017 if (rhs.isNegative()) {
2022 pred = CmpIPredicate::ugt;
2024 case CmpIPredicate::sge:
2027 if (!rhs.isNegative())
2028 pred = CmpIPredicate::sgt;
2038 rewriter.
create<ConstantOp>(
2039 op.getLoc(), intVal.
getType(),
2066 if (!op.getType().isInteger(1))
2069 Value falseConstant =
2070 rewriter.
create<arith::ConstantIntOp>(op.getLoc(),
true, 1);
2071 Value notCondition = rewriter.
create<arith::XOrIOp>(
2072 op.getLoc(), op.getCondition(), falseConstant);
2075 op.getLoc(), op.getCondition(), op.getTrueValue());
2076 Value falseVal = rewriter.
create<arith::AndIOp>(op.getLoc(), notCondition,
2077 op.getFalseValue());
2090 if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
2106 rewriter.
create<arith::XOrIOp>(
2107 op.getLoc(), op.getCondition(),
2108 rewriter.
create<arith::ConstantIntOp>(
2109 op.getLoc(), 1, op.getCondition().getType())));
2122 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2123 Value trueVal = getTrueValue();
2124 Value falseVal = getFalseValue();
2125 if (trueVal == falseVal)
2128 Value condition = getCondition();
2143 if (
auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.
getDefiningOp())) {
2144 auto pred = cmp.getPredicate();
2145 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2146 auto cmpLhs = cmp.getLhs();
2147 auto cmpRhs = cmp.getRhs();
2155 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2156 (cmpRhs == trueVal && cmpLhs == falseVal))
2157 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2164 Type conditionType, resultType;
2173 conditionType = resultType;
2182 {conditionType, resultType, resultType},
2187 p <<
" " << getOperands();
2190 if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
2191 p << condType <<
", ";
2196 Type conditionType = getCondition().getType();
2202 Type resultType = getType();
2204 return emitOpError() <<
"expected condition to be a signless i1, but got "
2207 if (conditionType != shapedConditionType) {
2208 return emitOpError() <<
"expected condition type to have the same shape "
2209 "as the result type, expected "
2210 << shapedConditionType <<
", but got "
2219 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2224 bool bounded =
false;
2225 auto result = constFoldBinaryOp<IntegerAttr>(
2226 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2227 bounded = b.ule(b.getBitWidth());
2237 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2242 bool bounded =
false;
2243 auto result = constFoldBinaryOp<IntegerAttr>(
2244 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2245 bounded = b.ule(b.getBitWidth());
2255 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2260 bool bounded =
false;
2261 auto result = constFoldBinaryOp<IntegerAttr>(
2262 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2263 bounded = b.ule(b.getBitWidth());
2277 case AtomicRMWKind::maxf:
2282 case AtomicRMWKind::addf:
2283 case AtomicRMWKind::addi:
2284 case AtomicRMWKind::maxu:
2285 case AtomicRMWKind::ori:
2287 case AtomicRMWKind::andi:
2290 APInt::getAllOnes(resultType.
cast<IntegerType>().getWidth()));
2291 case AtomicRMWKind::maxs:
2294 APInt::getSignedMinValue(resultType.
cast<IntegerType>().getWidth()));
2295 case AtomicRMWKind::minf:
2300 case AtomicRMWKind::mins:
2303 APInt::getSignedMaxValue(resultType.
cast<IntegerType>().getWidth()));
2304 case AtomicRMWKind::minu:
2307 APInt::getMaxValue(resultType.
cast<IntegerType>().getWidth()));
2308 case AtomicRMWKind::muli:
2310 case AtomicRMWKind::mulf:
2324 return builder.
create<arith::ConstantOp>(loc, attr);
2332 case AtomicRMWKind::addf:
2333 return builder.
create<arith::AddFOp>(loc, lhs, rhs);
2334 case AtomicRMWKind::addi:
2335 return builder.
create<arith::AddIOp>(loc, lhs, rhs);
2336 case AtomicRMWKind::mulf:
2337 return builder.
create<arith::MulFOp>(loc, lhs, rhs);
2338 case AtomicRMWKind::muli:
2339 return builder.
create<arith::MulIOp>(loc, lhs, rhs);
2340 case AtomicRMWKind::maxf:
2341 return builder.
create<arith::MaxFOp>(loc, lhs, rhs);
2342 case AtomicRMWKind::minf:
2343 return builder.
create<arith::MinFOp>(loc, lhs, rhs);
2344 case AtomicRMWKind::maxs:
2345 return builder.
create<arith::MaxSIOp>(loc, lhs, rhs);
2346 case AtomicRMWKind::mins:
2347 return builder.
create<arith::MinSIOp>(loc, lhs, rhs);
2348 case AtomicRMWKind::maxu:
2349 return builder.
create<arith::MaxUIOp>(loc, lhs, rhs);
2350 case AtomicRMWKind::minu:
2351 return builder.
create<arith::MinUIOp>(loc, lhs, rhs);
2352 case AtomicRMWKind::ori:
2353 return builder.
create<arith::OrIOp>(loc, lhs, rhs);
2354 case AtomicRMWKind::andi:
2355 return builder.
create<arith::AndIOp>(loc, lhs, rhs);
2368 #define GET_OP_CLASSES
2369 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2375 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value)
std::tuple< Types... > * type_list
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
bool isa() const
Casting utility functions.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Attribute getZeroAttr(Type type)
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
U dyn_cast_or_null() const
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
Specialization of arith.constant op that returns an integer value.
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc)
Returns the identity value associated with an AtomicRMWKind op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
MPInt ceil(const Fraction &f)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)