26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
44 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
45 const APInt &lhsVal = llvm::cast<IntegerAttr>(
lhs).getValue();
46 const APInt &rhsVal = llvm::cast<IntegerAttr>(
rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
48 return IntegerAttr::get(res.
getType(), value);
67static IntegerOverflowFlagsAttr
69 IntegerOverflowFlagsAttr val2) {
70 return IntegerOverflowFlagsAttr::get(val1.getContext(),
71 val1.getValue() & val2.getValue());
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
98 llvm_unreachable(
"unknown cmpi predicate kind");
107static llvm::RoundingMode
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
121 llvm_unreachable(
"Unhandled rounding mode");
125 return arith::CmpIPredicateAttr::get(pred.getContext(),
151 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
155 if (!shapedType.hasStaticShape())
165#include "ArithCanonicalization.inc"
174 auto i1Type = IntegerType::get(type.
getContext(), 1);
175 if (
auto shapedType = dyn_cast<ShapedType>(type))
176 return shapedType.cloneWith(std::nullopt, i1Type);
177 if (llvm::isa<UnrankedTensorType>(type))
178 return UnrankedTensorType::get(i1Type);
186void arith::ConstantOp::getAsmResultNames(
189 if (
auto intCst = dyn_cast<IntegerAttr>(getValue())) {
190 auto intType = dyn_cast<IntegerType>(type);
193 if (intType && intType.getWidth() == 1)
194 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
197 SmallString<32> specialNameBuffer;
198 llvm::raw_svector_ostream specialName(specialNameBuffer);
199 specialName <<
'c' << intCst.getValue();
201 specialName <<
'_' << type;
202 setNameFn(getResult(), specialName.str());
204 setNameFn(getResult(),
"cst");
210LogicalResult arith::ConstantOp::verify() {
213 if (llvm::isa<IntegerType>(type) &&
214 !llvm::cast<IntegerType>(type).isSignless())
215 return emitOpError(
"integer return type must be signless");
217 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
219 "value must be an integer, float, or elements attribute");
225 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
227 "initializing scalable vectors with elements attribute is not supported"
228 " unless it's a vector splat");
232bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
234 auto typedAttr = dyn_cast<TypedAttr>(value);
235 if (!typedAttr || typedAttr.getType() != type)
239 if (!intType.isSignless())
243 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
246ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
247 Type type, Location loc) {
248 if (isBuildableWith(value, type))
249 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
253OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
258 arith::ConstantOp::build(builder,
result, type,
268 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
269 assert(
result &&
"builder didn't return the right type");
281 arith::ConstantOp::build(builder,
result, type,
290 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
291 assert(
result &&
"builder didn't return the right type");
302 arith::ConstantOp::build(builder,
result, type,
308 const APInt &
value) {
311 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
312 assert(
result &&
"builder didn't return the right type");
318 const APInt &
value) {
323 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
324 return constOp.getType().isSignlessInteger();
329 FloatType type,
const APFloat &
value) {
330 arith::ConstantOp::build(builder,
result, type,
337 const APFloat &
value) {
340 auto result = dyn_cast<ConstantFloatOp>(builder.
create(state));
341 assert(
result &&
"builder didn't return the right type");
347 const APFloat &
value) {
352 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
353 return llvm::isa<FloatType>(constOp.getType());
368 auto result = dyn_cast<ConstantIndexOp>(builder.
create(state));
369 assert(
result &&
"builder didn't return the right type");
379 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
380 return constOp.getType().isIndex();
388 "type doesn't have a zero representation");
390 assert(zeroAttr &&
"unsupported type for zero attribute");
391 return arith::ConstantOp::create(builder, loc, zeroAttr);
404 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
405 if (getRhs() == sub.getRhs())
409 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
410 if (getLhs() == sub.getRhs())
414 adaptor.getOperands(),
415 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
420 patterns.
add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
421 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
428std::optional<SmallVector<int64_t, 4>>
429arith::AddUIExtendedOp::getShapeForUnroll() {
430 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
431 return llvm::to_vector<4>(vt.getShape());
438 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
442arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
443 SmallVectorImpl<OpFoldResult> &results) {
444 Type overflowTy = getOverflow().getType();
450 results.push_back(getLhs());
451 results.push_back(falseValue);
460 adaptor.getOperands(),
461 [](APInt a,
const APInt &
b) { return std::move(a) + b; })) {
464 results.push_back(sumAttr);
465 results.push_back(sumAttr);
469 ArrayRef({sumAttr, adaptor.getLhs()}),
475 results.push_back(sumAttr);
476 results.push_back(overflowAttr);
483void arith::AddUIExtendedOp::getCanonicalizationPatterns(
484 RewritePatternSet &patterns, MLIRContext *context) {
485 patterns.
add<AddUIExtendedToAddI>(context);
492OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
494 if (getOperand(0) == getOperand(1)) {
495 auto shapedType = dyn_cast<ShapedType>(
getType());
497 if (!shapedType || shapedType.hasStaticShape())
504 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
506 if (getRhs() ==
add.getRhs())
509 if (getRhs() ==
add.getLhs())
514 adaptor.getOperands(),
515 [](APInt a,
const APInt &
b) { return std::move(a) - b; });
518void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
519 MLIRContext *context) {
520 patterns.
add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
521 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
522 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
529OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
540 adaptor.getOperands(),
541 [](
const APInt &a,
const APInt &
b) { return a * b; });
544void arith::MulIOp::getAsmResultNames(
546 if (!isa<IndexType>(
getType()))
551 auto isVscale = [](Operation *op) {
552 return op && op->getName().getStringRef() ==
"vector.vscale";
555 IntegerAttr baseValue;
556 auto isVscaleExpr = [&](Value a, Value
b) {
558 isVscale(
b.getDefiningOp());
561 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
565 SmallString<32> specialNameBuffer;
566 llvm::raw_svector_ostream specialName(specialNameBuffer);
567 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
568 setNameFn(getResult(), specialName.str());
571void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
572 MLIRContext *context) {
573 patterns.
add<MulIMulIConstant>(context);
580std::optional<SmallVector<int64_t, 4>>
581arith::MulSIExtendedOp::getShapeForUnroll() {
582 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
583 return llvm::to_vector<4>(vt.getShape());
588arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
589 SmallVectorImpl<OpFoldResult> &results) {
592 Attribute zero = adaptor.getRhs();
593 results.push_back(zero);
594 results.push_back(zero);
600 adaptor.getOperands(),
601 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
604 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
605 return llvm::APIntOps::mulhs(a, b);
607 assert(highAttr &&
"Unexpected constant-folding failure");
609 results.push_back(lowAttr);
610 results.push_back(highAttr);
617void arith::MulSIExtendedOp::getCanonicalizationPatterns(
618 RewritePatternSet &patterns, MLIRContext *context) {
619 patterns.
add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
626std::optional<SmallVector<int64_t, 4>>
627arith::MulUIExtendedOp::getShapeForUnroll() {
628 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
629 return llvm::to_vector<4>(vt.getShape());
634arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
635 SmallVectorImpl<OpFoldResult> &results) {
638 Attribute zero = adaptor.getRhs();
639 results.push_back(zero);
640 results.push_back(zero);
648 results.push_back(getLhs());
649 results.push_back(zero);
655 adaptor.getOperands(),
656 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
659 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
660 return llvm::APIntOps::mulhu(a, b);
662 assert(highAttr &&
"Unexpected constant-folding failure");
664 results.push_back(lowAttr);
665 results.push_back(highAttr);
672void arith::MulUIExtendedOp::getCanonicalizationPatterns(
673 RewritePatternSet &patterns, MLIRContext *context) {
674 patterns.
add<MulUIExtendedToMulI>(context);
683 arith::IntegerOverflowFlags ovfFlags) {
684 auto mul =
lhs.getDefiningOp<mlir::arith::MulIOp>();
685 if (!
mul || !bitEnumContainsAll(
mul.getOverflowFlags(), ovfFlags))
697OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
703 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
709 [&](APInt a,
const APInt &
b) {
717 return div0 ? Attribute() :
result;
737OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
743 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
747 bool overflowOrDiv0 =
false;
749 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
750 if (overflowOrDiv0 || !b) {
751 overflowOrDiv0 = true;
754 return a.sdiv_ov(
b, overflowOrDiv0);
757 return overflowOrDiv0 ? Attribute() :
result;
784 APInt one(a.getBitWidth(), 1,
true);
785 APInt val = a.ssub_ov(one, overflow).sdiv_ov(
b, overflow);
786 return val.sadd_ov(one, overflow);
793OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
798 bool overflowOrDiv0 =
false;
800 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
801 if (overflowOrDiv0 || !b) {
802 overflowOrDiv0 = true;
805 APInt quotient = a.udiv(
b);
808 APInt one(a.getBitWidth(), 1,
true);
809 return quotient.uadd_ov(one, overflowOrDiv0);
812 return overflowOrDiv0 ? Attribute() :
result;
823OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
831 bool overflowOrDiv0 =
false;
833 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
834 if (overflowOrDiv0 || !b) {
835 overflowOrDiv0 = true;
841 unsigned bits = a.getBitWidth();
842 APInt zero = APInt::getZero(bits);
843 bool aGtZero = a.sgt(zero);
844 bool bGtZero =
b.sgt(zero);
845 if (aGtZero && bGtZero) {
852 bool overflowNegA =
false;
853 bool overflowNegB =
false;
854 bool overflowDiv =
false;
855 bool overflowNegRes =
false;
856 if (!aGtZero && !bGtZero) {
858 APInt posA = zero.ssub_ov(a, overflowNegA);
859 APInt posB = zero.ssub_ov(
b, overflowNegB);
861 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
864 if (!aGtZero && bGtZero) {
866 APInt posA = zero.ssub_ov(a, overflowNegA);
867 APInt
div = posA.sdiv_ov(
b, overflowDiv);
868 APInt res = zero.ssub_ov(
div, overflowNegRes);
869 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
873 APInt posB = zero.ssub_ov(
b, overflowNegB);
874 APInt
div = a.sdiv_ov(posB, overflowDiv);
875 APInt res = zero.ssub_ov(
div, overflowNegRes);
877 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
881 return overflowOrDiv0 ? Attribute() :
result;
892OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
898 bool overflowOrDiv =
false;
900 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
902 overflowOrDiv = true;
905 return a.sfloordiv_ov(
b, overflowOrDiv);
908 return overflowOrDiv ? Attribute() :
result;
915OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
923 [&](APInt a,
const APInt &
b) {
924 if (div0 || b.isZero()) {
931 return div0 ? Attribute() :
result;
942OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
950 [&](APInt a,
const APInt &
b) {
951 if (div0 || b.isZero()) {
958 return div0 ? Attribute() :
result;
976 for (
bool reversePrev : {
false,
true}) {
977 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
978 .getDefiningOp<arith::AndIOp>();
982 Value other = (reversePrev ? op.getLhs() : op.getRhs());
983 if (other != prev.getLhs() && other != prev.getRhs())
986 return prev.getResult();
991OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
998 intValue.isAllOnes())
1003 intValue.isAllOnes())
1008 intValue.isAllOnes())
1016 adaptor.getOperands(),
1017 [](APInt a,
const APInt &
b) { return std::move(a) & b; });
1024OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1027 if (rhsVal.isZero())
1030 if (rhsVal.isAllOnes())
1031 return adaptor.getRhs();
1038 intValue.isAllOnes())
1039 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1043 intValue.isAllOnes())
1044 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1047 adaptor.getOperands(),
1048 [](APInt a,
const APInt &
b) { return std::move(a) | b; });
1055OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1060 if (getLhs() == getRhs())
1064 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1065 if (prev.getRhs() == getRhs())
1066 return prev.getLhs();
1067 if (prev.getLhs() == getRhs())
1068 return prev.getRhs();
1072 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1073 if (prev.getRhs() == getLhs())
1074 return prev.getLhs();
1075 if (prev.getLhs() == getLhs())
1076 return prev.getRhs();
1080 adaptor.getOperands(),
1081 [](APInt a,
const APInt &
b) { return std::move(a) ^ b; });
1084void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1085 MLIRContext *context) {
1086 patterns.
add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1093OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1095 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1096 return op.getOperand();
1098 [](
const APFloat &a) { return -a; });
1105OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1111 adaptor.getOperands(),
1112 [](
const APFloat &a,
const APFloat &
b) { return a + b; });
1119OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1125 adaptor.getOperands(),
1126 [](
const APFloat &a,
const APFloat &
b) { return a - b; });
1133OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1135 if (getLhs() == getRhs())
1143 adaptor.getOperands(),
1144 [](
const APFloat &a,
const APFloat &
b) { return llvm::maximum(a, b); });
1151OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1153 if (getLhs() == getRhs())
1167OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1169 if (getLhs() == getRhs())
1175 if (intValue.isMaxSignedValue())
1178 if (intValue.isMinSignedValue())
1183 [](
const APInt &a,
const APInt &
b) {
1184 return llvm::APIntOps::smax(a, b);
1192OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1194 if (getLhs() == getRhs())
1200 if (intValue.isMaxValue())
1203 if (intValue.isMinValue())
1208 [](
const APInt &a,
const APInt &
b) {
1209 return llvm::APIntOps::umax(a, b);
1217OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1219 if (getLhs() == getRhs())
1227 adaptor.getOperands(),
1228 [](
const APFloat &a,
const APFloat &
b) { return llvm::minimum(a, b); });
1235OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1237 if (getLhs() == getRhs())
1245 adaptor.getOperands(),
1246 [](
const APFloat &a,
const APFloat &
b) { return llvm::minnum(a, b); });
1253OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1255 if (getLhs() == getRhs())
1261 if (intValue.isMinSignedValue())
1264 if (intValue.isMaxSignedValue())
1269 [](
const APInt &a,
const APInt &
b) {
1270 return llvm::APIntOps::smin(a, b);
1278OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1280 if (getLhs() == getRhs())
1286 if (intValue.isMinValue())
1289 if (intValue.isMaxValue())
1294 [](
const APInt &a,
const APInt &
b) {
1295 return llvm::APIntOps::umin(a, b);
1303OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1308 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1309 arith::FastMathFlags::nsz)) {
1316 adaptor.getOperands(),
1317 [](
const APFloat &a,
const APFloat &
b) { return a * b; });
1320void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1321 MLIRContext *context) {
1322 patterns.
add<MulFOfNegF>(context);
1329OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1335 adaptor.getOperands(),
1336 [](
const APFloat &a,
const APFloat &
b) { return a / b; });
1339void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1340 MLIRContext *context) {
1341 patterns.
add<DivFOfNegF>(context);
1348OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1350 [](
const APFloat &a,
const APFloat &
b) {
1355 (void)result.mod(b);
1364template <
typename... Types>
1370template <
typename... ShapedTypes,
typename... ElementTypes>
1373 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1377 if (!llvm::isa<ElementTypes...>(underlyingType))
1380 return underlyingType;
1384template <
typename... ElementTypes>
1391template <
typename... ElementTypes>
1400 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1401 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1402 if (!rankedTensorA || !rankedTensorB)
1404 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1408 if (inputs.size() != 1 || outputs.size() != 1)
1420template <
typename ValType,
typename Op>
1425 if (llvm::cast<ValType>(srcType).getWidth() >=
1426 llvm::cast<ValType>(dstType).getWidth())
1428 << dstType <<
" must be wider than operand type " << srcType;
1434template <
typename ValType,
typename Op>
1439 if (llvm::cast<ValType>(srcType).getWidth() <=
1440 llvm::cast<ValType>(dstType).getWidth())
1442 << dstType <<
" must be shorter than operand type " << srcType;
1448template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1453 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1454 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1455 if (!srcType || !dstType)
1458 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1459 srcType.getIntOrFloatBitWidth());
1465 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1466 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1469 using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior;
1470 if (sourceValue.isInfinity() &&
1471 (targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
1472 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly))
1474 if (sourceValue.isNaN() &&
1475 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly)
1478 bool losesInfo =
false;
1479 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1480 if (losesInfo || status != APFloat::opOK)
1490OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1491 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1492 getInMutable().assign(
lhs.getIn());
1497 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1499 adaptor.getOperands(),
getType(),
1500 [bitWidth](
const APInt &a,
bool &castStatus) {
1501 return a.zext(bitWidth);
1509LogicalResult arith::ExtUIOp::verify() {
1517OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1518 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1519 getInMutable().assign(
lhs.getIn());
1524 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1526 adaptor.getOperands(),
getType(),
1527 [bitWidth](
const APInt &a,
bool &castStatus) {
1528 return a.sext(bitWidth);
1536void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1537 MLIRContext *context) {
1538 patterns.
add<ExtSIOfExtUI>(context);
1541LogicalResult arith::ExtSIOp::verify() {
1551OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1552 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1553 if (truncFOp.getOperand().getType() ==
getType()) {
1554 arith::FastMathFlags truncFMF =
1555 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1556 bool isTruncContract =
1557 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1558 arith::FastMathFlags extFMF =
1559 getFastmath().value_or(arith::FastMathFlags::none);
1560 bool isExtContract =
1561 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1562 if (isTruncContract && isExtContract) {
1563 return truncFOp.getOperand();
1569 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1571 adaptor.getOperands(),
getType(),
1572 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1592bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
1597LogicalResult arith::ScalingExtFOp::verify() {
1605OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1608 Value src = getOperand().getDefiningOp()->getOperand(0);
1613 if (llvm::cast<IntegerType>(srcType).getWidth() >
1614 llvm::cast<IntegerType>(dstType).getWidth()) {
1621 if (srcType == dstType)
1627 setOperand(getOperand().getDefiningOp()->getOperand(0));
1632 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1634 adaptor.getOperands(),
getType(),
1635 [bitWidth](
const APInt &a,
bool &castStatus) {
1636 return a.trunc(bitWidth);
1644void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1645 MLIRContext *context) {
1647 .
add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1651LogicalResult arith::TruncIOp::verify() {
1661OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1663 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1664 Value src = extOp.getIn();
1666 auto intermediateType =
1669 if (llvm::APFloatBase::isRepresentableBy(
1670 srcType.getFloatSemantics(),
1671 intermediateType.getFloatSemantics())) {
1673 if (srcType.getWidth() > resElemType.getWidth()) {
1679 if (srcType == resElemType)
1684 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1686 adaptor.getOperands(),
getType(),
1687 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1688 RoundingMode roundingMode =
1689 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1690 llvm::RoundingMode llvmRoundingMode =
1692 FailureOr<APFloat>
result =
1702void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1703 MLIRContext *context) {
1704 patterns.
add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1711LogicalResult arith::TruncFOp::verify() {
1719OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) {
1721 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1723 adaptor.getOperands(),
getType(),
1724 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1725 RoundingMode roundingMode =
1726 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1727 llvm::RoundingMode llvmRoundingMode =
1729 FailureOr<APFloat>
result =
1744 if (!srcType || !dstType)
1746 return srcType != dstType &&
1750LogicalResult arith::ConvertFOp::verify() {
1753 if (srcType == dstType)
1754 return emitError(
"result element type ")
1755 << dstType <<
" must be different from operand element type "
1757 if (srcType.getWidth() != dstType.getWidth())
1758 return emitError(
"result element type ")
1759 << dstType <<
" must have the same bitwidth as operand element type "
1768bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
1773LogicalResult arith::ScalingTruncFOp::verify() {
1781void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1782 MLIRContext *context) {
1783 patterns.
add<AndOfExtUI, AndOfExtSI>(context);
1790void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1791 MLIRContext *context) {
1792 patterns.
add<OrOfExtUI, OrOfExtSI>(context);
1799template <
typename From,
typename To>
1807 return srcType && dstType;
1818OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1821 adaptor.getOperands(),
getType(),
1822 [&resEleType](
const APInt &a,
bool &castStatus) {
1823 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1824 APFloat apf(floatTy.getFloatSemantics(),
1825 APInt::getZero(floatTy.getWidth()));
1826 apf.convertFromAPInt(a,
false,
1827 APFloat::rmNearestTiesToEven);
1832void arith::UIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1833 MLIRContext *context) {
1834 patterns.
add<UIToFPOfExtUI>(context);
1845OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1848 adaptor.getOperands(),
getType(),
1849 [&resEleType](
const APInt &a,
bool &castStatus) {
1850 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1851 APFloat apf(floatTy.getFloatSemantics(),
1852 APInt::getZero(floatTy.getWidth()));
1853 apf.convertFromAPInt(a,
true,
1854 APFloat::rmNearestTiesToEven);
1859void arith::SIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1860 MLIRContext *context) {
1861 patterns.
add<SIToFPOfExtSI, SIToFPOfExtUI>(context);
1872OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1874 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1876 adaptor.getOperands(),
getType(),
1877 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1879 APSInt api(bitWidth,
true);
1880 castStatus = APFloat::opInvalidOp !=
1881 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1894OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1896 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1898 adaptor.getOperands(),
getType(),
1899 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1901 APSInt api(bitWidth,
false);
1902 castStatus = APFloat::opInvalidOp !=
1903 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1918 if (!srcType || !dstType)
1922 (srcType.isSignlessInteger() && dstType.
isIndex());
1925bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1930OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1932 unsigned resultBitwidth = 64;
1934 resultBitwidth = intTy.getWidth();
1937 adaptor.getOperands(),
getType(),
1938 [resultBitwidth](
const APInt &a,
bool & ) {
1939 return a.sextOrTrunc(resultBitwidth);
1943void arith::IndexCastOp::getCanonicalizationPatterns(
1944 RewritePatternSet &patterns, MLIRContext *context) {
1945 patterns.
add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1952bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1957OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1959 unsigned resultBitwidth = 64;
1961 resultBitwidth = intTy.getWidth();
1964 adaptor.getOperands(),
getType(),
1965 [resultBitwidth](
const APInt &a,
bool & ) {
1966 return a.zextOrTrunc(resultBitwidth);
1970void arith::IndexCastUIOp::getCanonicalizationPatterns(
1971 RewritePatternSet &patterns, MLIRContext *context) {
1972 patterns.
add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1985 if (!srcType || !dstType)
1991OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1993 auto operand = adaptor.getIn();
1998 if (
auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1999 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
2001 if (llvm::isa<ShapedType>(resType))
2009 APInt bits = llvm::isa<FloatAttr>(operand)
2010 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
2011 : llvm::cast<IntegerAttr>(operand).getValue();
2013 "trying to fold on broken IR: operands have incompatible types");
2015 if (
auto resFloatType = dyn_cast<FloatType>(resType))
2016 return FloatAttr::get(resType,
2017 APFloat(resFloatType.getFloatSemantics(), bits));
2018 return IntegerAttr::get(resType, bits);
2021void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2022 MLIRContext *context) {
2023 patterns.
add<BitcastOfBitcast>(context);
2033 const APInt &
lhs,
const APInt &
rhs) {
2034 switch (predicate) {
2035 case arith::CmpIPredicate::eq:
2037 case arith::CmpIPredicate::ne:
2039 case arith::CmpIPredicate::slt:
2041 case arith::CmpIPredicate::sle:
2043 case arith::CmpIPredicate::sgt:
2045 case arith::CmpIPredicate::sge:
2047 case arith::CmpIPredicate::ult:
2049 case arith::CmpIPredicate::ule:
2051 case arith::CmpIPredicate::ugt:
2053 case arith::CmpIPredicate::uge:
2056 llvm_unreachable(
"unknown cmpi predicate kind");
2061 switch (predicate) {
2062 case arith::CmpIPredicate::eq:
2063 case arith::CmpIPredicate::sle:
2064 case arith::CmpIPredicate::sge:
2065 case arith::CmpIPredicate::ule:
2066 case arith::CmpIPredicate::uge:
2068 case arith::CmpIPredicate::ne:
2069 case arith::CmpIPredicate::slt:
2070 case arith::CmpIPredicate::sgt:
2071 case arith::CmpIPredicate::ult:
2072 case arith::CmpIPredicate::ugt:
2075 llvm_unreachable(
"unknown cmpi predicate kind");
2079 if (
auto intType = dyn_cast<IntegerType>(t)) {
2080 return intType.getWidth();
2082 if (
auto vectorIntType = dyn_cast<VectorType>(t)) {
2083 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
2085 return std::nullopt;
2088OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2090 if (getLhs() == getRhs()) {
2096 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2098 std::optional<int64_t> integerWidth =
2100 if (integerWidth && integerWidth.value() == 1 &&
2101 getPredicate() == arith::CmpIPredicate::ne)
2102 return extOp.getOperand();
2104 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2106 std::optional<int64_t> integerWidth =
2108 if (integerWidth && integerWidth.value() == 1 &&
2109 getPredicate() == arith::CmpIPredicate::ne)
2110 return extOp.getOperand();
2115 getPredicate() == arith::CmpIPredicate::ne)
2122 getPredicate() == arith::CmpIPredicate::eq)
2127 if (adaptor.getLhs() && !adaptor.getRhs()) {
2129 using Pred = CmpIPredicate;
2130 const std::pair<Pred, Pred> invPreds[] = {
2131 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2132 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2133 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2134 {Pred::ne, Pred::ne},
2136 Pred origPred = getPredicate();
2137 for (
auto pred : invPreds) {
2138 if (origPred == pred.first) {
2139 setPredicate(pred.second);
2140 Value
lhs = getLhs();
2141 Value
rhs = getRhs();
2142 getLhsMutable().assign(
rhs);
2143 getRhsMutable().assign(
lhs);
2147 llvm_unreachable(
"unknown cmpi predicate kind");
2152 if (
auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2155 [pred = getPredicate()](
const APInt &
lhs,
const APInt &
rhs) {
2164void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2165 MLIRContext *context) {
2166 patterns.
insert<CmpIExtSI, CmpIExtUI>(context);
2176 const APFloat &
lhs,
const APFloat &
rhs) {
2177 auto cmpResult =
lhs.compare(
rhs);
2178 switch (predicate) {
2179 case arith::CmpFPredicate::AlwaysFalse:
2181 case arith::CmpFPredicate::OEQ:
2182 return cmpResult == APFloat::cmpEqual;
2183 case arith::CmpFPredicate::OGT:
2184 return cmpResult == APFloat::cmpGreaterThan;
2185 case arith::CmpFPredicate::OGE:
2186 return cmpResult == APFloat::cmpGreaterThan ||
2187 cmpResult == APFloat::cmpEqual;
2188 case arith::CmpFPredicate::OLT:
2189 return cmpResult == APFloat::cmpLessThan;
2190 case arith::CmpFPredicate::OLE:
2191 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2192 case arith::CmpFPredicate::ONE:
2193 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2194 case arith::CmpFPredicate::ORD:
2195 return cmpResult != APFloat::cmpUnordered;
2196 case arith::CmpFPredicate::UEQ:
2197 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2198 case arith::CmpFPredicate::UGT:
2199 return cmpResult == APFloat::cmpUnordered ||
2200 cmpResult == APFloat::cmpGreaterThan;
2201 case arith::CmpFPredicate::UGE:
2202 return cmpResult == APFloat::cmpUnordered ||
2203 cmpResult == APFloat::cmpGreaterThan ||
2204 cmpResult == APFloat::cmpEqual;
2205 case arith::CmpFPredicate::ULT:
2206 return cmpResult == APFloat::cmpUnordered ||
2207 cmpResult == APFloat::cmpLessThan;
2208 case arith::CmpFPredicate::ULE:
2209 return cmpResult == APFloat::cmpUnordered ||
2210 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2211 case arith::CmpFPredicate::UNE:
2212 return cmpResult != APFloat::cmpEqual;
2213 case arith::CmpFPredicate::UNO:
2214 return cmpResult == APFloat::cmpUnordered;
2215 case arith::CmpFPredicate::AlwaysTrue:
2218 llvm_unreachable(
"unknown cmpf predicate kind");
2222 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2223 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2226 if (
lhs &&
lhs.getValue().isNaN())
2228 if (
rhs &&
rhs.getValue().isNaN())
2244 using namespace arith;
2246 case CmpFPredicate::UEQ:
2247 case CmpFPredicate::OEQ:
2248 return CmpIPredicate::eq;
2249 case CmpFPredicate::UGT:
2250 case CmpFPredicate::OGT:
2251 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2252 case CmpFPredicate::UGE:
2253 case CmpFPredicate::OGE:
2254 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2255 case CmpFPredicate::ULT:
2256 case CmpFPredicate::OLT:
2257 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2258 case CmpFPredicate::ULE:
2259 case CmpFPredicate::OLE:
2260 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2261 case CmpFPredicate::UNE:
2262 case CmpFPredicate::ONE:
2263 return CmpIPredicate::ne;
2265 llvm_unreachable(
"Unexpected predicate!");
2275 const APFloat &
rhs = flt.getValue();
2283 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2284 int mantissaWidth = floatTy.getFPMantissaWidth();
2285 if (mantissaWidth <= 0)
2291 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2293 intVal = si.getIn();
2294 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2296 intVal = ui.getIn();
2303 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2304 auto intWidth = intTy.getWidth();
2307 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2312 if ((
int)intWidth > mantissaWidth) {
2314 int exponent = ilogb(
rhs);
2315 if (exponent == APFloat::IEK_Inf) {
2316 int maxExponent = ilogb(APFloat::getLargest(
rhs.getSemantics()));
2317 if (maxExponent < (
int)valueBits) {
2324 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2333 switch (op.getPredicate()) {
2334 case CmpFPredicate::ORD:
2339 case CmpFPredicate::UNO:
2352 APFloat signedMax(
rhs.getSemantics());
2353 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2354 APFloat::rmNearestTiesToEven);
2355 if (signedMax <
rhs) {
2356 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2357 pred == CmpIPredicate::sle)
2368 APFloat unsignedMax(
rhs.getSemantics());
2369 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2370 APFloat::rmNearestTiesToEven);
2371 if (unsignedMax <
rhs) {
2372 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2373 pred == CmpIPredicate::ule)
2385 APFloat signedMin(
rhs.getSemantics());
2386 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2387 APFloat::rmNearestTiesToEven);
2388 if (signedMin >
rhs) {
2389 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2390 pred == CmpIPredicate::sge)
2400 APFloat unsignedMin(
rhs.getSemantics());
2401 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2402 APFloat::rmNearestTiesToEven);
2403 if (unsignedMin >
rhs) {
2404 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2405 pred == CmpIPredicate::uge)
2420 APSInt rhsInt(intWidth, isUnsigned);
2421 if (APFloat::opInvalidOp ==
2422 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2428 if (!
rhs.isZero()) {
2429 APFloat apf(floatTy.getFloatSemantics(),
2430 APInt::getZero(floatTy.getWidth()));
2431 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2433 bool equal = apf ==
rhs;
2439 case CmpIPredicate::ne:
2443 case CmpIPredicate::eq:
2447 case CmpIPredicate::ule:
2450 if (
rhs.isNegative()) {
2456 case CmpIPredicate::sle:
2459 if (
rhs.isNegative())
2460 pred = CmpIPredicate::slt;
2462 case CmpIPredicate::ult:
2465 if (
rhs.isNegative()) {
2470 pred = CmpIPredicate::ule;
2472 case CmpIPredicate::slt:
2475 if (!
rhs.isNegative())
2476 pred = CmpIPredicate::sle;
2478 case CmpIPredicate::ugt:
2481 if (
rhs.isNegative()) {
2487 case CmpIPredicate::sgt:
2490 if (
rhs.isNegative())
2491 pred = CmpIPredicate::sge;
2493 case CmpIPredicate::uge:
2496 if (
rhs.isNegative()) {
2501 pred = CmpIPredicate::ugt;
2503 case CmpIPredicate::sge:
2506 if (!
rhs.isNegative())
2507 pred = CmpIPredicate::sgt;
2517 ConstantOp::create(rewriter, op.getLoc(), intVal.
getType(),
2523void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2524 MLIRContext *context) {
2525 patterns.
insert<CmpFIntToFPConst>(context);
2539 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2555 arith::XOrIOp::create(
2556 rewriter, op.getLoc(), op.getCondition(),
2558 op.getCondition().
getType(), 1)));
2566void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2567 MLIRContext *context) {
2568 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2569 SelectI1ToNot, SelectToExtUI>(context);
2572OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2573 Value trueVal = getTrueValue();
2574 Value falseVal = getFalseValue();
2575 if (trueVal == falseVal)
2578 Value condition = getCondition();
2596 if (
getType().isSignlessInteger(1) &&
2602 auto pred = cmp.getPredicate();
2603 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2604 auto cmpLhs = cmp.getLhs();
2605 auto cmpRhs = cmp.getRhs();
2613 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2614 (cmpRhs == trueVal && cmpLhs == falseVal))
2615 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2622 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2624 assert(cond.getType().hasStaticShape() &&
2625 "DenseElementsAttr must have static shape");
2627 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2629 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2630 SmallVector<Attribute> results;
2631 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2632 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2633 cond.value_end<BoolAttr>());
2634 auto lhsVals = llvm::make_range(
lhs.value_begin<Attribute>(),
2635 lhs.value_end<Attribute>());
2636 auto rhsVals = llvm::make_range(
rhs.value_begin<Attribute>(),
2637 rhs.value_end<Attribute>());
2639 for (
auto [condVal, lhsVal, rhsVal] :
2640 llvm::zip_equal(condVals, lhsVals, rhsVals))
2641 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2651ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &
result) {
2652 Type conditionType, resultType;
2653 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2661 conditionType = resultType;
2668 result.addTypes(resultType);
2670 {conditionType, resultType, resultType},
2674void arith::SelectOp::print(OpAsmPrinter &p) {
2675 p <<
" " << getOperands();
2678 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().
getType()))
2679 p << condType <<
", ";
2683LogicalResult arith::SelectOp::verify() {
2684 Type conditionType = getCondition().getType();
2691 if (!llvm::isa<TensorType, VectorType>(resultType))
2692 return emitOpError() <<
"expected condition to be a signless i1, but got "
2695 if (conditionType != shapedConditionType) {
2696 return emitOpError() <<
"expected condition type to have the same shape "
2697 "as the result type, expected "
2698 << shapedConditionType <<
", but got "
2707OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2712 bool bounded =
false;
2714 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2715 bounded = b.ult(b.getBitWidth());
2718 return bounded ?
result : Attribute();
2725OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2730 bool bounded =
false;
2732 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2733 bounded = b.ult(b.getBitWidth());
2736 return bounded ?
result : Attribute();
2743OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2748 bool bounded =
false;
2750 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2751 bounded = b.ult(b.getBitWidth());
2754 return bounded ?
result : Attribute();
2764 bool useOnlyFiniteValue) {
2766 case AtomicRMWKind::maximumf: {
2767 const llvm::fltSemantics &semantic =
2768 llvm::cast<FloatType>(resultType).getFloatSemantics();
2769 APFloat identity = useOnlyFiniteValue
2770 ? APFloat::getLargest(semantic,
true)
2771 : APFloat::getInf(semantic,
true);
2774 case AtomicRMWKind::maxnumf: {
2775 const llvm::fltSemantics &semantic =
2776 llvm::cast<FloatType>(resultType).getFloatSemantics();
2777 APFloat identity = APFloat::getNaN(semantic,
true);
2780 case AtomicRMWKind::addf:
2781 case AtomicRMWKind::addi:
2782 case AtomicRMWKind::maxu:
2783 case AtomicRMWKind::ori:
2784 case AtomicRMWKind::xori:
2786 case AtomicRMWKind::andi:
2789 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2790 case AtomicRMWKind::maxs:
2792 resultType, APInt::getSignedMinValue(
2793 llvm::cast<IntegerType>(resultType).getWidth()));
2794 case AtomicRMWKind::minimumf: {
2795 const llvm::fltSemantics &semantic =
2796 llvm::cast<FloatType>(resultType).getFloatSemantics();
2797 APFloat identity = useOnlyFiniteValue
2798 ? APFloat::getLargest(semantic,
false)
2799 : APFloat::getInf(semantic,
false);
2803 case AtomicRMWKind::minnumf: {
2804 const llvm::fltSemantics &semantic =
2805 llvm::cast<FloatType>(resultType).getFloatSemantics();
2806 APFloat identity = APFloat::getNaN(semantic,
false);
2809 case AtomicRMWKind::mins:
2811 resultType, APInt::getSignedMaxValue(
2812 llvm::cast<IntegerType>(resultType).getWidth()));
2813 case AtomicRMWKind::minu:
2816 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2817 case AtomicRMWKind::muli:
2819 case AtomicRMWKind::mulf:
2831 std::optional<AtomicRMWKind> maybeKind =
2834 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2835 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2836 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2837 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2838 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2839 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2841 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2842 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2843 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::xori; })
2844 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2845 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2846 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2847 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2848 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2849 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2850 .Default(std::nullopt);
2852 return std::nullopt;
2855 bool useOnlyFiniteValue =
false;
2856 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2857 if (fmfOpInterface) {
2858 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2859 useOnlyFiniteValue =
2860 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2868 useOnlyFiniteValue);
2874 bool useOnlyFiniteValue) {
2877 return arith::ConstantOp::create(builder, loc, attr);
2886 case AtomicRMWKind::addf:
2887 return arith::AddFOp::create(builder, loc,
lhs,
rhs);
2888 case AtomicRMWKind::addi:
2889 return arith::AddIOp::create(builder, loc,
lhs,
rhs);
2890 case AtomicRMWKind::mulf:
2891 return arith::MulFOp::create(builder, loc,
lhs,
rhs);
2892 case AtomicRMWKind::muli:
2893 return arith::MulIOp::create(builder, loc,
lhs,
rhs);
2894 case AtomicRMWKind::maximumf:
2895 return arith::MaximumFOp::create(builder, loc,
lhs,
rhs);
2896 case AtomicRMWKind::minimumf:
2897 return arith::MinimumFOp::create(builder, loc,
lhs,
rhs);
2898 case AtomicRMWKind::maxnumf:
2899 return arith::MaxNumFOp::create(builder, loc,
lhs,
rhs);
2900 case AtomicRMWKind::minnumf:
2901 return arith::MinNumFOp::create(builder, loc,
lhs,
rhs);
2902 case AtomicRMWKind::maxs:
2903 return arith::MaxSIOp::create(builder, loc,
lhs,
rhs);
2904 case AtomicRMWKind::mins:
2905 return arith::MinSIOp::create(builder, loc,
lhs,
rhs);
2906 case AtomicRMWKind::maxu:
2907 return arith::MaxUIOp::create(builder, loc,
lhs,
rhs);
2908 case AtomicRMWKind::minu:
2909 return arith::MinUIOp::create(builder, loc,
lhs,
rhs);
2910 case AtomicRMWKind::ori:
2911 return arith::OrIOp::create(builder, loc,
lhs,
rhs);
2912 case AtomicRMWKind::andi:
2913 return arith::AndIOp::create(builder, loc,
lhs,
rhs);
2914 case AtomicRMWKind::xori:
2915 return arith::XOrIOp::create(builder, loc,
lhs,
rhs);
2928#define GET_OP_CLASSES
2929#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2935#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static LogicalResult verifyTruncateOp(Op op)
static Type getElementType(Type type)
Determine the element type of type.
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Specialization of arith.constant op that returns an integer of index type.
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Specialization of arith.constant op that returns an integer value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
llvm::function_ref< Fn > function_ref
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.