26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
44 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
45 APInt lhsVal = llvm::cast<IntegerAttr>(
lhs).getValue();
46 APInt rhsVal = llvm::cast<IntegerAttr>(
rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
48 return IntegerAttr::get(res.
getType(), value);
67static IntegerOverflowFlagsAttr
69 IntegerOverflowFlagsAttr val2) {
70 return IntegerOverflowFlagsAttr::get(val1.getContext(),
71 val1.getValue() & val2.getValue());
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
98 llvm_unreachable(
"unknown cmpi predicate kind");
107static llvm::RoundingMode
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
121 llvm_unreachable(
"Unhandled rounding mode");
125 return arith::CmpIPredicateAttr::get(pred.getContext(),
151 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
155 if (!shapedType.hasStaticShape())
165#include "ArithCanonicalization.inc"
174 auto i1Type = IntegerType::get(type.
getContext(), 1);
175 if (
auto shapedType = dyn_cast<ShapedType>(type))
176 return shapedType.cloneWith(std::nullopt, i1Type);
177 if (llvm::isa<UnrankedTensorType>(type))
178 return UnrankedTensorType::get(i1Type);
186void arith::ConstantOp::getAsmResultNames(
189 if (
auto intCst = dyn_cast<IntegerAttr>(getValue())) {
190 auto intType = dyn_cast<IntegerType>(type);
193 if (intType && intType.getWidth() == 1)
194 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
197 SmallString<32> specialNameBuffer;
198 llvm::raw_svector_ostream specialName(specialNameBuffer);
199 specialName <<
'c' << intCst.getValue();
201 specialName <<
'_' << type;
202 setNameFn(getResult(), specialName.str());
204 setNameFn(getResult(),
"cst");
210LogicalResult arith::ConstantOp::verify() {
213 if (llvm::isa<IntegerType>(type) &&
214 !llvm::cast<IntegerType>(type).isSignless())
215 return emitOpError(
"integer return type must be signless");
217 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
219 "value must be an integer, float, or elements attribute");
225 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
227 "initializing scalable vectors with elements attribute is not supported"
228 " unless it's a vector splat");
232bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
234 auto typedAttr = dyn_cast<TypedAttr>(value);
235 if (!typedAttr || typedAttr.getType() != type)
239 if (!intType.isSignless())
243 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
246ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
247 Type type, Location loc) {
248 if (isBuildableWith(value, type))
249 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
253OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
258 arith::ConstantOp::build(builder,
result, type,
268 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
269 assert(
result &&
"builder didn't return the right type");
281 arith::ConstantOp::build(builder,
result, type,
290 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
291 assert(
result &&
"builder didn't return the right type");
302 arith::ConstantOp::build(builder,
result, type,
308 const APInt &
value) {
311 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
312 assert(
result &&
"builder didn't return the right type");
318 const APInt &
value) {
323 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
324 return constOp.getType().isSignlessInteger();
329 FloatType type,
const APFloat &
value) {
330 arith::ConstantOp::build(builder,
result, type,
337 const APFloat &
value) {
340 auto result = dyn_cast<ConstantFloatOp>(builder.
create(state));
341 assert(
result &&
"builder didn't return the right type");
347 const APFloat &
value) {
352 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
353 return llvm::isa<FloatType>(constOp.getType());
368 auto result = dyn_cast<ConstantIndexOp>(builder.
create(state));
369 assert(
result &&
"builder didn't return the right type");
379 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
380 return constOp.getType().isIndex();
388 "type doesn't have a zero representation");
390 assert(zeroAttr &&
"unsupported type for zero attribute");
391 return arith::ConstantOp::create(builder, loc, zeroAttr);
404 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
405 if (getRhs() == sub.getRhs())
409 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
410 if (getLhs() == sub.getRhs())
414 adaptor.getOperands(),
415 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
420 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
421 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
428std::optional<SmallVector<int64_t, 4>>
429arith::AddUIExtendedOp::getShapeForUnroll() {
430 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
431 return llvm::to_vector<4>(vt.getShape());
438 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
442arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
443 SmallVectorImpl<OpFoldResult> &results) {
444 Type overflowTy = getOverflow().getType();
450 results.push_back(getLhs());
451 results.push_back(falseValue);
460 adaptor.getOperands(),
461 [](APInt a,
const APInt &
b) { return std::move(a) + b; })) {
463 ArrayRef({sumAttr, adaptor.getLhs()}),
469 results.push_back(sumAttr);
470 results.push_back(overflowAttr);
477void arith::AddUIExtendedOp::getCanonicalizationPatterns(
478 RewritePatternSet &
patterns, MLIRContext *context) {
479 patterns.add<AddUIExtendedToAddI>(context);
486OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
488 if (getOperand(0) == getOperand(1)) {
489 auto shapedType = dyn_cast<ShapedType>(
getType());
491 if (!shapedType || shapedType.hasStaticShape())
498 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
500 if (getRhs() ==
add.getRhs())
503 if (getRhs() ==
add.getLhs())
508 adaptor.getOperands(),
509 [](APInt a,
const APInt &
b) { return std::move(a) - b; });
512void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
513 MLIRContext *context) {
514 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
515 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
516 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
523OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
534 adaptor.getOperands(),
535 [](
const APInt &a,
const APInt &
b) { return a * b; });
538void arith::MulIOp::getAsmResultNames(
540 if (!isa<IndexType>(
getType()))
545 auto isVscale = [](Operation *op) {
546 return op && op->getName().getStringRef() ==
"vector.vscale";
549 IntegerAttr baseValue;
550 auto isVscaleExpr = [&](Value a, Value
b) {
552 isVscale(
b.getDefiningOp());
555 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
559 SmallString<32> specialNameBuffer;
560 llvm::raw_svector_ostream specialName(specialNameBuffer);
561 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
562 setNameFn(getResult(), specialName.str());
565void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
566 MLIRContext *context) {
567 patterns.add<MulIMulIConstant>(context);
574std::optional<SmallVector<int64_t, 4>>
575arith::MulSIExtendedOp::getShapeForUnroll() {
576 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
577 return llvm::to_vector<4>(vt.getShape());
582arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
583 SmallVectorImpl<OpFoldResult> &results) {
586 Attribute zero = adaptor.getRhs();
587 results.push_back(zero);
588 results.push_back(zero);
594 adaptor.getOperands(),
595 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
598 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
599 return llvm::APIntOps::mulhs(a, b);
601 assert(highAttr &&
"Unexpected constant-folding failure");
603 results.push_back(lowAttr);
604 results.push_back(highAttr);
611void arith::MulSIExtendedOp::getCanonicalizationPatterns(
612 RewritePatternSet &
patterns, MLIRContext *context) {
613 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
620std::optional<SmallVector<int64_t, 4>>
621arith::MulUIExtendedOp::getShapeForUnroll() {
622 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
623 return llvm::to_vector<4>(vt.getShape());
628arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
629 SmallVectorImpl<OpFoldResult> &results) {
632 Attribute zero = adaptor.getRhs();
633 results.push_back(zero);
634 results.push_back(zero);
642 results.push_back(getLhs());
643 results.push_back(zero);
649 adaptor.getOperands(),
650 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
653 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
654 return llvm::APIntOps::mulhu(a, b);
656 assert(highAttr &&
"Unexpected constant-folding failure");
658 results.push_back(lowAttr);
659 results.push_back(highAttr);
666void arith::MulUIExtendedOp::getCanonicalizationPatterns(
667 RewritePatternSet &
patterns, MLIRContext *context) {
668 patterns.add<MulUIExtendedToMulI>(context);
677 arith::IntegerOverflowFlags ovfFlags) {
678 auto mul =
lhs.getDefiningOp<mlir::arith::MulIOp>();
679 if (!
mul || !bitEnumContainsAll(
mul.getOverflowFlags(), ovfFlags))
691OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
697 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
703 [&](APInt a,
const APInt &
b) {
711 return div0 ? Attribute() :
result;
731OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
737 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
741 bool overflowOrDiv0 =
false;
743 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
744 if (overflowOrDiv0 || !b) {
745 overflowOrDiv0 = true;
748 return a.sdiv_ov(
b, overflowOrDiv0);
751 return overflowOrDiv0 ? Attribute() :
result;
778 APInt one(a.getBitWidth(), 1,
true);
779 APInt val = a.ssub_ov(one, overflow).sdiv_ov(
b, overflow);
780 return val.sadd_ov(one, overflow);
787OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
792 bool overflowOrDiv0 =
false;
794 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
795 if (overflowOrDiv0 || !b) {
796 overflowOrDiv0 = true;
799 APInt quotient = a.udiv(
b);
802 APInt one(a.getBitWidth(), 1,
true);
803 return quotient.uadd_ov(one, overflowOrDiv0);
806 return overflowOrDiv0 ? Attribute() :
result;
817OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
825 bool overflowOrDiv0 =
false;
827 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
828 if (overflowOrDiv0 || !b) {
829 overflowOrDiv0 = true;
835 unsigned bits = a.getBitWidth();
836 APInt zero = APInt::getZero(bits);
837 bool aGtZero = a.sgt(zero);
838 bool bGtZero =
b.sgt(zero);
839 if (aGtZero && bGtZero) {
846 bool overflowNegA =
false;
847 bool overflowNegB =
false;
848 bool overflowDiv =
false;
849 bool overflowNegRes =
false;
850 if (!aGtZero && !bGtZero) {
852 APInt posA = zero.ssub_ov(a, overflowNegA);
853 APInt posB = zero.ssub_ov(
b, overflowNegB);
855 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
858 if (!aGtZero && bGtZero) {
860 APInt posA = zero.ssub_ov(a, overflowNegA);
861 APInt
div = posA.sdiv_ov(
b, overflowDiv);
862 APInt res = zero.ssub_ov(
div, overflowNegRes);
863 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
867 APInt posB = zero.ssub_ov(
b, overflowNegB);
868 APInt
div = a.sdiv_ov(posB, overflowDiv);
869 APInt res = zero.ssub_ov(
div, overflowNegRes);
871 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
875 return overflowOrDiv0 ? Attribute() :
result;
886OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
892 bool overflowOrDiv =
false;
894 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
896 overflowOrDiv = true;
899 return a.sfloordiv_ov(
b, overflowOrDiv);
902 return overflowOrDiv ? Attribute() :
result;
909OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
917 [&](APInt a,
const APInt &
b) {
918 if (div0 || b.isZero()) {
925 return div0 ? Attribute() :
result;
932OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
940 [&](APInt a,
const APInt &
b) {
941 if (div0 || b.isZero()) {
948 return div0 ? Attribute() :
result;
957 for (
bool reversePrev : {
false,
true}) {
958 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
959 .getDefiningOp<arith::AndIOp>();
963 Value other = (reversePrev ? op.getLhs() : op.getRhs());
964 if (other != prev.getLhs() && other != prev.getRhs())
967 return prev.getResult();
972OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
979 intValue.isAllOnes())
984 intValue.isAllOnes())
989 intValue.isAllOnes())
997 adaptor.getOperands(),
998 [](APInt a,
const APInt &
b) { return std::move(a) & b; });
1005OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1008 if (rhsVal.isZero())
1011 if (rhsVal.isAllOnes())
1012 return adaptor.getRhs();
1019 intValue.isAllOnes())
1020 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1024 intValue.isAllOnes())
1025 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1028 adaptor.getOperands(),
1029 [](APInt a,
const APInt &
b) { return std::move(a) | b; });
1036OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1041 if (getLhs() == getRhs())
1045 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1046 if (prev.getRhs() == getRhs())
1047 return prev.getLhs();
1048 if (prev.getLhs() == getRhs())
1049 return prev.getRhs();
1053 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1054 if (prev.getRhs() == getLhs())
1055 return prev.getLhs();
1056 if (prev.getLhs() == getLhs())
1057 return prev.getRhs();
1061 adaptor.getOperands(),
1062 [](APInt a,
const APInt &
b) { return std::move(a) ^ b; });
1065void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1066 MLIRContext *context) {
1067 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1074OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1076 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1077 return op.getOperand();
1079 [](
const APFloat &a) { return -a; });
1086OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1092 adaptor.getOperands(),
1093 [](
const APFloat &a,
const APFloat &
b) { return a + b; });
1100OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1106 adaptor.getOperands(),
1107 [](
const APFloat &a,
const APFloat &
b) { return a - b; });
1114OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1116 if (getLhs() == getRhs())
1124 adaptor.getOperands(),
1125 [](
const APFloat &a,
const APFloat &
b) { return llvm::maximum(a, b); });
1132OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1134 if (getLhs() == getRhs())
1148OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1150 if (getLhs() == getRhs())
1156 if (intValue.isMaxSignedValue())
1159 if (intValue.isMinSignedValue())
1164 [](
const APInt &a,
const APInt &
b) {
1165 return llvm::APIntOps::smax(a, b);
1173OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1175 if (getLhs() == getRhs())
1181 if (intValue.isMaxValue())
1184 if (intValue.isMinValue())
1189 [](
const APInt &a,
const APInt &
b) {
1190 return llvm::APIntOps::umax(a, b);
1198OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1200 if (getLhs() == getRhs())
1208 adaptor.getOperands(),
1209 [](
const APFloat &a,
const APFloat &
b) { return llvm::minimum(a, b); });
1216OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1218 if (getLhs() == getRhs())
1226 adaptor.getOperands(),
1227 [](
const APFloat &a,
const APFloat &
b) { return llvm::minnum(a, b); });
1234OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1236 if (getLhs() == getRhs())
1242 if (intValue.isMinSignedValue())
1245 if (intValue.isMaxSignedValue())
1250 [](
const APInt &a,
const APInt &
b) {
1251 return llvm::APIntOps::smin(a, b);
1259OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1261 if (getLhs() == getRhs())
1267 if (intValue.isMinValue())
1270 if (intValue.isMaxValue())
1275 [](
const APInt &a,
const APInt &
b) {
1276 return llvm::APIntOps::umin(a, b);
1284OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1289 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1290 arith::FastMathFlags::nsz)) {
1297 adaptor.getOperands(),
1298 [](
const APFloat &a,
const APFloat &
b) { return a * b; });
1301void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1302 MLIRContext *context) {
1310OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1316 adaptor.getOperands(),
1317 [](
const APFloat &a,
const APFloat &
b) { return a / b; });
1320void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1321 MLIRContext *context) {
1329OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1331 [](
const APFloat &a,
const APFloat &
b) {
1336 (void)result.mod(b);
1345template <
typename... Types>
1351template <
typename... ShapedTypes,
typename... ElementTypes>
1354 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1358 if (!llvm::isa<ElementTypes...>(underlyingType))
1361 return underlyingType;
1365template <
typename... ElementTypes>
1372template <
typename... ElementTypes>
1381 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1382 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1383 if (!rankedTensorA || !rankedTensorB)
1385 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1389 if (inputs.size() != 1 || outputs.size() != 1)
1401template <
typename ValType,
typename Op>
1406 if (llvm::cast<ValType>(srcType).getWidth() >=
1407 llvm::cast<ValType>(dstType).getWidth())
1409 << dstType <<
" must be wider than operand type " << srcType;
1415template <
typename ValType,
typename Op>
1420 if (llvm::cast<ValType>(srcType).getWidth() <=
1421 llvm::cast<ValType>(dstType).getWidth())
1423 << dstType <<
" must be shorter than operand type " << srcType;
1429template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1434 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1435 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1436 if (!srcType || !dstType)
1439 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1440 srcType.getIntOrFloatBitWidth());
1446 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1447 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1448 bool losesInfo =
false;
1449 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1450 if (losesInfo || status != APFloat::opOK)
1460OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1461 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1462 getInMutable().assign(
lhs.getIn());
1467 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1469 adaptor.getOperands(),
getType(),
1470 [bitWidth](
const APInt &a,
bool &castStatus) {
1471 return a.zext(bitWidth);
1479LogicalResult arith::ExtUIOp::verify() {
1487OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1488 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1489 getInMutable().assign(
lhs.getIn());
1494 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1496 adaptor.getOperands(),
getType(),
1497 [bitWidth](
const APInt &a,
bool &castStatus) {
1498 return a.sext(bitWidth);
1506void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1507 MLIRContext *context) {
1508 patterns.add<ExtSIOfExtUI>(context);
1511LogicalResult arith::ExtSIOp::verify() {
1521OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1522 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1523 if (truncFOp.getOperand().getType() ==
getType()) {
1524 arith::FastMathFlags truncFMF =
1525 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1526 bool isTruncContract =
1527 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1528 arith::FastMathFlags extFMF =
1529 getFastmath().value_or(arith::FastMathFlags::none);
1530 bool isExtContract =
1531 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1532 if (isTruncContract && isExtContract) {
1533 return truncFOp.getOperand();
1539 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1541 adaptor.getOperands(),
getType(),
1542 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1562bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
1567LogicalResult arith::ScalingExtFOp::verify() {
1575OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1578 Value src = getOperand().getDefiningOp()->getOperand(0);
1583 if (llvm::cast<IntegerType>(srcType).getWidth() >
1584 llvm::cast<IntegerType>(dstType).getWidth()) {
1591 if (srcType == dstType)
1597 setOperand(getOperand().getDefiningOp()->getOperand(0));
1602 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1604 adaptor.getOperands(),
getType(),
1605 [bitWidth](
const APInt &a,
bool &castStatus) {
1606 return a.trunc(bitWidth);
1614void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1615 MLIRContext *context) {
1617 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1621LogicalResult arith::TruncIOp::verify() {
1631OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1633 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1634 Value src = extOp.getIn();
1636 auto intermediateType =
1639 if (llvm::APFloatBase::isRepresentableBy(
1640 srcType.getFloatSemantics(),
1641 intermediateType.getFloatSemantics())) {
1643 if (srcType.getWidth() > resElemType.getWidth()) {
1649 if (srcType == resElemType)
1654 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1656 adaptor.getOperands(),
getType(),
1657 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1658 RoundingMode roundingMode =
1659 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1660 llvm::RoundingMode llvmRoundingMode =
1662 FailureOr<APFloat>
result =
1672void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1673 MLIRContext *context) {
1674 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1681LogicalResult arith::TruncFOp::verify() {
1689bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
1694LogicalResult arith::ScalingTruncFOp::verify() {
1702void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1703 MLIRContext *context) {
1704 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1711void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1712 MLIRContext *context) {
1713 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1720template <
typename From,
typename To>
1728 return srcType && dstType;
1739OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1742 adaptor.getOperands(),
getType(),
1743 [&resEleType](
const APInt &a,
bool &castStatus) {
1744 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1745 APFloat apf(floatTy.getFloatSemantics(),
1746 APInt::getZero(floatTy.getWidth()));
1747 apf.convertFromAPInt(a,
false,
1748 APFloat::rmNearestTiesToEven);
1761OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1764 adaptor.getOperands(),
getType(),
1765 [&resEleType](
const APInt &a,
bool &castStatus) {
1766 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1767 APFloat apf(floatTy.getFloatSemantics(),
1768 APInt::getZero(floatTy.getWidth()));
1769 apf.convertFromAPInt(a,
true,
1770 APFloat::rmNearestTiesToEven);
1783OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1785 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1787 adaptor.getOperands(),
getType(),
1788 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1790 APSInt api(bitWidth,
true);
1791 castStatus = APFloat::opInvalidOp !=
1792 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1805OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1807 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1809 adaptor.getOperands(),
getType(),
1810 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1812 APSInt api(bitWidth,
false);
1813 castStatus = APFloat::opInvalidOp !=
1814 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1829 if (!srcType || !dstType)
1836bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1841OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1843 unsigned resultBitwidth = 64;
1845 resultBitwidth = intTy.getWidth();
1848 adaptor.getOperands(),
getType(),
1849 [resultBitwidth](
const APInt &a,
bool & ) {
1850 return a.sextOrTrunc(resultBitwidth);
1854void arith::IndexCastOp::getCanonicalizationPatterns(
1855 RewritePatternSet &
patterns, MLIRContext *context) {
1856 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1863bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1868OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1870 unsigned resultBitwidth = 64;
1872 resultBitwidth = intTy.getWidth();
1875 adaptor.getOperands(),
getType(),
1876 [resultBitwidth](
const APInt &a,
bool & ) {
1877 return a.zextOrTrunc(resultBitwidth);
1881void arith::IndexCastUIOp::getCanonicalizationPatterns(
1882 RewritePatternSet &
patterns, MLIRContext *context) {
1883 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1896 if (!srcType || !dstType)
1902OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1904 auto operand = adaptor.getIn();
1909 if (
auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1910 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1912 if (llvm::isa<ShapedType>(resType))
1916 if (llvm::isa<ub::PoisonAttr>(operand))
1920 APInt bits = llvm::isa<FloatAttr>(operand)
1921 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1922 : llvm::cast<IntegerAttr>(operand).getValue();
1924 "trying to fold on broken IR: operands have incompatible types");
1926 if (
auto resFloatType = dyn_cast<FloatType>(resType))
1927 return FloatAttr::get(resType,
1928 APFloat(resFloatType.getFloatSemantics(), bits));
1929 return IntegerAttr::get(resType, bits);
1932void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1933 MLIRContext *context) {
1934 patterns.add<BitcastOfBitcast>(context);
1944 const APInt &
lhs,
const APInt &
rhs) {
1945 switch (predicate) {
1946 case arith::CmpIPredicate::eq:
1948 case arith::CmpIPredicate::ne:
1950 case arith::CmpIPredicate::slt:
1952 case arith::CmpIPredicate::sle:
1954 case arith::CmpIPredicate::sgt:
1956 case arith::CmpIPredicate::sge:
1958 case arith::CmpIPredicate::ult:
1960 case arith::CmpIPredicate::ule:
1962 case arith::CmpIPredicate::ugt:
1964 case arith::CmpIPredicate::uge:
1967 llvm_unreachable(
"unknown cmpi predicate kind");
1972 switch (predicate) {
1973 case arith::CmpIPredicate::eq:
1974 case arith::CmpIPredicate::sle:
1975 case arith::CmpIPredicate::sge:
1976 case arith::CmpIPredicate::ule:
1977 case arith::CmpIPredicate::uge:
1979 case arith::CmpIPredicate::ne:
1980 case arith::CmpIPredicate::slt:
1981 case arith::CmpIPredicate::sgt:
1982 case arith::CmpIPredicate::ult:
1983 case arith::CmpIPredicate::ugt:
1986 llvm_unreachable(
"unknown cmpi predicate kind");
1990 if (
auto intType = dyn_cast<IntegerType>(t)) {
1991 return intType.getWidth();
1993 if (
auto vectorIntType = dyn_cast<VectorType>(t)) {
1994 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1996 return std::nullopt;
1999OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2001 if (getLhs() == getRhs()) {
2007 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2009 std::optional<int64_t> integerWidth =
2011 if (integerWidth && integerWidth.value() == 1 &&
2012 getPredicate() == arith::CmpIPredicate::ne)
2013 return extOp.getOperand();
2015 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2017 std::optional<int64_t> integerWidth =
2019 if (integerWidth && integerWidth.value() == 1 &&
2020 getPredicate() == arith::CmpIPredicate::ne)
2021 return extOp.getOperand();
2026 getPredicate() == arith::CmpIPredicate::ne)
2033 getPredicate() == arith::CmpIPredicate::eq)
2038 if (adaptor.getLhs() && !adaptor.getRhs()) {
2040 using Pred = CmpIPredicate;
2041 const std::pair<Pred, Pred> invPreds[] = {
2042 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2043 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2044 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2045 {Pred::ne, Pred::ne},
2047 Pred origPred = getPredicate();
2048 for (
auto pred : invPreds) {
2049 if (origPred == pred.first) {
2050 setPredicate(pred.second);
2051 Value
lhs = getLhs();
2052 Value
rhs = getRhs();
2053 getLhsMutable().assign(
rhs);
2054 getRhsMutable().assign(
lhs);
2058 llvm_unreachable(
"unknown cmpi predicate kind");
2063 if (
auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2066 [pred = getPredicate()](
const APInt &
lhs,
const APInt &
rhs) {
2075void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
2076 MLIRContext *context) {
2077 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2087 const APFloat &
lhs,
const APFloat &
rhs) {
2088 auto cmpResult =
lhs.compare(
rhs);
2089 switch (predicate) {
2090 case arith::CmpFPredicate::AlwaysFalse:
2092 case arith::CmpFPredicate::OEQ:
2093 return cmpResult == APFloat::cmpEqual;
2094 case arith::CmpFPredicate::OGT:
2095 return cmpResult == APFloat::cmpGreaterThan;
2096 case arith::CmpFPredicate::OGE:
2097 return cmpResult == APFloat::cmpGreaterThan ||
2098 cmpResult == APFloat::cmpEqual;
2099 case arith::CmpFPredicate::OLT:
2100 return cmpResult == APFloat::cmpLessThan;
2101 case arith::CmpFPredicate::OLE:
2102 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2103 case arith::CmpFPredicate::ONE:
2104 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2105 case arith::CmpFPredicate::ORD:
2106 return cmpResult != APFloat::cmpUnordered;
2107 case arith::CmpFPredicate::UEQ:
2108 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2109 case arith::CmpFPredicate::UGT:
2110 return cmpResult == APFloat::cmpUnordered ||
2111 cmpResult == APFloat::cmpGreaterThan;
2112 case arith::CmpFPredicate::UGE:
2113 return cmpResult == APFloat::cmpUnordered ||
2114 cmpResult == APFloat::cmpGreaterThan ||
2115 cmpResult == APFloat::cmpEqual;
2116 case arith::CmpFPredicate::ULT:
2117 return cmpResult == APFloat::cmpUnordered ||
2118 cmpResult == APFloat::cmpLessThan;
2119 case arith::CmpFPredicate::ULE:
2120 return cmpResult == APFloat::cmpUnordered ||
2121 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2122 case arith::CmpFPredicate::UNE:
2123 return cmpResult != APFloat::cmpEqual;
2124 case arith::CmpFPredicate::UNO:
2125 return cmpResult == APFloat::cmpUnordered;
2126 case arith::CmpFPredicate::AlwaysTrue:
2129 llvm_unreachable(
"unknown cmpf predicate kind");
2133 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2134 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2137 if (
lhs &&
lhs.getValue().isNaN())
2139 if (
rhs &&
rhs.getValue().isNaN())
2155 using namespace arith;
2157 case CmpFPredicate::UEQ:
2158 case CmpFPredicate::OEQ:
2159 return CmpIPredicate::eq;
2160 case CmpFPredicate::UGT:
2161 case CmpFPredicate::OGT:
2162 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2163 case CmpFPredicate::UGE:
2164 case CmpFPredicate::OGE:
2165 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2166 case CmpFPredicate::ULT:
2167 case CmpFPredicate::OLT:
2168 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2169 case CmpFPredicate::ULE:
2170 case CmpFPredicate::OLE:
2171 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2172 case CmpFPredicate::UNE:
2173 case CmpFPredicate::ONE:
2174 return CmpIPredicate::ne;
2176 llvm_unreachable(
"Unexpected predicate!");
2186 const APFloat &
rhs = flt.getValue();
2194 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2195 int mantissaWidth = floatTy.getFPMantissaWidth();
2196 if (mantissaWidth <= 0)
2202 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2204 intVal = si.getIn();
2205 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2207 intVal = ui.getIn();
2214 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2215 auto intWidth = intTy.getWidth();
2218 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2223 if ((
int)intWidth > mantissaWidth) {
2225 int exponent = ilogb(
rhs);
2226 if (exponent == APFloat::IEK_Inf) {
2227 int maxExponent = ilogb(APFloat::getLargest(
rhs.getSemantics()));
2228 if (maxExponent < (
int)valueBits) {
2235 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2244 switch (op.getPredicate()) {
2245 case CmpFPredicate::ORD:
2250 case CmpFPredicate::UNO:
2263 APFloat signedMax(
rhs.getSemantics());
2264 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2265 APFloat::rmNearestTiesToEven);
2266 if (signedMax <
rhs) {
2267 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2268 pred == CmpIPredicate::sle)
2279 APFloat unsignedMax(
rhs.getSemantics());
2280 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2281 APFloat::rmNearestTiesToEven);
2282 if (unsignedMax <
rhs) {
2283 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2284 pred == CmpIPredicate::ule)
2296 APFloat signedMin(
rhs.getSemantics());
2297 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2298 APFloat::rmNearestTiesToEven);
2299 if (signedMin >
rhs) {
2300 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2301 pred == CmpIPredicate::sge)
2311 APFloat unsignedMin(
rhs.getSemantics());
2312 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2313 APFloat::rmNearestTiesToEven);
2314 if (unsignedMin >
rhs) {
2315 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2316 pred == CmpIPredicate::uge)
2331 APSInt rhsInt(intWidth, isUnsigned);
2332 if (APFloat::opInvalidOp ==
2333 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2339 if (!
rhs.isZero()) {
2340 APFloat apf(floatTy.getFloatSemantics(),
2341 APInt::getZero(floatTy.getWidth()));
2342 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2344 bool equal = apf ==
rhs;
2350 case CmpIPredicate::ne:
2354 case CmpIPredicate::eq:
2358 case CmpIPredicate::ule:
2361 if (
rhs.isNegative()) {
2367 case CmpIPredicate::sle:
2370 if (
rhs.isNegative())
2371 pred = CmpIPredicate::slt;
2373 case CmpIPredicate::ult:
2376 if (
rhs.isNegative()) {
2381 pred = CmpIPredicate::ule;
2383 case CmpIPredicate::slt:
2386 if (!
rhs.isNegative())
2387 pred = CmpIPredicate::sle;
2389 case CmpIPredicate::ugt:
2392 if (
rhs.isNegative()) {
2398 case CmpIPredicate::sgt:
2401 if (
rhs.isNegative())
2402 pred = CmpIPredicate::sge;
2404 case CmpIPredicate::uge:
2407 if (
rhs.isNegative()) {
2412 pred = CmpIPredicate::ugt;
2414 case CmpIPredicate::sge:
2417 if (!
rhs.isNegative())
2418 pred = CmpIPredicate::sgt;
2428 ConstantOp::create(rewriter, op.getLoc(), intVal.
getType(),
2434void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
2435 MLIRContext *context) {
2436 patterns.insert<CmpFIntToFPConst>(context);
2450 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2466 arith::XOrIOp::create(
2467 rewriter, op.getLoc(), op.getCondition(),
2469 op.getCondition().
getType(), 1)));
2477void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2478 MLIRContext *context) {
2479 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2480 SelectI1ToNot, SelectToExtUI>(context);
2483OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2484 Value trueVal = getTrueValue();
2485 Value falseVal = getFalseValue();
2486 if (trueVal == falseVal)
2489 Value condition = getCondition();
2500 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2503 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2507 if (
getType().isSignlessInteger(1) &&
2513 auto pred = cmp.getPredicate();
2514 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2515 auto cmpLhs = cmp.getLhs();
2516 auto cmpRhs = cmp.getRhs();
2524 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2525 (cmpRhs == trueVal && cmpLhs == falseVal))
2526 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2533 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2535 assert(cond.getType().hasStaticShape() &&
2536 "DenseElementsAttr must have static shape");
2538 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2540 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2541 SmallVector<Attribute> results;
2542 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2543 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2544 cond.value_end<BoolAttr>());
2545 auto lhsVals = llvm::make_range(
lhs.value_begin<Attribute>(),
2546 lhs.value_end<Attribute>());
2547 auto rhsVals = llvm::make_range(
rhs.value_begin<Attribute>(),
2548 rhs.value_end<Attribute>());
2550 for (
auto [condVal, lhsVal, rhsVal] :
2551 llvm::zip_equal(condVals, lhsVals, rhsVals))
2552 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2562ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &
result) {
2563 Type conditionType, resultType;
2564 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2572 conditionType = resultType;
2579 result.addTypes(resultType);
2581 {conditionType, resultType, resultType},
2585void arith::SelectOp::print(OpAsmPrinter &p) {
2586 p <<
" " << getOperands();
2589 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().
getType()))
2590 p << condType <<
", ";
2594LogicalResult arith::SelectOp::verify() {
2595 Type conditionType = getCondition().getType();
2602 if (!llvm::isa<TensorType, VectorType>(resultType))
2603 return emitOpError() <<
"expected condition to be a signless i1, but got "
2606 if (conditionType != shapedConditionType) {
2607 return emitOpError() <<
"expected condition type to have the same shape "
2608 "as the result type, expected "
2609 << shapedConditionType <<
", but got "
2618OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2623 bool bounded =
false;
2625 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2626 bounded = b.ult(b.getBitWidth());
2629 return bounded ?
result : Attribute();
2636OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2641 bool bounded =
false;
2643 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2644 bounded = b.ult(b.getBitWidth());
2647 return bounded ?
result : Attribute();
2654OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2659 bool bounded =
false;
2661 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2662 bounded = b.ult(b.getBitWidth());
2665 return bounded ?
result : Attribute();
2675 bool useOnlyFiniteValue) {
2677 case AtomicRMWKind::maximumf: {
2678 const llvm::fltSemantics &semantic =
2679 llvm::cast<FloatType>(resultType).getFloatSemantics();
2680 APFloat identity = useOnlyFiniteValue
2681 ? APFloat::getLargest(semantic,
true)
2682 : APFloat::getInf(semantic,
true);
2685 case AtomicRMWKind::maxnumf: {
2686 const llvm::fltSemantics &semantic =
2687 llvm::cast<FloatType>(resultType).getFloatSemantics();
2688 APFloat identity = APFloat::getNaN(semantic,
true);
2691 case AtomicRMWKind::addf:
2692 case AtomicRMWKind::addi:
2693 case AtomicRMWKind::maxu:
2694 case AtomicRMWKind::ori:
2695 case AtomicRMWKind::xori:
2697 case AtomicRMWKind::andi:
2700 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2701 case AtomicRMWKind::maxs:
2703 resultType, APInt::getSignedMinValue(
2704 llvm::cast<IntegerType>(resultType).getWidth()));
2705 case AtomicRMWKind::minimumf: {
2706 const llvm::fltSemantics &semantic =
2707 llvm::cast<FloatType>(resultType).getFloatSemantics();
2708 APFloat identity = useOnlyFiniteValue
2709 ? APFloat::getLargest(semantic,
false)
2710 : APFloat::getInf(semantic,
false);
2714 case AtomicRMWKind::minnumf: {
2715 const llvm::fltSemantics &semantic =
2716 llvm::cast<FloatType>(resultType).getFloatSemantics();
2717 APFloat identity = APFloat::getNaN(semantic,
false);
2720 case AtomicRMWKind::mins:
2722 resultType, APInt::getSignedMaxValue(
2723 llvm::cast<IntegerType>(resultType).getWidth()));
2724 case AtomicRMWKind::minu:
2727 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2728 case AtomicRMWKind::muli:
2730 case AtomicRMWKind::mulf:
2742 std::optional<AtomicRMWKind> maybeKind =
2745 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2746 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2747 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2748 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2749 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2750 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2752 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2753 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2754 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::xori; })
2755 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2756 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2757 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2758 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2759 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2760 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2761 .Default(std::nullopt);
2763 return std::nullopt;
2766 bool useOnlyFiniteValue =
false;
2767 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2768 if (fmfOpInterface) {
2769 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2770 useOnlyFiniteValue =
2771 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2779 useOnlyFiniteValue);
2785 bool useOnlyFiniteValue) {
2788 return arith::ConstantOp::create(builder, loc, attr);
2796 case AtomicRMWKind::addf:
2797 return arith::AddFOp::create(builder, loc,
lhs,
rhs);
2798 case AtomicRMWKind::addi:
2799 return arith::AddIOp::create(builder, loc,
lhs,
rhs);
2800 case AtomicRMWKind::mulf:
2801 return arith::MulFOp::create(builder, loc,
lhs,
rhs);
2802 case AtomicRMWKind::muli:
2803 return arith::MulIOp::create(builder, loc,
lhs,
rhs);
2804 case AtomicRMWKind::maximumf:
2805 return arith::MaximumFOp::create(builder, loc,
lhs,
rhs);
2806 case AtomicRMWKind::minimumf:
2807 return arith::MinimumFOp::create(builder, loc,
lhs,
rhs);
2808 case AtomicRMWKind::maxnumf:
2809 return arith::MaxNumFOp::create(builder, loc,
lhs,
rhs);
2810 case AtomicRMWKind::minnumf:
2811 return arith::MinNumFOp::create(builder, loc,
lhs,
rhs);
2812 case AtomicRMWKind::maxs:
2813 return arith::MaxSIOp::create(builder, loc,
lhs,
rhs);
2814 case AtomicRMWKind::mins:
2815 return arith::MinSIOp::create(builder, loc,
lhs,
rhs);
2816 case AtomicRMWKind::maxu:
2817 return arith::MaxUIOp::create(builder, loc,
lhs,
rhs);
2818 case AtomicRMWKind::minu:
2819 return arith::MinUIOp::create(builder, loc,
lhs,
rhs);
2820 case AtomicRMWKind::ori:
2821 return arith::OrIOp::create(builder, loc,
lhs,
rhs);
2822 case AtomicRMWKind::andi:
2823 return arith::AndIOp::create(builder, loc,
lhs,
rhs);
2824 case AtomicRMWKind::xori:
2825 return arith::XOrIOp::create(builder, loc,
lhs,
rhs);
2838#define GET_OP_CLASSES
2839#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2845#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static LogicalResult verifyTruncateOp(Op op)
static Type getElementType(Type type)
Determine the element type of type.
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Specialization of arith.constant op that returns an integer of index type.
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Specialization of arith.constant op that returns an integer value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
llvm::function_ref< Fn > function_ref
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.