26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
44 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
45 APInt lhsVal = llvm::cast<IntegerAttr>(
lhs).getValue();
46 APInt rhsVal = llvm::cast<IntegerAttr>(
rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
48 return IntegerAttr::get(res.
getType(), value);
67static IntegerOverflowFlagsAttr
69 IntegerOverflowFlagsAttr val2) {
70 return IntegerOverflowFlagsAttr::get(val1.getContext(),
71 val1.getValue() & val2.getValue());
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
98 llvm_unreachable(
"unknown cmpi predicate kind");
107static llvm::RoundingMode
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
121 llvm_unreachable(
"Unhandled rounding mode");
125 return arith::CmpIPredicateAttr::get(pred.getContext(),
151 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
155 if (!shapedType.hasStaticShape())
165#include "ArithCanonicalization.inc"
174 auto i1Type = IntegerType::get(type.
getContext(), 1);
175 if (
auto shapedType = dyn_cast<ShapedType>(type))
176 return shapedType.cloneWith(std::nullopt, i1Type);
177 if (llvm::isa<UnrankedTensorType>(type))
178 return UnrankedTensorType::get(i1Type);
186void arith::ConstantOp::getAsmResultNames(
189 if (
auto intCst = dyn_cast<IntegerAttr>(getValue())) {
190 auto intType = dyn_cast<IntegerType>(type);
193 if (intType && intType.getWidth() == 1)
194 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
197 SmallString<32> specialNameBuffer;
198 llvm::raw_svector_ostream specialName(specialNameBuffer);
199 specialName <<
'c' << intCst.getValue();
201 specialName <<
'_' << type;
202 setNameFn(getResult(), specialName.str());
204 setNameFn(getResult(),
"cst");
210LogicalResult arith::ConstantOp::verify() {
213 if (llvm::isa<IntegerType>(type) &&
214 !llvm::cast<IntegerType>(type).isSignless())
215 return emitOpError(
"integer return type must be signless");
217 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
219 "value must be an integer, float, or elements attribute");
225 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
227 "initializing scalable vectors with elements attribute is not supported"
228 " unless it's a vector splat");
232bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
234 auto typedAttr = dyn_cast<TypedAttr>(value);
235 if (!typedAttr || typedAttr.getType() != type)
238 if (llvm::isa<IntegerType>(type) &&
239 !llvm::cast<IntegerType>(type).isSignless())
242 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
245ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
246 Type type, Location loc) {
247 if (isBuildableWith(value, type))
248 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
252OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
257 arith::ConstantOp::build(builder,
result, type,
267 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
268 assert(
result &&
"builder didn't return the right type");
280 arith::ConstantOp::build(builder,
result, type,
289 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
290 assert(
result &&
"builder didn't return the right type");
301 arith::ConstantOp::build(builder,
result, type,
307 const APInt &
value) {
310 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
311 assert(
result &&
"builder didn't return the right type");
317 const APInt &
value) {
322 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
323 return constOp.getType().isSignlessInteger();
328 FloatType type,
const APFloat &
value) {
329 arith::ConstantOp::build(builder,
result, type,
336 const APFloat &
value) {
339 auto result = dyn_cast<ConstantFloatOp>(builder.
create(state));
340 assert(
result &&
"builder didn't return the right type");
346 const APFloat &
value) {
351 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
352 return llvm::isa<FloatType>(constOp.getType());
367 auto result = dyn_cast<ConstantIndexOp>(builder.
create(state));
368 assert(
result &&
"builder didn't return the right type");
378 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
379 return constOp.getType().isIndex();
387 "type doesn't have a zero representation");
389 assert(zeroAttr &&
"unsupported type for zero attribute");
390 return arith::ConstantOp::create(builder, loc, zeroAttr);
403 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
404 if (getRhs() == sub.getRhs())
408 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
409 if (getLhs() == sub.getRhs())
413 adaptor.getOperands(),
414 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
419 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
420 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
427std::optional<SmallVector<int64_t, 4>>
428arith::AddUIExtendedOp::getShapeForUnroll() {
429 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
430 return llvm::to_vector<4>(vt.getShape());
437 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
441arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
442 SmallVectorImpl<OpFoldResult> &results) {
443 Type overflowTy = getOverflow().getType();
449 results.push_back(getLhs());
450 results.push_back(falseValue);
459 adaptor.getOperands(),
460 [](APInt a,
const APInt &
b) { return std::move(a) + b; })) {
462 ArrayRef({sumAttr, adaptor.getLhs()}),
468 results.push_back(sumAttr);
469 results.push_back(overflowAttr);
476void arith::AddUIExtendedOp::getCanonicalizationPatterns(
477 RewritePatternSet &
patterns, MLIRContext *context) {
478 patterns.add<AddUIExtendedToAddI>(context);
485OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
487 if (getOperand(0) == getOperand(1)) {
488 auto shapedType = dyn_cast<ShapedType>(
getType());
490 if (!shapedType || shapedType.hasStaticShape())
497 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
499 if (getRhs() ==
add.getRhs())
502 if (getRhs() ==
add.getLhs())
507 adaptor.getOperands(),
508 [](APInt a,
const APInt &
b) { return std::move(a) - b; });
511void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
512 MLIRContext *context) {
513 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
514 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
515 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
522OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
533 adaptor.getOperands(),
534 [](
const APInt &a,
const APInt &
b) { return a * b; });
537void arith::MulIOp::getAsmResultNames(
539 if (!isa<IndexType>(
getType()))
544 auto isVscale = [](Operation *op) {
545 return op && op->getName().getStringRef() ==
"vector.vscale";
548 IntegerAttr baseValue;
549 auto isVscaleExpr = [&](Value a, Value
b) {
551 isVscale(
b.getDefiningOp());
554 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
558 SmallString<32> specialNameBuffer;
559 llvm::raw_svector_ostream specialName(specialNameBuffer);
560 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
561 setNameFn(getResult(), specialName.str());
564void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
565 MLIRContext *context) {
566 patterns.add<MulIMulIConstant>(context);
573std::optional<SmallVector<int64_t, 4>>
574arith::MulSIExtendedOp::getShapeForUnroll() {
575 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
576 return llvm::to_vector<4>(vt.getShape());
581arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
582 SmallVectorImpl<OpFoldResult> &results) {
585 Attribute zero = adaptor.getRhs();
586 results.push_back(zero);
587 results.push_back(zero);
593 adaptor.getOperands(),
594 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
597 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
598 return llvm::APIntOps::mulhs(a, b);
600 assert(highAttr &&
"Unexpected constant-folding failure");
602 results.push_back(lowAttr);
603 results.push_back(highAttr);
610void arith::MulSIExtendedOp::getCanonicalizationPatterns(
611 RewritePatternSet &
patterns, MLIRContext *context) {
612 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
619std::optional<SmallVector<int64_t, 4>>
620arith::MulUIExtendedOp::getShapeForUnroll() {
621 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
622 return llvm::to_vector<4>(vt.getShape());
627arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
628 SmallVectorImpl<OpFoldResult> &results) {
631 Attribute zero = adaptor.getRhs();
632 results.push_back(zero);
633 results.push_back(zero);
641 results.push_back(getLhs());
642 results.push_back(zero);
648 adaptor.getOperands(),
649 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
652 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
653 return llvm::APIntOps::mulhu(a, b);
655 assert(highAttr &&
"Unexpected constant-folding failure");
657 results.push_back(lowAttr);
658 results.push_back(highAttr);
665void arith::MulUIExtendedOp::getCanonicalizationPatterns(
666 RewritePatternSet &
patterns, MLIRContext *context) {
667 patterns.add<MulUIExtendedToMulI>(context);
676 arith::IntegerOverflowFlags ovfFlags) {
677 auto mul =
lhs.getDefiningOp<mlir::arith::MulIOp>();
678 if (!
mul || !bitEnumContainsAll(
mul.getOverflowFlags(), ovfFlags))
690OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
696 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
702 [&](APInt a,
const APInt &
b) {
710 return div0 ? Attribute() :
result;
730OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
736 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
740 bool overflowOrDiv0 =
false;
742 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
743 if (overflowOrDiv0 || !b) {
744 overflowOrDiv0 = true;
747 return a.sdiv_ov(
b, overflowOrDiv0);
750 return overflowOrDiv0 ? Attribute() :
result;
777 APInt one(a.getBitWidth(), 1,
true);
778 APInt val = a.ssub_ov(one, overflow).sdiv_ov(
b, overflow);
779 return val.sadd_ov(one, overflow);
786OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
791 bool overflowOrDiv0 =
false;
793 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
794 if (overflowOrDiv0 || !b) {
795 overflowOrDiv0 = true;
798 APInt quotient = a.udiv(
b);
801 APInt one(a.getBitWidth(), 1,
true);
802 return quotient.uadd_ov(one, overflowOrDiv0);
805 return overflowOrDiv0 ? Attribute() :
result;
816OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
824 bool overflowOrDiv0 =
false;
826 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
827 if (overflowOrDiv0 || !b) {
828 overflowOrDiv0 = true;
834 unsigned bits = a.getBitWidth();
835 APInt zero = APInt::getZero(bits);
836 bool aGtZero = a.sgt(zero);
837 bool bGtZero =
b.sgt(zero);
838 if (aGtZero && bGtZero) {
845 bool overflowNegA =
false;
846 bool overflowNegB =
false;
847 bool overflowDiv =
false;
848 bool overflowNegRes =
false;
849 if (!aGtZero && !bGtZero) {
851 APInt posA = zero.ssub_ov(a, overflowNegA);
852 APInt posB = zero.ssub_ov(
b, overflowNegB);
854 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
857 if (!aGtZero && bGtZero) {
859 APInt posA = zero.ssub_ov(a, overflowNegA);
860 APInt
div = posA.sdiv_ov(
b, overflowDiv);
861 APInt res = zero.ssub_ov(
div, overflowNegRes);
862 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
866 APInt posB = zero.ssub_ov(
b, overflowNegB);
867 APInt
div = a.sdiv_ov(posB, overflowDiv);
868 APInt res = zero.ssub_ov(
div, overflowNegRes);
870 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
874 return overflowOrDiv0 ? Attribute() :
result;
885OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
891 bool overflowOrDiv =
false;
893 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
895 overflowOrDiv = true;
898 return a.sfloordiv_ov(
b, overflowOrDiv);
901 return overflowOrDiv ? Attribute() :
result;
908OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
916 [&](APInt a,
const APInt &
b) {
917 if (div0 || b.isZero()) {
924 return div0 ? Attribute() :
result;
931OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
939 [&](APInt a,
const APInt &
b) {
940 if (div0 || b.isZero()) {
947 return div0 ? Attribute() :
result;
956 for (
bool reversePrev : {
false,
true}) {
957 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
958 .getDefiningOp<arith::AndIOp>();
962 Value other = (reversePrev ? op.getLhs() : op.getRhs());
963 if (other != prev.getLhs() && other != prev.getRhs())
966 return prev.getResult();
971OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
978 intValue.isAllOnes())
983 intValue.isAllOnes())
988 intValue.isAllOnes())
996 adaptor.getOperands(),
997 [](APInt a,
const APInt &
b) { return std::move(a) & b; });
1004OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1007 if (rhsVal.isZero())
1010 if (rhsVal.isAllOnes())
1011 return adaptor.getRhs();
1018 intValue.isAllOnes())
1019 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1023 intValue.isAllOnes())
1024 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1027 adaptor.getOperands(),
1028 [](APInt a,
const APInt &
b) { return std::move(a) | b; });
1035OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1040 if (getLhs() == getRhs())
1044 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1045 if (prev.getRhs() == getRhs())
1046 return prev.getLhs();
1047 if (prev.getLhs() == getRhs())
1048 return prev.getRhs();
1052 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1053 if (prev.getRhs() == getLhs())
1054 return prev.getLhs();
1055 if (prev.getLhs() == getLhs())
1056 return prev.getRhs();
1060 adaptor.getOperands(),
1061 [](APInt a,
const APInt &
b) { return std::move(a) ^ b; });
1064void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1065 MLIRContext *context) {
1066 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1073OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1075 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1076 return op.getOperand();
1078 [](
const APFloat &a) { return -a; });
1085OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1091 adaptor.getOperands(),
1092 [](
const APFloat &a,
const APFloat &
b) { return a + b; });
1099OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1105 adaptor.getOperands(),
1106 [](
const APFloat &a,
const APFloat &
b) { return a - b; });
1113OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1115 if (getLhs() == getRhs())
1123 adaptor.getOperands(),
1124 [](
const APFloat &a,
const APFloat &
b) { return llvm::maximum(a, b); });
1131OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1133 if (getLhs() == getRhs())
1147OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1149 if (getLhs() == getRhs())
1155 if (intValue.isMaxSignedValue())
1158 if (intValue.isMinSignedValue())
1163 [](
const APInt &a,
const APInt &
b) {
1164 return llvm::APIntOps::smax(a, b);
1172OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1174 if (getLhs() == getRhs())
1180 if (intValue.isMaxValue())
1183 if (intValue.isMinValue())
1188 [](
const APInt &a,
const APInt &
b) {
1189 return llvm::APIntOps::umax(a, b);
1197OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1199 if (getLhs() == getRhs())
1207 adaptor.getOperands(),
1208 [](
const APFloat &a,
const APFloat &
b) { return llvm::minimum(a, b); });
1215OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1217 if (getLhs() == getRhs())
1225 adaptor.getOperands(),
1226 [](
const APFloat &a,
const APFloat &
b) { return llvm::minnum(a, b); });
1233OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1235 if (getLhs() == getRhs())
1241 if (intValue.isMinSignedValue())
1244 if (intValue.isMaxSignedValue())
1249 [](
const APInt &a,
const APInt &
b) {
1250 return llvm::APIntOps::smin(a, b);
1258OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1260 if (getLhs() == getRhs())
1266 if (intValue.isMinValue())
1269 if (intValue.isMaxValue())
1274 [](
const APInt &a,
const APInt &
b) {
1275 return llvm::APIntOps::umin(a, b);
1283OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1288 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1289 arith::FastMathFlags::nsz)) {
1296 adaptor.getOperands(),
1297 [](
const APFloat &a,
const APFloat &
b) { return a * b; });
1300void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1301 MLIRContext *context) {
1309OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1315 adaptor.getOperands(),
1316 [](
const APFloat &a,
const APFloat &
b) { return a / b; });
1319void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1320 MLIRContext *context) {
1328OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1330 [](
const APFloat &a,
const APFloat &
b) {
1335 (void)result.mod(b);
1344template <
typename... Types>
1350template <
typename... ShapedTypes,
typename... ElementTypes>
1353 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1357 if (!llvm::isa<ElementTypes...>(underlyingType))
1360 return underlyingType;
1364template <
typename... ElementTypes>
1371template <
typename... ElementTypes>
1380 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1381 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1382 if (!rankedTensorA || !rankedTensorB)
1384 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1388 if (inputs.size() != 1 || outputs.size() != 1)
1400template <
typename ValType,
typename Op>
1405 if (llvm::cast<ValType>(srcType).getWidth() >=
1406 llvm::cast<ValType>(dstType).getWidth())
1408 << dstType <<
" must be wider than operand type " << srcType;
1414template <
typename ValType,
typename Op>
1419 if (llvm::cast<ValType>(srcType).getWidth() <=
1420 llvm::cast<ValType>(dstType).getWidth())
1422 << dstType <<
" must be shorter than operand type " << srcType;
1428template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1433 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1434 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1435 if (!srcType || !dstType)
1438 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1439 srcType.getIntOrFloatBitWidth());
1445 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1446 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1447 bool losesInfo =
false;
1448 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1449 if (losesInfo || status != APFloat::opOK)
1459OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1460 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1461 getInMutable().assign(
lhs.getIn());
1466 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1468 adaptor.getOperands(),
getType(),
1469 [bitWidth](
const APInt &a,
bool &castStatus) {
1470 return a.zext(bitWidth);
1478LogicalResult arith::ExtUIOp::verify() {
1486OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1487 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1488 getInMutable().assign(
lhs.getIn());
1493 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1495 adaptor.getOperands(),
getType(),
1496 [bitWidth](
const APInt &a,
bool &castStatus) {
1497 return a.sext(bitWidth);
1505void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1506 MLIRContext *context) {
1507 patterns.add<ExtSIOfExtUI>(context);
1510LogicalResult arith::ExtSIOp::verify() {
1520OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1521 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1522 if (truncFOp.getOperand().getType() ==
getType()) {
1523 arith::FastMathFlags truncFMF =
1524 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1525 bool isTruncContract =
1526 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1527 arith::FastMathFlags extFMF =
1528 getFastmath().value_or(arith::FastMathFlags::none);
1529 bool isExtContract =
1530 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1531 if (isTruncContract && isExtContract) {
1532 return truncFOp.getOperand();
1538 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1540 adaptor.getOperands(),
getType(),
1541 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1561bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
1566LogicalResult arith::ScalingExtFOp::verify() {
1574OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1577 Value src = getOperand().getDefiningOp()->getOperand(0);
1582 if (llvm::cast<IntegerType>(srcType).getWidth() >
1583 llvm::cast<IntegerType>(dstType).getWidth()) {
1590 if (srcType == dstType)
1596 setOperand(getOperand().getDefiningOp()->getOperand(0));
1601 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1603 adaptor.getOperands(),
getType(),
1604 [bitWidth](
const APInt &a,
bool &castStatus) {
1605 return a.trunc(bitWidth);
1613void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1614 MLIRContext *context) {
1616 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1620LogicalResult arith::TruncIOp::verify() {
1630OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1632 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1633 Value src = extOp.getIn();
1635 auto intermediateType =
1638 if (llvm::APFloatBase::isRepresentableBy(
1639 srcType.getFloatSemantics(),
1640 intermediateType.getFloatSemantics())) {
1642 if (srcType.getWidth() > resElemType.getWidth()) {
1648 if (srcType == resElemType)
1653 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1655 adaptor.getOperands(),
getType(),
1656 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1657 RoundingMode roundingMode =
1658 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1659 llvm::RoundingMode llvmRoundingMode =
1661 FailureOr<APFloat>
result =
1671void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1672 MLIRContext *context) {
1673 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1680LogicalResult arith::TruncFOp::verify() {
1688bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
1693LogicalResult arith::ScalingTruncFOp::verify() {
1701void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1702 MLIRContext *context) {
1703 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1710void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1711 MLIRContext *context) {
1712 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1719template <
typename From,
typename To>
1727 return srcType && dstType;
1738OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1741 adaptor.getOperands(),
getType(),
1742 [&resEleType](
const APInt &a,
bool &castStatus) {
1743 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1744 APFloat apf(floatTy.getFloatSemantics(),
1745 APInt::getZero(floatTy.getWidth()));
1746 apf.convertFromAPInt(a,
false,
1747 APFloat::rmNearestTiesToEven);
1760OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1763 adaptor.getOperands(),
getType(),
1764 [&resEleType](
const APInt &a,
bool &castStatus) {
1765 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1766 APFloat apf(floatTy.getFloatSemantics(),
1767 APInt::getZero(floatTy.getWidth()));
1768 apf.convertFromAPInt(a,
true,
1769 APFloat::rmNearestTiesToEven);
1782OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1784 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1786 adaptor.getOperands(),
getType(),
1787 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1789 APSInt api(bitWidth,
true);
1790 castStatus = APFloat::opInvalidOp !=
1791 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1804OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1806 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1808 adaptor.getOperands(),
getType(),
1809 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1811 APSInt api(bitWidth,
false);
1812 castStatus = APFloat::opInvalidOp !=
1813 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1828 if (!srcType || !dstType)
1835bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1840OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1842 unsigned resultBitwidth = 64;
1844 resultBitwidth = intTy.getWidth();
1847 adaptor.getOperands(),
getType(),
1848 [resultBitwidth](
const APInt &a,
bool & ) {
1849 return a.sextOrTrunc(resultBitwidth);
1853void arith::IndexCastOp::getCanonicalizationPatterns(
1854 RewritePatternSet &
patterns, MLIRContext *context) {
1855 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1862bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1867OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1869 unsigned resultBitwidth = 64;
1871 resultBitwidth = intTy.getWidth();
1874 adaptor.getOperands(),
getType(),
1875 [resultBitwidth](
const APInt &a,
bool & ) {
1876 return a.zextOrTrunc(resultBitwidth);
1880void arith::IndexCastUIOp::getCanonicalizationPatterns(
1881 RewritePatternSet &
patterns, MLIRContext *context) {
1882 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1895 if (!srcType || !dstType)
1901OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1903 auto operand = adaptor.getIn();
1908 if (
auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1909 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1911 if (llvm::isa<ShapedType>(resType))
1915 if (llvm::isa<ub::PoisonAttr>(operand))
1919 APInt bits = llvm::isa<FloatAttr>(operand)
1920 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1921 : llvm::cast<IntegerAttr>(operand).getValue();
1923 "trying to fold on broken IR: operands have incompatible types");
1925 if (
auto resFloatType = dyn_cast<FloatType>(resType))
1926 return FloatAttr::get(resType,
1927 APFloat(resFloatType.getFloatSemantics(), bits));
1928 return IntegerAttr::get(resType, bits);
1931void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
1932 MLIRContext *context) {
1933 patterns.add<BitcastOfBitcast>(context);
1943 const APInt &
lhs,
const APInt &
rhs) {
1944 switch (predicate) {
1945 case arith::CmpIPredicate::eq:
1947 case arith::CmpIPredicate::ne:
1949 case arith::CmpIPredicate::slt:
1951 case arith::CmpIPredicate::sle:
1953 case arith::CmpIPredicate::sgt:
1955 case arith::CmpIPredicate::sge:
1957 case arith::CmpIPredicate::ult:
1959 case arith::CmpIPredicate::ule:
1961 case arith::CmpIPredicate::ugt:
1963 case arith::CmpIPredicate::uge:
1966 llvm_unreachable(
"unknown cmpi predicate kind");
1971 switch (predicate) {
1972 case arith::CmpIPredicate::eq:
1973 case arith::CmpIPredicate::sle:
1974 case arith::CmpIPredicate::sge:
1975 case arith::CmpIPredicate::ule:
1976 case arith::CmpIPredicate::uge:
1978 case arith::CmpIPredicate::ne:
1979 case arith::CmpIPredicate::slt:
1980 case arith::CmpIPredicate::sgt:
1981 case arith::CmpIPredicate::ult:
1982 case arith::CmpIPredicate::ugt:
1985 llvm_unreachable(
"unknown cmpi predicate kind");
1989 if (
auto intType = dyn_cast<IntegerType>(t)) {
1990 return intType.getWidth();
1992 if (
auto vectorIntType = dyn_cast<VectorType>(t)) {
1993 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1995 return std::nullopt;
1998OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2000 if (getLhs() == getRhs()) {
2006 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2008 std::optional<int64_t> integerWidth =
2010 if (integerWidth && integerWidth.value() == 1 &&
2011 getPredicate() == arith::CmpIPredicate::ne)
2012 return extOp.getOperand();
2014 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2016 std::optional<int64_t> integerWidth =
2018 if (integerWidth && integerWidth.value() == 1 &&
2019 getPredicate() == arith::CmpIPredicate::ne)
2020 return extOp.getOperand();
2025 getPredicate() == arith::CmpIPredicate::ne)
2032 getPredicate() == arith::CmpIPredicate::eq)
2037 if (adaptor.getLhs() && !adaptor.getRhs()) {
2039 using Pred = CmpIPredicate;
2040 const std::pair<Pred, Pred> invPreds[] = {
2041 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2042 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2043 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2044 {Pred::ne, Pred::ne},
2046 Pred origPred = getPredicate();
2047 for (
auto pred : invPreds) {
2048 if (origPred == pred.first) {
2049 setPredicate(pred.second);
2050 Value
lhs = getLhs();
2051 Value
rhs = getRhs();
2052 getLhsMutable().assign(
rhs);
2053 getRhsMutable().assign(
lhs);
2057 llvm_unreachable(
"unknown cmpi predicate kind");
2062 if (
auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2065 [pred = getPredicate()](
const APInt &
lhs,
const APInt &
rhs) {
2074void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
2075 MLIRContext *context) {
2076 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2086 const APFloat &
lhs,
const APFloat &
rhs) {
2087 auto cmpResult =
lhs.compare(
rhs);
2088 switch (predicate) {
2089 case arith::CmpFPredicate::AlwaysFalse:
2091 case arith::CmpFPredicate::OEQ:
2092 return cmpResult == APFloat::cmpEqual;
2093 case arith::CmpFPredicate::OGT:
2094 return cmpResult == APFloat::cmpGreaterThan;
2095 case arith::CmpFPredicate::OGE:
2096 return cmpResult == APFloat::cmpGreaterThan ||
2097 cmpResult == APFloat::cmpEqual;
2098 case arith::CmpFPredicate::OLT:
2099 return cmpResult == APFloat::cmpLessThan;
2100 case arith::CmpFPredicate::OLE:
2101 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2102 case arith::CmpFPredicate::ONE:
2103 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2104 case arith::CmpFPredicate::ORD:
2105 return cmpResult != APFloat::cmpUnordered;
2106 case arith::CmpFPredicate::UEQ:
2107 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2108 case arith::CmpFPredicate::UGT:
2109 return cmpResult == APFloat::cmpUnordered ||
2110 cmpResult == APFloat::cmpGreaterThan;
2111 case arith::CmpFPredicate::UGE:
2112 return cmpResult == APFloat::cmpUnordered ||
2113 cmpResult == APFloat::cmpGreaterThan ||
2114 cmpResult == APFloat::cmpEqual;
2115 case arith::CmpFPredicate::ULT:
2116 return cmpResult == APFloat::cmpUnordered ||
2117 cmpResult == APFloat::cmpLessThan;
2118 case arith::CmpFPredicate::ULE:
2119 return cmpResult == APFloat::cmpUnordered ||
2120 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2121 case arith::CmpFPredicate::UNE:
2122 return cmpResult != APFloat::cmpEqual;
2123 case arith::CmpFPredicate::UNO:
2124 return cmpResult == APFloat::cmpUnordered;
2125 case arith::CmpFPredicate::AlwaysTrue:
2128 llvm_unreachable(
"unknown cmpf predicate kind");
2132 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2133 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2136 if (
lhs &&
lhs.getValue().isNaN())
2138 if (
rhs &&
rhs.getValue().isNaN())
2154 using namespace arith;
2156 case CmpFPredicate::UEQ:
2157 case CmpFPredicate::OEQ:
2158 return CmpIPredicate::eq;
2159 case CmpFPredicate::UGT:
2160 case CmpFPredicate::OGT:
2161 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2162 case CmpFPredicate::UGE:
2163 case CmpFPredicate::OGE:
2164 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2165 case CmpFPredicate::ULT:
2166 case CmpFPredicate::OLT:
2167 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2168 case CmpFPredicate::ULE:
2169 case CmpFPredicate::OLE:
2170 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2171 case CmpFPredicate::UNE:
2172 case CmpFPredicate::ONE:
2173 return CmpIPredicate::ne;
2175 llvm_unreachable(
"Unexpected predicate!");
2185 const APFloat &
rhs = flt.getValue();
2193 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2194 int mantissaWidth = floatTy.getFPMantissaWidth();
2195 if (mantissaWidth <= 0)
2201 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2203 intVal = si.getIn();
2204 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2206 intVal = ui.getIn();
2213 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2214 auto intWidth = intTy.getWidth();
2217 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2222 if ((
int)intWidth > mantissaWidth) {
2224 int exponent = ilogb(
rhs);
2225 if (exponent == APFloat::IEK_Inf) {
2226 int maxExponent = ilogb(APFloat::getLargest(
rhs.getSemantics()));
2227 if (maxExponent < (
int)valueBits) {
2234 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2243 switch (op.getPredicate()) {
2244 case CmpFPredicate::ORD:
2249 case CmpFPredicate::UNO:
2262 APFloat signedMax(
rhs.getSemantics());
2263 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2264 APFloat::rmNearestTiesToEven);
2265 if (signedMax <
rhs) {
2266 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2267 pred == CmpIPredicate::sle)
2278 APFloat unsignedMax(
rhs.getSemantics());
2279 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2280 APFloat::rmNearestTiesToEven);
2281 if (unsignedMax <
rhs) {
2282 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2283 pred == CmpIPredicate::ule)
2295 APFloat signedMin(
rhs.getSemantics());
2296 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2297 APFloat::rmNearestTiesToEven);
2298 if (signedMin >
rhs) {
2299 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2300 pred == CmpIPredicate::sge)
2310 APFloat unsignedMin(
rhs.getSemantics());
2311 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2312 APFloat::rmNearestTiesToEven);
2313 if (unsignedMin >
rhs) {
2314 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2315 pred == CmpIPredicate::uge)
2330 APSInt rhsInt(intWidth, isUnsigned);
2331 if (APFloat::opInvalidOp ==
2332 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2338 if (!
rhs.isZero()) {
2339 APFloat apf(floatTy.getFloatSemantics(),
2340 APInt::getZero(floatTy.getWidth()));
2341 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2343 bool equal = apf ==
rhs;
2349 case CmpIPredicate::ne:
2353 case CmpIPredicate::eq:
2357 case CmpIPredicate::ule:
2360 if (
rhs.isNegative()) {
2366 case CmpIPredicate::sle:
2369 if (
rhs.isNegative())
2370 pred = CmpIPredicate::slt;
2372 case CmpIPredicate::ult:
2375 if (
rhs.isNegative()) {
2380 pred = CmpIPredicate::ule;
2382 case CmpIPredicate::slt:
2385 if (!
rhs.isNegative())
2386 pred = CmpIPredicate::sle;
2388 case CmpIPredicate::ugt:
2391 if (
rhs.isNegative()) {
2397 case CmpIPredicate::sgt:
2400 if (
rhs.isNegative())
2401 pred = CmpIPredicate::sge;
2403 case CmpIPredicate::uge:
2406 if (
rhs.isNegative()) {
2411 pred = CmpIPredicate::ugt;
2413 case CmpIPredicate::sge:
2416 if (!
rhs.isNegative())
2417 pred = CmpIPredicate::sgt;
2427 ConstantOp::create(rewriter, op.getLoc(), intVal.
getType(),
2433void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
2434 MLIRContext *context) {
2435 patterns.insert<CmpFIntToFPConst>(context);
2449 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2465 arith::XOrIOp::create(
2466 rewriter, op.getLoc(), op.getCondition(),
2468 op.getCondition().
getType(), 1)));
2476void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2477 MLIRContext *context) {
2478 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2479 SelectI1ToNot, SelectToExtUI>(context);
2482OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2483 Value trueVal = getTrueValue();
2484 Value falseVal = getFalseValue();
2485 if (trueVal == falseVal)
2488 Value condition = getCondition();
2499 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2502 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2506 if (
getType().isSignlessInteger(1) &&
2512 auto pred = cmp.getPredicate();
2513 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2514 auto cmpLhs = cmp.getLhs();
2515 auto cmpRhs = cmp.getRhs();
2523 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2524 (cmpRhs == trueVal && cmpLhs == falseVal))
2525 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2532 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2534 assert(cond.getType().hasStaticShape() &&
2535 "DenseElementsAttr must have static shape");
2537 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2539 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2540 SmallVector<Attribute> results;
2541 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2542 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2543 cond.value_end<BoolAttr>());
2544 auto lhsVals = llvm::make_range(
lhs.value_begin<Attribute>(),
2545 lhs.value_end<Attribute>());
2546 auto rhsVals = llvm::make_range(
rhs.value_begin<Attribute>(),
2547 rhs.value_end<Attribute>());
2549 for (
auto [condVal, lhsVal, rhsVal] :
2550 llvm::zip_equal(condVals, lhsVals, rhsVals))
2551 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2561ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &
result) {
2562 Type conditionType, resultType;
2563 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2571 conditionType = resultType;
2578 result.addTypes(resultType);
2580 {conditionType, resultType, resultType},
2584void arith::SelectOp::print(OpAsmPrinter &p) {
2585 p <<
" " << getOperands();
2588 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().
getType()))
2589 p << condType <<
", ";
2593LogicalResult arith::SelectOp::verify() {
2594 Type conditionType = getCondition().getType();
2601 if (!llvm::isa<TensorType, VectorType>(resultType))
2602 return emitOpError() <<
"expected condition to be a signless i1, but got "
2605 if (conditionType != shapedConditionType) {
2606 return emitOpError() <<
"expected condition type to have the same shape "
2607 "as the result type, expected "
2608 << shapedConditionType <<
", but got "
2617OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2622 bool bounded =
false;
2624 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2625 bounded = b.ult(b.getBitWidth());
2628 return bounded ?
result : Attribute();
2635OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2640 bool bounded =
false;
2642 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2643 bounded = b.ult(b.getBitWidth());
2646 return bounded ?
result : Attribute();
2653OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2658 bool bounded =
false;
2660 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2661 bounded = b.ult(b.getBitWidth());
2664 return bounded ?
result : Attribute();
2674 bool useOnlyFiniteValue) {
2676 case AtomicRMWKind::maximumf: {
2677 const llvm::fltSemantics &semantic =
2678 llvm::cast<FloatType>(resultType).getFloatSemantics();
2679 APFloat identity = useOnlyFiniteValue
2680 ? APFloat::getLargest(semantic,
true)
2681 : APFloat::getInf(semantic,
true);
2684 case AtomicRMWKind::maxnumf: {
2685 const llvm::fltSemantics &semantic =
2686 llvm::cast<FloatType>(resultType).getFloatSemantics();
2687 APFloat identity = APFloat::getNaN(semantic,
true);
2690 case AtomicRMWKind::addf:
2691 case AtomicRMWKind::addi:
2692 case AtomicRMWKind::maxu:
2693 case AtomicRMWKind::ori:
2694 case AtomicRMWKind::xori:
2696 case AtomicRMWKind::andi:
2699 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2700 case AtomicRMWKind::maxs:
2702 resultType, APInt::getSignedMinValue(
2703 llvm::cast<IntegerType>(resultType).getWidth()));
2704 case AtomicRMWKind::minimumf: {
2705 const llvm::fltSemantics &semantic =
2706 llvm::cast<FloatType>(resultType).getFloatSemantics();
2707 APFloat identity = useOnlyFiniteValue
2708 ? APFloat::getLargest(semantic,
false)
2709 : APFloat::getInf(semantic,
false);
2713 case AtomicRMWKind::minnumf: {
2714 const llvm::fltSemantics &semantic =
2715 llvm::cast<FloatType>(resultType).getFloatSemantics();
2716 APFloat identity = APFloat::getNaN(semantic,
false);
2719 case AtomicRMWKind::mins:
2721 resultType, APInt::getSignedMaxValue(
2722 llvm::cast<IntegerType>(resultType).getWidth()));
2723 case AtomicRMWKind::minu:
2726 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2727 case AtomicRMWKind::muli:
2729 case AtomicRMWKind::mulf:
2741 std::optional<AtomicRMWKind> maybeKind =
2744 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2745 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2746 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2747 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2748 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2749 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2751 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2752 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2753 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::xori; })
2754 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2755 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2756 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2757 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2758 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2759 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2760 .Default(std::nullopt);
2762 return std::nullopt;
2765 bool useOnlyFiniteValue =
false;
2766 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2767 if (fmfOpInterface) {
2768 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2769 useOnlyFiniteValue =
2770 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2778 useOnlyFiniteValue);
2784 bool useOnlyFiniteValue) {
2787 return arith::ConstantOp::create(builder, loc, attr);
2795 case AtomicRMWKind::addf:
2796 return arith::AddFOp::create(builder, loc,
lhs,
rhs);
2797 case AtomicRMWKind::addi:
2798 return arith::AddIOp::create(builder, loc,
lhs,
rhs);
2799 case AtomicRMWKind::mulf:
2800 return arith::MulFOp::create(builder, loc,
lhs,
rhs);
2801 case AtomicRMWKind::muli:
2802 return arith::MulIOp::create(builder, loc,
lhs,
rhs);
2803 case AtomicRMWKind::maximumf:
2804 return arith::MaximumFOp::create(builder, loc,
lhs,
rhs);
2805 case AtomicRMWKind::minimumf:
2806 return arith::MinimumFOp::create(builder, loc,
lhs,
rhs);
2807 case AtomicRMWKind::maxnumf:
2808 return arith::MaxNumFOp::create(builder, loc,
lhs,
rhs);
2809 case AtomicRMWKind::minnumf:
2810 return arith::MinNumFOp::create(builder, loc,
lhs,
rhs);
2811 case AtomicRMWKind::maxs:
2812 return arith::MaxSIOp::create(builder, loc,
lhs,
rhs);
2813 case AtomicRMWKind::mins:
2814 return arith::MinSIOp::create(builder, loc,
lhs,
rhs);
2815 case AtomicRMWKind::maxu:
2816 return arith::MaxUIOp::create(builder, loc,
lhs,
rhs);
2817 case AtomicRMWKind::minu:
2818 return arith::MinUIOp::create(builder, loc,
lhs,
rhs);
2819 case AtomicRMWKind::ori:
2820 return arith::OrIOp::create(builder, loc,
lhs,
rhs);
2821 case AtomicRMWKind::andi:
2822 return arith::AndIOp::create(builder, loc,
lhs,
rhs);
2823 case AtomicRMWKind::xori:
2824 return arith::XOrIOp::create(builder, loc,
lhs,
rhs);
2837#define GET_OP_CLASSES
2838#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2844#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static LogicalResult verifyTruncateOp(Op op)
static Type getElementType(Type type)
Determine the element type of type.
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Specialization of arith.constant op that returns an integer of index type.
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Specialization of arith.constant op that returns an integer value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
llvm::function_ref< Fn > function_ref
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.