26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
39 llvm::RoundingMode::NearestTiesToEven;
48 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
49 const APInt &lhsVal = llvm::cast<IntegerAttr>(
lhs).getValue();
50 const APInt &rhsVal = llvm::cast<IntegerAttr>(
rhs).getValue();
51 APInt value = binFn(lhsVal, rhsVal);
52 return IntegerAttr::get(res.
getType(), value);
71static IntegerOverflowFlagsAttr
73 IntegerOverflowFlagsAttr val2) {
74 return IntegerOverflowFlagsAttr::get(val1.getContext(),
75 val1.getValue() & val2.getValue());
81 case arith::CmpIPredicate::eq:
82 return arith::CmpIPredicate::ne;
83 case arith::CmpIPredicate::ne:
84 return arith::CmpIPredicate::eq;
85 case arith::CmpIPredicate::slt:
86 return arith::CmpIPredicate::sge;
87 case arith::CmpIPredicate::sle:
88 return arith::CmpIPredicate::sgt;
89 case arith::CmpIPredicate::sgt:
90 return arith::CmpIPredicate::sle;
91 case arith::CmpIPredicate::sge:
92 return arith::CmpIPredicate::slt;
93 case arith::CmpIPredicate::ult:
94 return arith::CmpIPredicate::uge;
95 case arith::CmpIPredicate::ule:
96 return arith::CmpIPredicate::ugt;
97 case arith::CmpIPredicate::ugt:
98 return arith::CmpIPredicate::ule;
99 case arith::CmpIPredicate::uge:
100 return arith::CmpIPredicate::ult;
102 llvm_unreachable(
"unknown cmpi predicate kind");
111static llvm::RoundingMode
115 switch (*roundingMode) {
116 case RoundingMode::downward:
117 return llvm::RoundingMode::TowardNegative;
118 case RoundingMode::to_nearest_away:
119 return llvm::RoundingMode::NearestTiesToAway;
120 case RoundingMode::to_nearest_even:
121 return llvm::RoundingMode::NearestTiesToEven;
122 case RoundingMode::toward_zero:
123 return llvm::RoundingMode::TowardZero;
124 case RoundingMode::upward:
125 return llvm::RoundingMode::TowardPositive;
127 llvm_unreachable(
"Unhandled rounding mode");
131 return arith::CmpIPredicateAttr::get(pred.getContext(),
157 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
161 if (!shapedType.hasStaticShape())
171#include "ArithCanonicalization.inc"
180 auto i1Type = IntegerType::get(type.
getContext(), 1);
181 if (
auto shapedType = dyn_cast<ShapedType>(type))
182 return shapedType.cloneWith(std::nullopt, i1Type);
183 if (llvm::isa<UnrankedTensorType>(type))
184 return UnrankedTensorType::get(i1Type);
192void arith::ConstantOp::getAsmResultNames(
195 if (
auto intCst = dyn_cast<IntegerAttr>(getValue())) {
196 auto intType = dyn_cast<IntegerType>(type);
199 if (intType && intType.getWidth() == 1)
200 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
203 SmallString<32> specialNameBuffer;
204 llvm::raw_svector_ostream specialName(specialNameBuffer);
205 specialName <<
'c' << intCst.getValue();
207 specialName <<
'_' << type;
208 setNameFn(getResult(), specialName.str());
210 setNameFn(getResult(),
"cst");
216LogicalResult arith::ConstantOp::verify() {
219 if (llvm::isa<IntegerType>(type) &&
220 !llvm::cast<IntegerType>(type).isSignless())
221 return emitOpError(
"integer return type must be signless");
223 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
225 "value must be an integer, float, or elements attribute");
231 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
233 "initializing scalable vectors with elements attribute is not supported"
234 " unless it's a vector splat");
238bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
240 auto typedAttr = dyn_cast<TypedAttr>(value);
241 if (!typedAttr || typedAttr.getType() != type)
245 if (!intType.isSignless())
249 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
252ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
253 Type type, Location loc) {
254 if (isBuildableWith(value, type))
255 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
259OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
264 arith::ConstantOp::build(builder,
result, type,
274 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
275 assert(
result &&
"builder didn't return the right type");
287 arith::ConstantOp::build(builder,
result, type,
296 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
297 assert(
result &&
"builder didn't return the right type");
308 arith::ConstantOp::build(builder,
result, type,
314 const APInt &
value) {
317 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
318 assert(
result &&
"builder didn't return the right type");
324 const APInt &
value) {
329 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
330 return constOp.getType().isSignlessInteger();
335 FloatType type,
const APFloat &
value) {
336 arith::ConstantOp::build(builder,
result, type,
343 const APFloat &
value) {
346 auto result = dyn_cast<ConstantFloatOp>(builder.
create(state));
347 assert(
result &&
"builder didn't return the right type");
353 const APFloat &
value) {
358 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
359 return llvm::isa<FloatType>(constOp.getType());
374 auto result = dyn_cast<ConstantIndexOp>(builder.
create(state));
375 assert(
result &&
"builder didn't return the right type");
385 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
386 return constOp.getType().isIndex();
394 "type doesn't have a zero representation");
396 assert(zeroAttr &&
"unsupported type for zero attribute");
397 return arith::ConstantOp::create(builder, loc, zeroAttr);
410 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
411 if (getRhs() == sub.getRhs())
415 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
416 if (getLhs() == sub.getRhs())
420 adaptor.getOperands(),
421 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
426 patterns.
add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
427 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
434std::optional<SmallVector<int64_t, 4>>
435arith::AddUIExtendedOp::getShapeForUnroll() {
436 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
437 return llvm::to_vector<4>(vt.getShape());
444 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
448arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
449 SmallVectorImpl<OpFoldResult> &results) {
450 Type overflowTy = getOverflow().getType();
456 results.push_back(getLhs());
457 results.push_back(falseValue);
466 adaptor.getOperands(),
467 [](APInt a,
const APInt &
b) { return std::move(a) + b; })) {
470 results.push_back(sumAttr);
471 results.push_back(sumAttr);
475 ArrayRef({sumAttr, adaptor.getLhs()}),
481 results.push_back(sumAttr);
482 results.push_back(overflowAttr);
489void arith::AddUIExtendedOp::getCanonicalizationPatterns(
490 RewritePatternSet &patterns, MLIRContext *context) {
491 patterns.
add<AddUIExtendedToAddI>(context);
498OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
500 if (getOperand(0) == getOperand(1)) {
501 auto shapedType = dyn_cast<ShapedType>(
getType());
503 if (!shapedType || shapedType.hasStaticShape())
510 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
512 if (getRhs() ==
add.getRhs())
515 if (getRhs() ==
add.getLhs())
520 adaptor.getOperands(),
521 [](APInt a,
const APInt &
b) { return std::move(a) - b; });
524void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
525 MLIRContext *context) {
526 patterns.
add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
527 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
528 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
535OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
546 adaptor.getOperands(),
547 [](
const APInt &a,
const APInt &
b) { return a * b; });
550void arith::MulIOp::getAsmResultNames(
552 if (!isa<IndexType>(
getType()))
557 auto isVscale = [](Operation *op) {
558 return op && op->getName().getStringRef() ==
"vector.vscale";
561 IntegerAttr baseValue;
562 auto isVscaleExpr = [&](Value a, Value
b) {
564 isVscale(
b.getDefiningOp());
567 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
571 SmallString<32> specialNameBuffer;
572 llvm::raw_svector_ostream specialName(specialNameBuffer);
573 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
574 setNameFn(getResult(), specialName.str());
577void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
578 MLIRContext *context) {
579 patterns.
add<MulIMulIConstant>(context);
586std::optional<SmallVector<int64_t, 4>>
587arith::MulSIExtendedOp::getShapeForUnroll() {
588 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
589 return llvm::to_vector<4>(vt.getShape());
594arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
595 SmallVectorImpl<OpFoldResult> &results) {
598 Attribute zero = adaptor.getRhs();
599 results.push_back(zero);
600 results.push_back(zero);
606 adaptor.getOperands(),
607 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
610 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
611 return llvm::APIntOps::mulhs(a, b);
613 assert(highAttr &&
"Unexpected constant-folding failure");
615 results.push_back(lowAttr);
616 results.push_back(highAttr);
623void arith::MulSIExtendedOp::getCanonicalizationPatterns(
624 RewritePatternSet &patterns, MLIRContext *context) {
625 patterns.
add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
632std::optional<SmallVector<int64_t, 4>>
633arith::MulUIExtendedOp::getShapeForUnroll() {
634 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
635 return llvm::to_vector<4>(vt.getShape());
640arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
641 SmallVectorImpl<OpFoldResult> &results) {
644 Attribute zero = adaptor.getRhs();
645 results.push_back(zero);
646 results.push_back(zero);
654 results.push_back(getLhs());
655 results.push_back(zero);
661 adaptor.getOperands(),
662 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
665 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
666 return llvm::APIntOps::mulhu(a, b);
668 assert(highAttr &&
"Unexpected constant-folding failure");
670 results.push_back(lowAttr);
671 results.push_back(highAttr);
678void arith::MulUIExtendedOp::getCanonicalizationPatterns(
679 RewritePatternSet &patterns, MLIRContext *context) {
680 patterns.
add<MulUIExtendedToMulI>(context);
689 arith::IntegerOverflowFlags ovfFlags) {
690 auto mul =
lhs.getDefiningOp<mlir::arith::MulIOp>();
691 if (!
mul || !bitEnumContainsAll(
mul.getOverflowFlags(), ovfFlags))
703OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
709 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
715 [&](APInt a,
const APInt &
b) {
723 return div0 ? Attribute() :
result;
743OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
749 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
753 bool overflowOrDiv0 =
false;
755 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
756 if (overflowOrDiv0 || !b) {
757 overflowOrDiv0 = true;
760 return a.sdiv_ov(
b, overflowOrDiv0);
763 return overflowOrDiv0 ? Attribute() :
result;
790 APInt one(a.getBitWidth(), 1,
true);
791 APInt val = a.ssub_ov(one, overflow).sdiv_ov(
b, overflow);
792 return val.sadd_ov(one, overflow);
799OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
804 bool overflowOrDiv0 =
false;
806 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
807 if (overflowOrDiv0 || !b) {
808 overflowOrDiv0 = true;
811 APInt quotient = a.udiv(
b);
814 APInt one(a.getBitWidth(), 1,
true);
815 return quotient.uadd_ov(one, overflowOrDiv0);
818 return overflowOrDiv0 ? Attribute() :
result;
829OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
837 bool overflowOrDiv0 =
false;
839 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
840 if (overflowOrDiv0 || !b) {
841 overflowOrDiv0 = true;
847 unsigned bits = a.getBitWidth();
848 APInt zero = APInt::getZero(bits);
849 bool aGtZero = a.sgt(zero);
850 bool bGtZero =
b.sgt(zero);
851 if (aGtZero && bGtZero) {
858 bool overflowNegA =
false;
859 bool overflowNegB =
false;
860 bool overflowDiv =
false;
861 bool overflowNegRes =
false;
862 if (!aGtZero && !bGtZero) {
864 APInt posA = zero.ssub_ov(a, overflowNegA);
865 APInt posB = zero.ssub_ov(
b, overflowNegB);
867 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
870 if (!aGtZero && bGtZero) {
872 APInt posA = zero.ssub_ov(a, overflowNegA);
873 APInt
div = posA.sdiv_ov(
b, overflowDiv);
874 APInt res = zero.ssub_ov(
div, overflowNegRes);
875 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
879 APInt posB = zero.ssub_ov(
b, overflowNegB);
880 APInt
div = a.sdiv_ov(posB, overflowDiv);
881 APInt res = zero.ssub_ov(
div, overflowNegRes);
883 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
887 return overflowOrDiv0 ? Attribute() :
result;
898OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
904 bool overflowOrDiv =
false;
906 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
908 overflowOrDiv = true;
911 return a.sfloordiv_ov(
b, overflowOrDiv);
914 return overflowOrDiv ? Attribute() :
result;
921OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
929 [&](APInt a,
const APInt &
b) {
930 if (div0 || b.isZero()) {
937 return div0 ? Attribute() :
result;
948OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
956 [&](APInt a,
const APInt &
b) {
957 if (div0 || b.isZero()) {
964 return div0 ? Attribute() :
result;
982 for (
bool reversePrev : {
false,
true}) {
983 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
984 .getDefiningOp<arith::AndIOp>();
988 Value other = (reversePrev ? op.getLhs() : op.getRhs());
989 if (other != prev.getLhs() && other != prev.getRhs())
992 return prev.getResult();
997OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
1004 intValue.isAllOnes())
1009 intValue.isAllOnes())
1014 intValue.isAllOnes())
1022 adaptor.getOperands(),
1023 [](APInt a,
const APInt &
b) { return std::move(a) & b; });
1030OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1033 if (rhsVal.isZero())
1036 if (rhsVal.isAllOnes())
1037 return adaptor.getRhs();
1044 intValue.isAllOnes())
1045 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1049 intValue.isAllOnes())
1050 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1053 adaptor.getOperands(),
1054 [](APInt a,
const APInt &
b) { return std::move(a) | b; });
1061OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1066 if (getLhs() == getRhs())
1070 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1071 if (prev.getRhs() == getRhs())
1072 return prev.getLhs();
1073 if (prev.getLhs() == getRhs())
1074 return prev.getRhs();
1078 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1079 if (prev.getRhs() == getLhs())
1080 return prev.getLhs();
1081 if (prev.getLhs() == getLhs())
1082 return prev.getRhs();
1086 adaptor.getOperands(),
1087 [](APInt a,
const APInt &
b) { return std::move(a) ^ b; });
1090void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1091 MLIRContext *context) {
1092 patterns.
add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1099OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1101 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1102 return op.getOperand();
1104 [](
const APFloat &a) { return -a; });
1111OpFoldResult arith::FlushDenormalsOp::fold(FoldAdaptor adaptor) {
1117 if (
auto op = this->getOperand().getDefiningOp<arith::FlushDenormalsOp>())
1118 return op.getResult();
1122 adaptor.getOperands(), [](
const APFloat &a) {
1124 return APFloat::getZero(a.getSemantics(), a.isNegative());
1133OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1138 auto rm = getRoundingmode();
1140 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1142 result.add(b, convertArithRoundingModeToLLVMIR(rm));
1151OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1156 auto rm = getRoundingmode();
1158 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1160 result.subtract(b, convertArithRoundingModeToLLVMIR(rm));
1169OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1171 if (getLhs() == getRhs())
1179 adaptor.getOperands(),
1180 [](
const APFloat &a,
const APFloat &
b) { return llvm::maximum(a, b); });
1187OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1189 if (getLhs() == getRhs())
1203OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1205 if (getLhs() == getRhs())
1211 if (intValue.isMaxSignedValue())
1214 if (intValue.isMinSignedValue())
1219 [](
const APInt &a,
const APInt &
b) {
1220 return llvm::APIntOps::smax(a, b);
1228OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1230 if (getLhs() == getRhs())
1236 if (intValue.isMaxValue())
1239 if (intValue.isMinValue())
1244 [](
const APInt &a,
const APInt &
b) {
1245 return llvm::APIntOps::umax(a, b);
1253OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1255 if (getLhs() == getRhs())
1263 adaptor.getOperands(),
1264 [](
const APFloat &a,
const APFloat &
b) { return llvm::minimum(a, b); });
1271OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1273 if (getLhs() == getRhs())
1281 adaptor.getOperands(),
1282 [](
const APFloat &a,
const APFloat &
b) { return llvm::minnum(a, b); });
1289OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1291 if (getLhs() == getRhs())
1297 if (intValue.isMinSignedValue())
1300 if (intValue.isMaxSignedValue())
1305 [](
const APInt &a,
const APInt &
b) {
1306 return llvm::APIntOps::smin(a, b);
1314OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1316 if (getLhs() == getRhs())
1322 if (intValue.isMinValue())
1325 if (intValue.isMaxValue())
1330 [](
const APInt &a,
const APInt &
b) {
1331 return llvm::APIntOps::umin(a, b);
1339OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1344 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1345 arith::FastMathFlags::nsz)) {
1351 auto rm = getRoundingmode();
1353 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1355 result.multiply(b, convertArithRoundingModeToLLVMIR(rm));
1360void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1361 MLIRContext *context) {
1362 patterns.
add<MulFOfNegF>(context);
1369OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1374 auto rm = getRoundingmode();
1376 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1378 result.divide(b, convertArithRoundingModeToLLVMIR(rm));
1383void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1384 MLIRContext *context) {
1385 patterns.
add<DivFOfNegF>(context);
1392OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1394 [](
const APFloat &a,
const APFloat &
b) {
1399 (void)result.mod(b);
1408template <
typename... Types>
1414template <
typename... ShapedTypes,
typename... ElementTypes>
1417 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1421 if (!llvm::isa<ElementTypes...>(underlyingType))
1424 return underlyingType;
1428template <
typename... ElementTypes>
1435template <
typename... ElementTypes>
1444 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1445 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1446 if (!rankedTensorA || !rankedTensorB)
1448 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1452 if (inputs.size() != 1 || outputs.size() != 1)
1464template <
typename ValType,
typename Op>
1469 if (llvm::cast<ValType>(srcType).getWidth() >=
1470 llvm::cast<ValType>(dstType).getWidth())
1472 << dstType <<
" must be wider than operand type " << srcType;
1478template <
typename ValType,
typename Op>
1483 if (llvm::cast<ValType>(srcType).getWidth() <=
1484 llvm::cast<ValType>(dstType).getWidth())
1486 << dstType <<
" must be shorter than operand type " << srcType;
1492template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1497 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1498 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1499 if (!srcType || !dstType)
1502 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1503 srcType.getIntOrFloatBitWidth());
1508static FailureOr<APFloat>
1510 const llvm::fltSemantics &targetSemantics,
1514 using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior;
1515 if (sourceValue.isInfinity() &&
1516 (targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
1517 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly))
1519 if (sourceValue.isNaN() &&
1520 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly)
1523 bool losesInfo =
false;
1524 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1525 if (losesInfo || status != APFloat::opOK)
1535OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1536 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1537 getInMutable().assign(
lhs.getIn());
1542 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1544 adaptor.getOperands(),
getType(),
1545 [bitWidth](
const APInt &a,
bool &castStatus) {
1546 return a.zext(bitWidth);
1554LogicalResult arith::ExtUIOp::verify() {
1562OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1563 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1564 getInMutable().assign(
lhs.getIn());
1569 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1571 adaptor.getOperands(),
getType(),
1572 [bitWidth](
const APInt &a,
bool &castStatus) {
1573 return a.sext(bitWidth);
1581void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1582 MLIRContext *context) {
1583 patterns.
add<ExtSIOfExtUI>(context);
1586LogicalResult arith::ExtSIOp::verify() {
1596OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1597 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1598 if (truncFOp.getOperand().getType() ==
getType()) {
1599 arith::FastMathFlags truncFMF =
1600 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1601 bool isTruncContract =
1602 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1603 arith::FastMathFlags extFMF =
1604 getFastmath().value_or(arith::FastMathFlags::none);
1605 bool isExtContract =
1606 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1607 if (isTruncContract && isExtContract) {
1608 return truncFOp.getOperand();
1614 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1616 adaptor.getOperands(),
getType(),
1617 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1637bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
1642LogicalResult arith::ScalingExtFOp::verify() {
1650OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1653 Value src = getOperand().getDefiningOp()->getOperand(0);
1658 if (llvm::cast<IntegerType>(srcType).getWidth() >
1659 llvm::cast<IntegerType>(dstType).getWidth()) {
1666 if (srcType == dstType)
1672 setOperand(getOperand().getDefiningOp()->getOperand(0));
1677 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1679 adaptor.getOperands(),
getType(),
1680 [bitWidth](
const APInt &a,
bool &castStatus) {
1681 return a.trunc(bitWidth);
1689void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1690 MLIRContext *context) {
1692 .
add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1696LogicalResult arith::TruncIOp::verify() {
1706OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1708 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1709 Value src = extOp.getIn();
1711 auto intermediateType =
1714 if (llvm::APFloatBase::isRepresentableBy(
1715 srcType.getFloatSemantics(),
1716 intermediateType.getFloatSemantics())) {
1718 if (srcType.getWidth() > resElemType.getWidth()) {
1724 if (srcType == resElemType)
1729 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1731 adaptor.getOperands(),
getType(),
1732 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1733 llvm::RoundingMode llvmRoundingMode =
1735 FailureOr<APFloat>
result =
1745void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1746 MLIRContext *context) {
1747 patterns.
add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1754LogicalResult arith::TruncFOp::verify() {
1762OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) {
1764 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1766 adaptor.getOperands(),
getType(),
1767 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1768 llvm::RoundingMode llvmRoundingMode =
1770 FailureOr<APFloat>
result =
1785 if (!srcType || !dstType)
1787 return srcType != dstType &&
1791LogicalResult arith::ConvertFOp::verify() {
1794 if (srcType == dstType)
1795 return emitError(
"result element type ")
1796 << dstType <<
" must be different from operand element type "
1798 if (srcType.getWidth() != dstType.getWidth())
1799 return emitError(
"result element type ")
1800 << dstType <<
" must have the same bitwidth as operand element type "
1809bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
1814LogicalResult arith::ScalingTruncFOp::verify() {
1822void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1823 MLIRContext *context) {
1824 patterns.
add<AndOfExtUI, AndOfExtSI>(context);
1831void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1832 MLIRContext *context) {
1833 patterns.
add<OrOfExtUI, OrOfExtSI>(context);
1840template <
typename From,
typename To>
1848 return srcType && dstType;
1859OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1862 adaptor.getOperands(),
getType(),
1863 [&resEleType](
const APInt &a,
bool &castStatus) {
1864 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1865 APFloat apf(floatTy.getFloatSemantics(),
1866 APInt::getZero(floatTy.getWidth()));
1867 apf.convertFromAPInt(a,
false,
1868 APFloat::rmNearestTiesToEven);
1873void arith::UIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1874 MLIRContext *context) {
1875 patterns.
add<UIToFPOfExtUI>(context);
1886OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1889 adaptor.getOperands(),
getType(),
1890 [&resEleType](
const APInt &a,
bool &castStatus) {
1891 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1892 APFloat apf(floatTy.getFloatSemantics(),
1893 APInt::getZero(floatTy.getWidth()));
1894 apf.convertFromAPInt(a,
true,
1895 APFloat::rmNearestTiesToEven);
1900void arith::SIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1901 MLIRContext *context) {
1902 patterns.
add<SIToFPOfExtSI, SIToFPOfExtUI>(context);
1913OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1915 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1917 adaptor.getOperands(),
getType(),
1918 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1920 APSInt api(bitWidth,
true);
1921 castStatus = APFloat::opInvalidOp !=
1922 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1935OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1937 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1939 adaptor.getOperands(),
getType(),
1940 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1942 APSInt api(bitWidth,
false);
1943 castStatus = APFloat::opInvalidOp !=
1944 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1958 return intTy.getWidth();
1959 return IndexType::kInternalStorageBitWidth;
1968 if (!srcType || !dstType)
1972 (srcType.isSignlessInteger() && dstType.
isIndex());
1975bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1980OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1982 unsigned resultBitwidth = 64;
1984 resultBitwidth = intTy.getWidth();
1987 adaptor.getOperands(),
getType(),
1988 [resultBitwidth](
const APInt &a,
bool & ) {
1989 return a.sextOrTrunc(resultBitwidth);
1996 if (
auto inner = getOperand().getDefiningOp<arith::IndexCastOp>()) {
1997 Value x = inner.getOperand();
2006void arith::IndexCastOp::getCanonicalizationPatterns(
2007 RewritePatternSet &patterns, MLIRContext *context) {
2008 patterns.
add<IndexCastOfExtSI>(context);
2015bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
2020OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
2022 unsigned resultBitwidth = 64;
2024 resultBitwidth = intTy.getWidth();
2027 adaptor.getOperands(),
getType(),
2028 [resultBitwidth](
const APInt &a,
bool & ) {
2029 return a.zextOrTrunc(resultBitwidth);
2036 if (
auto inner = getOperand().getDefiningOp<arith::IndexCastUIOp>()) {
2037 Value x = inner.getOperand();
2046void arith::IndexCastUIOp::getCanonicalizationPatterns(
2047 RewritePatternSet &patterns, MLIRContext *context) {
2048 patterns.
add<IndexCastUIOfExtUI>(context);
2061 if (!srcType || !dstType)
2067OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
2069 auto operand = adaptor.getIn();
2074 if (
auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
2075 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
2077 if (llvm::isa<ShapedType>(resType))
2085 APInt bits = llvm::isa<FloatAttr>(operand)
2086 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
2087 : llvm::cast<IntegerAttr>(operand).getValue();
2089 "trying to fold on broken IR: operands have incompatible types");
2091 if (
auto resFloatType = dyn_cast<FloatType>(resType))
2092 return FloatAttr::get(resType,
2093 APFloat(resFloatType.getFloatSemantics(), bits));
2094 return IntegerAttr::get(resType, bits);
2097void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2098 MLIRContext *context) {
2099 patterns.
add<BitcastOfBitcast>(context);
2109 const APInt &
lhs,
const APInt &
rhs) {
2110 switch (predicate) {
2111 case arith::CmpIPredicate::eq:
2113 case arith::CmpIPredicate::ne:
2115 case arith::CmpIPredicate::slt:
2117 case arith::CmpIPredicate::sle:
2119 case arith::CmpIPredicate::sgt:
2121 case arith::CmpIPredicate::sge:
2123 case arith::CmpIPredicate::ult:
2125 case arith::CmpIPredicate::ule:
2127 case arith::CmpIPredicate::ugt:
2129 case arith::CmpIPredicate::uge:
2132 llvm_unreachable(
"unknown cmpi predicate kind");
2137 switch (predicate) {
2138 case arith::CmpIPredicate::eq:
2139 case arith::CmpIPredicate::sle:
2140 case arith::CmpIPredicate::sge:
2141 case arith::CmpIPredicate::ule:
2142 case arith::CmpIPredicate::uge:
2144 case arith::CmpIPredicate::ne:
2145 case arith::CmpIPredicate::slt:
2146 case arith::CmpIPredicate::sgt:
2147 case arith::CmpIPredicate::ult:
2148 case arith::CmpIPredicate::ugt:
2151 llvm_unreachable(
"unknown cmpi predicate kind");
2155 if (
auto intType = dyn_cast<IntegerType>(t)) {
2156 return intType.getWidth();
2158 if (
auto vectorIntType = dyn_cast<VectorType>(t)) {
2159 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
2161 return std::nullopt;
2164OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2166 if (getLhs() == getRhs()) {
2172 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2174 std::optional<int64_t> integerWidth =
2176 if (integerWidth && integerWidth.value() == 1 &&
2177 getPredicate() == arith::CmpIPredicate::ne)
2178 return extOp.getOperand();
2180 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2182 std::optional<int64_t> integerWidth =
2184 if (integerWidth && integerWidth.value() == 1 &&
2185 getPredicate() == arith::CmpIPredicate::ne)
2186 return extOp.getOperand();
2191 getPredicate() == arith::CmpIPredicate::ne)
2198 getPredicate() == arith::CmpIPredicate::eq)
2203 if (adaptor.getLhs() && !adaptor.getRhs()) {
2205 using Pred = CmpIPredicate;
2206 const std::pair<Pred, Pred> invPreds[] = {
2207 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2208 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2209 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2210 {Pred::ne, Pred::ne},
2212 Pred origPred = getPredicate();
2213 for (
auto pred : invPreds) {
2214 if (origPred == pred.first) {
2215 setPredicate(pred.second);
2216 Value
lhs = getLhs();
2217 Value
rhs = getRhs();
2218 getLhsMutable().assign(
rhs);
2219 getRhsMutable().assign(
lhs);
2223 llvm_unreachable(
"unknown cmpi predicate kind");
2228 if (
auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2231 [pred = getPredicate()](
const APInt &
lhs,
const APInt &
rhs) {
2240void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2241 MLIRContext *context) {
2242 patterns.
insert<CmpIExtSI, CmpIExtUI>(context);
2252 const APFloat &
lhs,
const APFloat &
rhs) {
2253 auto cmpResult =
lhs.compare(
rhs);
2254 switch (predicate) {
2255 case arith::CmpFPredicate::AlwaysFalse:
2257 case arith::CmpFPredicate::OEQ:
2258 return cmpResult == APFloat::cmpEqual;
2259 case arith::CmpFPredicate::OGT:
2260 return cmpResult == APFloat::cmpGreaterThan;
2261 case arith::CmpFPredicate::OGE:
2262 return cmpResult == APFloat::cmpGreaterThan ||
2263 cmpResult == APFloat::cmpEqual;
2264 case arith::CmpFPredicate::OLT:
2265 return cmpResult == APFloat::cmpLessThan;
2266 case arith::CmpFPredicate::OLE:
2267 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2268 case arith::CmpFPredicate::ONE:
2269 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2270 case arith::CmpFPredicate::ORD:
2271 return cmpResult != APFloat::cmpUnordered;
2272 case arith::CmpFPredicate::UEQ:
2273 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2274 case arith::CmpFPredicate::UGT:
2275 return cmpResult == APFloat::cmpUnordered ||
2276 cmpResult == APFloat::cmpGreaterThan;
2277 case arith::CmpFPredicate::UGE:
2278 return cmpResult == APFloat::cmpUnordered ||
2279 cmpResult == APFloat::cmpGreaterThan ||
2280 cmpResult == APFloat::cmpEqual;
2281 case arith::CmpFPredicate::ULT:
2282 return cmpResult == APFloat::cmpUnordered ||
2283 cmpResult == APFloat::cmpLessThan;
2284 case arith::CmpFPredicate::ULE:
2285 return cmpResult == APFloat::cmpUnordered ||
2286 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2287 case arith::CmpFPredicate::UNE:
2288 return cmpResult != APFloat::cmpEqual;
2289 case arith::CmpFPredicate::UNO:
2290 return cmpResult == APFloat::cmpUnordered;
2291 case arith::CmpFPredicate::AlwaysTrue:
2294 llvm_unreachable(
"unknown cmpf predicate kind");
2298 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2299 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2302 if (
lhs &&
lhs.getValue().isNaN())
2304 if (
rhs &&
rhs.getValue().isNaN())
2320 using namespace arith;
2322 case CmpFPredicate::UEQ:
2323 case CmpFPredicate::OEQ:
2324 return CmpIPredicate::eq;
2325 case CmpFPredicate::UGT:
2326 case CmpFPredicate::OGT:
2327 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2328 case CmpFPredicate::UGE:
2329 case CmpFPredicate::OGE:
2330 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2331 case CmpFPredicate::ULT:
2332 case CmpFPredicate::OLT:
2333 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2334 case CmpFPredicate::ULE:
2335 case CmpFPredicate::OLE:
2336 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2337 case CmpFPredicate::UNE:
2338 case CmpFPredicate::ONE:
2339 return CmpIPredicate::ne;
2341 llvm_unreachable(
"Unexpected predicate!");
2351 const APFloat &
rhs = flt.getValue();
2359 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2360 int mantissaWidth = floatTy.getFPMantissaWidth();
2361 if (mantissaWidth <= 0)
2367 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2369 intVal = si.getIn();
2370 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2372 intVal = ui.getIn();
2379 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2380 auto intWidth = intTy.getWidth();
2383 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2388 if ((
int)intWidth > mantissaWidth) {
2390 int exponent = ilogb(
rhs);
2391 if (exponent == APFloat::IEK_Inf) {
2392 int maxExponent = ilogb(APFloat::getLargest(
rhs.getSemantics()));
2393 if (maxExponent < (
int)valueBits) {
2400 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2409 switch (op.getPredicate()) {
2410 case CmpFPredicate::ORD:
2415 case CmpFPredicate::UNO:
2428 APFloat signedMax(
rhs.getSemantics());
2429 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2430 APFloat::rmNearestTiesToEven);
2431 if (signedMax <
rhs) {
2432 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2433 pred == CmpIPredicate::sle)
2444 APFloat unsignedMax(
rhs.getSemantics());
2445 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2446 APFloat::rmNearestTiesToEven);
2447 if (unsignedMax <
rhs) {
2448 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2449 pred == CmpIPredicate::ule)
2461 APFloat signedMin(
rhs.getSemantics());
2462 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2463 APFloat::rmNearestTiesToEven);
2464 if (signedMin >
rhs) {
2465 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2466 pred == CmpIPredicate::sge)
2476 APFloat unsignedMin(
rhs.getSemantics());
2477 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2478 APFloat::rmNearestTiesToEven);
2479 if (unsignedMin >
rhs) {
2480 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2481 pred == CmpIPredicate::uge)
2496 APSInt rhsInt(intWidth, isUnsigned);
2497 if (APFloat::opInvalidOp ==
2498 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2504 if (!
rhs.isZero()) {
2505 APFloat apf(floatTy.getFloatSemantics(),
2506 APInt::getZero(floatTy.getWidth()));
2507 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2509 bool equal = apf ==
rhs;
2515 case CmpIPredicate::ne:
2519 case CmpIPredicate::eq:
2523 case CmpIPredicate::ule:
2526 if (
rhs.isNegative()) {
2532 case CmpIPredicate::sle:
2535 if (
rhs.isNegative())
2536 pred = CmpIPredicate::slt;
2538 case CmpIPredicate::ult:
2541 if (
rhs.isNegative()) {
2546 pred = CmpIPredicate::ule;
2548 case CmpIPredicate::slt:
2551 if (!
rhs.isNegative())
2552 pred = CmpIPredicate::sle;
2554 case CmpIPredicate::ugt:
2557 if (
rhs.isNegative()) {
2563 case CmpIPredicate::sgt:
2566 if (
rhs.isNegative())
2567 pred = CmpIPredicate::sge;
2569 case CmpIPredicate::uge:
2572 if (
rhs.isNegative()) {
2577 pred = CmpIPredicate::ugt;
2579 case CmpIPredicate::sge:
2582 if (!
rhs.isNegative())
2583 pred = CmpIPredicate::sgt;
2593 ConstantOp::create(rewriter, op.getLoc(), intVal.
getType(),
2599void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2600 MLIRContext *context) {
2601 patterns.
insert<CmpFIntToFPConst>(context);
2615 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2631 arith::XOrIOp::create(
2632 rewriter, op.getLoc(), op.getCondition(),
2634 op.getCondition().
getType(), 1)));
2642void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2643 MLIRContext *context) {
2644 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2645 SelectI1ToNot, SelectToExtUI>(context);
2648OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2649 Value trueVal = getTrueValue();
2650 Value falseVal = getFalseValue();
2651 if (trueVal == falseVal)
2654 Value condition = getCondition();
2672 if (
getType().isSignlessInteger(1) &&
2678 auto pred = cmp.getPredicate();
2679 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2680 auto cmpLhs = cmp.getLhs();
2681 auto cmpRhs = cmp.getRhs();
2689 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2690 (cmpRhs == trueVal && cmpLhs == falseVal))
2691 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2698 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2700 assert(cond.getType().hasStaticShape() &&
2701 "DenseElementsAttr must have static shape");
2703 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2705 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2706 SmallVector<Attribute> results;
2707 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2708 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2709 cond.value_end<BoolAttr>());
2710 auto lhsVals = llvm::make_range(
lhs.value_begin<Attribute>(),
2711 lhs.value_end<Attribute>());
2712 auto rhsVals = llvm::make_range(
rhs.value_begin<Attribute>(),
2713 rhs.value_end<Attribute>());
2715 for (
auto [condVal, lhsVal, rhsVal] :
2716 llvm::zip_equal(condVals, lhsVals, rhsVals))
2717 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2727ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &
result) {
2728 Type conditionType, resultType;
2729 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2737 conditionType = resultType;
2744 result.addTypes(resultType);
2746 {conditionType, resultType, resultType},
2750void arith::SelectOp::print(OpAsmPrinter &p) {
2751 p <<
" " << getOperands();
2754 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().
getType()))
2755 p << condType <<
", ";
2759LogicalResult arith::SelectOp::verify() {
2760 Type conditionType = getCondition().getType();
2767 if (!llvm::isa<TensorType, VectorType>(resultType))
2768 return emitOpError() <<
"expected condition to be a signless i1, but got "
2771 if (conditionType != shapedConditionType) {
2772 return emitOpError() <<
"expected condition type to have the same shape "
2773 "as the result type, expected "
2774 << shapedConditionType <<
", but got "
2783OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2788 bool bounded =
false;
2790 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2791 bounded = b.ult(b.getBitWidth());
2794 return bounded ?
result : Attribute();
2801OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2806 bool bounded =
false;
2808 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2809 bounded = b.ult(b.getBitWidth());
2812 return bounded ?
result : Attribute();
2819OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2824 bool bounded =
false;
2826 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2827 bounded = b.ult(b.getBitWidth());
2830 return bounded ?
result : Attribute();
2840 bool useOnlyFiniteValue) {
2842 case AtomicRMWKind::maximumf: {
2843 const llvm::fltSemantics &semantic =
2844 llvm::cast<FloatType>(resultType).getFloatSemantics();
2845 APFloat identity = useOnlyFiniteValue
2846 ? APFloat::getLargest(semantic,
true)
2847 : APFloat::getInf(semantic,
true);
2850 case AtomicRMWKind::maxnumf: {
2851 const llvm::fltSemantics &semantic =
2852 llvm::cast<FloatType>(resultType).getFloatSemantics();
2853 APFloat identity = APFloat::getNaN(semantic,
true);
2856 case AtomicRMWKind::addf:
2857 case AtomicRMWKind::addi:
2858 case AtomicRMWKind::maxu:
2859 case AtomicRMWKind::ori:
2860 case AtomicRMWKind::xori:
2862 case AtomicRMWKind::andi:
2865 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2866 case AtomicRMWKind::maxs:
2868 resultType, APInt::getSignedMinValue(
2869 llvm::cast<IntegerType>(resultType).getWidth()));
2870 case AtomicRMWKind::minimumf: {
2871 const llvm::fltSemantics &semantic =
2872 llvm::cast<FloatType>(resultType).getFloatSemantics();
2873 APFloat identity = useOnlyFiniteValue
2874 ? APFloat::getLargest(semantic,
false)
2875 : APFloat::getInf(semantic,
false);
2879 case AtomicRMWKind::minnumf: {
2880 const llvm::fltSemantics &semantic =
2881 llvm::cast<FloatType>(resultType).getFloatSemantics();
2882 APFloat identity = APFloat::getNaN(semantic,
false);
2885 case AtomicRMWKind::mins:
2887 resultType, APInt::getSignedMaxValue(
2888 llvm::cast<IntegerType>(resultType).getWidth()));
2889 case AtomicRMWKind::minu:
2892 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2893 case AtomicRMWKind::muli:
2895 case AtomicRMWKind::mulf:
2907 std::optional<AtomicRMWKind> maybeKind =
2910 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2911 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2912 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2913 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2914 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2915 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2917 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2918 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2919 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::xori; })
2920 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2921 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2922 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2923 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2924 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2925 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2926 .Default(std::nullopt);
2928 return std::nullopt;
2931 bool useOnlyFiniteValue =
false;
2932 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2933 if (fmfOpInterface) {
2934 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2935 useOnlyFiniteValue =
2936 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2944 useOnlyFiniteValue);
2950 bool useOnlyFiniteValue) {
2953 return arith::ConstantOp::create(builder, loc, attr);
2962 case AtomicRMWKind::addf:
2963 return arith::AddFOp::create(builder, loc,
lhs,
rhs);
2964 case AtomicRMWKind::addi:
2965 return arith::AddIOp::create(builder, loc,
lhs,
rhs);
2966 case AtomicRMWKind::mulf:
2967 return arith::MulFOp::create(builder, loc,
lhs,
rhs);
2968 case AtomicRMWKind::muli:
2969 return arith::MulIOp::create(builder, loc,
lhs,
rhs);
2970 case AtomicRMWKind::maximumf:
2971 return arith::MaximumFOp::create(builder, loc,
lhs,
rhs);
2972 case AtomicRMWKind::minimumf:
2973 return arith::MinimumFOp::create(builder, loc,
lhs,
rhs);
2974 case AtomicRMWKind::maxnumf:
2975 return arith::MaxNumFOp::create(builder, loc,
lhs,
rhs);
2976 case AtomicRMWKind::minnumf:
2977 return arith::MinNumFOp::create(builder, loc,
lhs,
rhs);
2978 case AtomicRMWKind::maxs:
2979 return arith::MaxSIOp::create(builder, loc,
lhs,
rhs);
2980 case AtomicRMWKind::mins:
2981 return arith::MinSIOp::create(builder, loc,
lhs,
rhs);
2982 case AtomicRMWKind::maxu:
2983 return arith::MaxUIOp::create(builder, loc,
lhs,
rhs);
2984 case AtomicRMWKind::minu:
2985 return arith::MinUIOp::create(builder, loc,
lhs,
rhs);
2986 case AtomicRMWKind::ori:
2987 return arith::OrIOp::create(builder, loc,
lhs,
rhs);
2988 case AtomicRMWKind::andi:
2989 return arith::AndIOp::create(builder, loc,
lhs,
rhs);
2990 case AtomicRMWKind::xori:
2991 return arith::XOrIOp::create(builder, loc,
lhs,
rhs);
3004#define GET_OP_CLASSES
3005#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
3011#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static constexpr llvm::RoundingMode kDefaultRoundingMode
Default rounding mode according to default LLVM floating-point environment.
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=kDefaultRoundingMode)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(std::optional< RoundingMode > roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static unsigned getIndexCastWidth(Type t)
Return the bit-width of t for the purpose of index_cast width checks.
static LogicalResult verifyTruncateOp(Op op)
static Type getElementType(Type type)
Determine the element type of type.
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Specialization of arith.constant op that returns an integer of index type.
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Specialization of arith.constant op that returns an integer value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
llvm::function_ref< Fn > function_ref
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.