26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
44 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
45 APInt lhsVal = llvm::cast<IntegerAttr>(
lhs).getValue();
46 APInt rhsVal = llvm::cast<IntegerAttr>(
rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
48 return IntegerAttr::get(res.
getType(), value);
67static IntegerOverflowFlagsAttr
69 IntegerOverflowFlagsAttr val2) {
70 return IntegerOverflowFlagsAttr::get(val1.getContext(),
71 val1.getValue() & val2.getValue());
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
98 llvm_unreachable(
"unknown cmpi predicate kind");
107static llvm::RoundingMode
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
121 llvm_unreachable(
"Unhandled rounding mode");
125 return arith::CmpIPredicateAttr::get(pred.getContext(),
151 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
155 if (!shapedType.hasStaticShape())
165#include "ArithCanonicalization.inc"
174 auto i1Type = IntegerType::get(type.
getContext(), 1);
175 if (
auto shapedType = dyn_cast<ShapedType>(type))
176 return shapedType.cloneWith(std::nullopt, i1Type);
177 if (llvm::isa<UnrankedTensorType>(type))
178 return UnrankedTensorType::get(i1Type);
186void arith::ConstantOp::getAsmResultNames(
189 if (
auto intCst = dyn_cast<IntegerAttr>(getValue())) {
190 auto intType = dyn_cast<IntegerType>(type);
193 if (intType && intType.getWidth() == 1)
194 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
197 SmallString<32> specialNameBuffer;
198 llvm::raw_svector_ostream specialName(specialNameBuffer);
199 specialName <<
'c' << intCst.getValue();
201 specialName <<
'_' << type;
202 setNameFn(getResult(), specialName.str());
204 setNameFn(getResult(),
"cst");
210LogicalResult arith::ConstantOp::verify() {
213 if (llvm::isa<IntegerType>(type) &&
214 !llvm::cast<IntegerType>(type).isSignless())
215 return emitOpError(
"integer return type must be signless");
217 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
219 "value must be an integer, float, or elements attribute");
225 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
227 "initializing scalable vectors with elements attribute is not supported"
228 " unless it's a vector splat");
232bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
234 auto typedAttr = dyn_cast<TypedAttr>(value);
235 if (!typedAttr || typedAttr.getType() != type)
239 if (!intType.isSignless())
243 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
246ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
247 Type type, Location loc) {
248 if (isBuildableWith(value, type))
249 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
253OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
258 arith::ConstantOp::build(builder,
result, type,
268 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
269 assert(
result &&
"builder didn't return the right type");
281 arith::ConstantOp::build(builder,
result, type,
290 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
291 assert(
result &&
"builder didn't return the right type");
302 arith::ConstantOp::build(builder,
result, type,
308 const APInt &
value) {
311 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
312 assert(
result &&
"builder didn't return the right type");
318 const APInt &
value) {
323 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
324 return constOp.getType().isSignlessInteger();
329 FloatType type,
const APFloat &
value) {
330 arith::ConstantOp::build(builder,
result, type,
337 const APFloat &
value) {
340 auto result = dyn_cast<ConstantFloatOp>(builder.
create(state));
341 assert(
result &&
"builder didn't return the right type");
347 const APFloat &
value) {
352 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
353 return llvm::isa<FloatType>(constOp.getType());
368 auto result = dyn_cast<ConstantIndexOp>(builder.
create(state));
369 assert(
result &&
"builder didn't return the right type");
379 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
380 return constOp.getType().isIndex();
388 "type doesn't have a zero representation");
390 assert(zeroAttr &&
"unsupported type for zero attribute");
391 return arith::ConstantOp::create(builder, loc, zeroAttr);
404 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
405 if (getRhs() == sub.getRhs())
409 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
410 if (getLhs() == sub.getRhs())
414 adaptor.getOperands(),
415 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
420 patterns.
add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
421 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
428std::optional<SmallVector<int64_t, 4>>
429arith::AddUIExtendedOp::getShapeForUnroll() {
430 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
431 return llvm::to_vector<4>(vt.getShape());
438 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
442arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
443 SmallVectorImpl<OpFoldResult> &results) {
444 Type overflowTy = getOverflow().getType();
450 results.push_back(getLhs());
451 results.push_back(falseValue);
460 adaptor.getOperands(),
461 [](APInt a,
const APInt &
b) { return std::move(a) + b; })) {
464 results.push_back(sumAttr);
465 results.push_back(sumAttr);
469 ArrayRef({sumAttr, adaptor.getLhs()}),
475 results.push_back(sumAttr);
476 results.push_back(overflowAttr);
483void arith::AddUIExtendedOp::getCanonicalizationPatterns(
484 RewritePatternSet &patterns, MLIRContext *context) {
485 patterns.
add<AddUIExtendedToAddI>(context);
492OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
494 if (getOperand(0) == getOperand(1)) {
495 auto shapedType = dyn_cast<ShapedType>(
getType());
497 if (!shapedType || shapedType.hasStaticShape())
504 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
506 if (getRhs() ==
add.getRhs())
509 if (getRhs() ==
add.getLhs())
514 adaptor.getOperands(),
515 [](APInt a,
const APInt &
b) { return std::move(a) - b; });
518void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
519 MLIRContext *context) {
520 patterns.
add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
521 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
522 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
529OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
540 adaptor.getOperands(),
541 [](
const APInt &a,
const APInt &
b) { return a * b; });
544void arith::MulIOp::getAsmResultNames(
546 if (!isa<IndexType>(
getType()))
551 auto isVscale = [](Operation *op) {
552 return op && op->getName().getStringRef() ==
"vector.vscale";
555 IntegerAttr baseValue;
556 auto isVscaleExpr = [&](Value a, Value
b) {
558 isVscale(
b.getDefiningOp());
561 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
565 SmallString<32> specialNameBuffer;
566 llvm::raw_svector_ostream specialName(specialNameBuffer);
567 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
568 setNameFn(getResult(), specialName.str());
571void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
572 MLIRContext *context) {
573 patterns.
add<MulIMulIConstant>(context);
580std::optional<SmallVector<int64_t, 4>>
581arith::MulSIExtendedOp::getShapeForUnroll() {
582 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
583 return llvm::to_vector<4>(vt.getShape());
588arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
589 SmallVectorImpl<OpFoldResult> &results) {
592 Attribute zero = adaptor.getRhs();
593 results.push_back(zero);
594 results.push_back(zero);
600 adaptor.getOperands(),
601 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
604 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
605 return llvm::APIntOps::mulhs(a, b);
607 assert(highAttr &&
"Unexpected constant-folding failure");
609 results.push_back(lowAttr);
610 results.push_back(highAttr);
617void arith::MulSIExtendedOp::getCanonicalizationPatterns(
618 RewritePatternSet &patterns, MLIRContext *context) {
619 patterns.
add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
626std::optional<SmallVector<int64_t, 4>>
627arith::MulUIExtendedOp::getShapeForUnroll() {
628 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
629 return llvm::to_vector<4>(vt.getShape());
634arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
635 SmallVectorImpl<OpFoldResult> &results) {
638 Attribute zero = adaptor.getRhs();
639 results.push_back(zero);
640 results.push_back(zero);
648 results.push_back(getLhs());
649 results.push_back(zero);
655 adaptor.getOperands(),
656 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
659 adaptor.getOperands(), [](
const APInt &a,
const APInt &
b) {
660 return llvm::APIntOps::mulhu(a, b);
662 assert(highAttr &&
"Unexpected constant-folding failure");
664 results.push_back(lowAttr);
665 results.push_back(highAttr);
672void arith::MulUIExtendedOp::getCanonicalizationPatterns(
673 RewritePatternSet &patterns, MLIRContext *context) {
674 patterns.
add<MulUIExtendedToMulI>(context);
683 arith::IntegerOverflowFlags ovfFlags) {
684 auto mul =
lhs.getDefiningOp<mlir::arith::MulIOp>();
685 if (!
mul || !bitEnumContainsAll(
mul.getOverflowFlags(), ovfFlags))
697OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
703 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
709 [&](APInt a,
const APInt &
b) {
717 return div0 ? Attribute() :
result;
737OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
743 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
747 bool overflowOrDiv0 =
false;
749 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
750 if (overflowOrDiv0 || !b) {
751 overflowOrDiv0 = true;
754 return a.sdiv_ov(
b, overflowOrDiv0);
757 return overflowOrDiv0 ? Attribute() :
result;
784 APInt one(a.getBitWidth(), 1,
true);
785 APInt val = a.ssub_ov(one, overflow).sdiv_ov(
b, overflow);
786 return val.sadd_ov(one, overflow);
793OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
798 bool overflowOrDiv0 =
false;
800 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
801 if (overflowOrDiv0 || !b) {
802 overflowOrDiv0 = true;
805 APInt quotient = a.udiv(
b);
808 APInt one(a.getBitWidth(), 1,
true);
809 return quotient.uadd_ov(one, overflowOrDiv0);
812 return overflowOrDiv0 ? Attribute() :
result;
823OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
831 bool overflowOrDiv0 =
false;
833 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
834 if (overflowOrDiv0 || !b) {
835 overflowOrDiv0 = true;
841 unsigned bits = a.getBitWidth();
842 APInt zero = APInt::getZero(bits);
843 bool aGtZero = a.sgt(zero);
844 bool bGtZero =
b.sgt(zero);
845 if (aGtZero && bGtZero) {
852 bool overflowNegA =
false;
853 bool overflowNegB =
false;
854 bool overflowDiv =
false;
855 bool overflowNegRes =
false;
856 if (!aGtZero && !bGtZero) {
858 APInt posA = zero.ssub_ov(a, overflowNegA);
859 APInt posB = zero.ssub_ov(
b, overflowNegB);
861 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
864 if (!aGtZero && bGtZero) {
866 APInt posA = zero.ssub_ov(a, overflowNegA);
867 APInt
div = posA.sdiv_ov(
b, overflowDiv);
868 APInt res = zero.ssub_ov(
div, overflowNegRes);
869 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
873 APInt posB = zero.ssub_ov(
b, overflowNegB);
874 APInt
div = a.sdiv_ov(posB, overflowDiv);
875 APInt res = zero.ssub_ov(
div, overflowNegRes);
877 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
881 return overflowOrDiv0 ? Attribute() :
result;
892OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
898 bool overflowOrDiv =
false;
900 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
902 overflowOrDiv = true;
905 return a.sfloordiv_ov(
b, overflowOrDiv);
908 return overflowOrDiv ? Attribute() :
result;
915OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
923 [&](APInt a,
const APInt &
b) {
924 if (div0 || b.isZero()) {
931 return div0 ? Attribute() :
result;
938OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
946 [&](APInt a,
const APInt &
b) {
947 if (div0 || b.isZero()) {
954 return div0 ? Attribute() :
result;
963 for (
bool reversePrev : {
false,
true}) {
964 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
965 .getDefiningOp<arith::AndIOp>();
969 Value other = (reversePrev ? op.getLhs() : op.getRhs());
970 if (other != prev.getLhs() && other != prev.getRhs())
973 return prev.getResult();
978OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
985 intValue.isAllOnes())
990 intValue.isAllOnes())
995 intValue.isAllOnes())
1003 adaptor.getOperands(),
1004 [](APInt a,
const APInt &
b) { return std::move(a) & b; });
1011OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1014 if (rhsVal.isZero())
1017 if (rhsVal.isAllOnes())
1018 return adaptor.getRhs();
1025 intValue.isAllOnes())
1026 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1030 intValue.isAllOnes())
1031 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1034 adaptor.getOperands(),
1035 [](APInt a,
const APInt &
b) { return std::move(a) | b; });
1042OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1047 if (getLhs() == getRhs())
1051 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1052 if (prev.getRhs() == getRhs())
1053 return prev.getLhs();
1054 if (prev.getLhs() == getRhs())
1055 return prev.getRhs();
1059 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1060 if (prev.getRhs() == getLhs())
1061 return prev.getLhs();
1062 if (prev.getLhs() == getLhs())
1063 return prev.getRhs();
1067 adaptor.getOperands(),
1068 [](APInt a,
const APInt &
b) { return std::move(a) ^ b; });
1071void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1072 MLIRContext *context) {
1073 patterns.
add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1080OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1082 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1083 return op.getOperand();
1085 [](
const APFloat &a) { return -a; });
1092OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1098 adaptor.getOperands(),
1099 [](
const APFloat &a,
const APFloat &
b) { return a + b; });
1106OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1112 adaptor.getOperands(),
1113 [](
const APFloat &a,
const APFloat &
b) { return a - b; });
1120OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1122 if (getLhs() == getRhs())
1130 adaptor.getOperands(),
1131 [](
const APFloat &a,
const APFloat &
b) { return llvm::maximum(a, b); });
1138OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1140 if (getLhs() == getRhs())
1154OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1156 if (getLhs() == getRhs())
1162 if (intValue.isMaxSignedValue())
1165 if (intValue.isMinSignedValue())
1170 [](
const APInt &a,
const APInt &
b) {
1171 return llvm::APIntOps::smax(a, b);
1179OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1181 if (getLhs() == getRhs())
1187 if (intValue.isMaxValue())
1190 if (intValue.isMinValue())
1195 [](
const APInt &a,
const APInt &
b) {
1196 return llvm::APIntOps::umax(a, b);
1204OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1206 if (getLhs() == getRhs())
1214 adaptor.getOperands(),
1215 [](
const APFloat &a,
const APFloat &
b) { return llvm::minimum(a, b); });
1222OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1224 if (getLhs() == getRhs())
1232 adaptor.getOperands(),
1233 [](
const APFloat &a,
const APFloat &
b) { return llvm::minnum(a, b); });
1240OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1242 if (getLhs() == getRhs())
1248 if (intValue.isMinSignedValue())
1251 if (intValue.isMaxSignedValue())
1256 [](
const APInt &a,
const APInt &
b) {
1257 return llvm::APIntOps::smin(a, b);
1265OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1267 if (getLhs() == getRhs())
1273 if (intValue.isMinValue())
1276 if (intValue.isMaxValue())
1281 [](
const APInt &a,
const APInt &
b) {
1282 return llvm::APIntOps::umin(a, b);
1290OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1295 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1296 arith::FastMathFlags::nsz)) {
1303 adaptor.getOperands(),
1304 [](
const APFloat &a,
const APFloat &
b) { return a * b; });
1307void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1308 MLIRContext *context) {
1309 patterns.
add<MulFOfNegF>(context);
1316OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1322 adaptor.getOperands(),
1323 [](
const APFloat &a,
const APFloat &
b) { return a / b; });
1326void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1327 MLIRContext *context) {
1328 patterns.
add<DivFOfNegF>(context);
1335OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1337 [](
const APFloat &a,
const APFloat &
b) {
1342 (void)result.mod(b);
1351template <
typename... Types>
1357template <
typename... ShapedTypes,
typename... ElementTypes>
1360 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1364 if (!llvm::isa<ElementTypes...>(underlyingType))
1367 return underlyingType;
1371template <
typename... ElementTypes>
1378template <
typename... ElementTypes>
1387 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1388 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1389 if (!rankedTensorA || !rankedTensorB)
1391 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1395 if (inputs.size() != 1 || outputs.size() != 1)
1407template <
typename ValType,
typename Op>
1412 if (llvm::cast<ValType>(srcType).getWidth() >=
1413 llvm::cast<ValType>(dstType).getWidth())
1415 << dstType <<
" must be wider than operand type " << srcType;
1421template <
typename ValType,
typename Op>
1426 if (llvm::cast<ValType>(srcType).getWidth() <=
1427 llvm::cast<ValType>(dstType).getWidth())
1429 << dstType <<
" must be shorter than operand type " << srcType;
1435template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1440 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1441 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1442 if (!srcType || !dstType)
1445 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1446 srcType.getIntOrFloatBitWidth());
1452 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1453 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1454 bool losesInfo =
false;
1455 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1456 if (losesInfo || status != APFloat::opOK)
1466OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1467 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1468 getInMutable().assign(
lhs.getIn());
1473 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1475 adaptor.getOperands(),
getType(),
1476 [bitWidth](
const APInt &a,
bool &castStatus) {
1477 return a.zext(bitWidth);
1485LogicalResult arith::ExtUIOp::verify() {
1493OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1494 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1495 getInMutable().assign(
lhs.getIn());
1500 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1502 adaptor.getOperands(),
getType(),
1503 [bitWidth](
const APInt &a,
bool &castStatus) {
1504 return a.sext(bitWidth);
1512void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1513 MLIRContext *context) {
1514 patterns.
add<ExtSIOfExtUI>(context);
1517LogicalResult arith::ExtSIOp::verify() {
1527OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1528 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1529 if (truncFOp.getOperand().getType() ==
getType()) {
1530 arith::FastMathFlags truncFMF =
1531 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1532 bool isTruncContract =
1533 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1534 arith::FastMathFlags extFMF =
1535 getFastmath().value_or(arith::FastMathFlags::none);
1536 bool isExtContract =
1537 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1538 if (isTruncContract && isExtContract) {
1539 return truncFOp.getOperand();
1545 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1547 adaptor.getOperands(),
getType(),
1548 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1568bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
1573LogicalResult arith::ScalingExtFOp::verify() {
1581OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1584 Value src = getOperand().getDefiningOp()->getOperand(0);
1589 if (llvm::cast<IntegerType>(srcType).getWidth() >
1590 llvm::cast<IntegerType>(dstType).getWidth()) {
1597 if (srcType == dstType)
1603 setOperand(getOperand().getDefiningOp()->getOperand(0));
1608 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1610 adaptor.getOperands(),
getType(),
1611 [bitWidth](
const APInt &a,
bool &castStatus) {
1612 return a.trunc(bitWidth);
1620void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1621 MLIRContext *context) {
1623 .
add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1627LogicalResult arith::TruncIOp::verify() {
1637OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1639 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1640 Value src = extOp.getIn();
1642 auto intermediateType =
1645 if (llvm::APFloatBase::isRepresentableBy(
1646 srcType.getFloatSemantics(),
1647 intermediateType.getFloatSemantics())) {
1649 if (srcType.getWidth() > resElemType.getWidth()) {
1655 if (srcType == resElemType)
1660 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1662 adaptor.getOperands(),
getType(),
1663 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1664 RoundingMode roundingMode =
1665 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1666 llvm::RoundingMode llvmRoundingMode =
1668 FailureOr<APFloat>
result =
1678void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1679 MLIRContext *context) {
1680 patterns.
add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1687LogicalResult arith::TruncFOp::verify() {
1695bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
1700LogicalResult arith::ScalingTruncFOp::verify() {
1708void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1709 MLIRContext *context) {
1710 patterns.
add<AndOfExtUI, AndOfExtSI>(context);
1717void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1718 MLIRContext *context) {
1719 patterns.
add<OrOfExtUI, OrOfExtSI>(context);
1726template <
typename From,
typename To>
1734 return srcType && dstType;
1745OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1748 adaptor.getOperands(),
getType(),
1749 [&resEleType](
const APInt &a,
bool &castStatus) {
1750 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1751 APFloat apf(floatTy.getFloatSemantics(),
1752 APInt::getZero(floatTy.getWidth()));
1753 apf.convertFromAPInt(a,
false,
1754 APFloat::rmNearestTiesToEven);
1767OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1770 adaptor.getOperands(),
getType(),
1771 [&resEleType](
const APInt &a,
bool &castStatus) {
1772 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1773 APFloat apf(floatTy.getFloatSemantics(),
1774 APInt::getZero(floatTy.getWidth()));
1775 apf.convertFromAPInt(a,
true,
1776 APFloat::rmNearestTiesToEven);
1789OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1791 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1793 adaptor.getOperands(),
getType(),
1794 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1796 APSInt api(bitWidth,
true);
1797 castStatus = APFloat::opInvalidOp !=
1798 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1811OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1813 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1815 adaptor.getOperands(),
getType(),
1816 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1818 APSInt api(bitWidth,
false);
1819 castStatus = APFloat::opInvalidOp !=
1820 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1835 if (!srcType || !dstType)
1842bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1847OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1849 unsigned resultBitwidth = 64;
1851 resultBitwidth = intTy.getWidth();
1854 adaptor.getOperands(),
getType(),
1855 [resultBitwidth](
const APInt &a,
bool & ) {
1856 return a.sextOrTrunc(resultBitwidth);
1860void arith::IndexCastOp::getCanonicalizationPatterns(
1861 RewritePatternSet &patterns, MLIRContext *context) {
1862 patterns.
add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1869bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1874OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1876 unsigned resultBitwidth = 64;
1878 resultBitwidth = intTy.getWidth();
1881 adaptor.getOperands(),
getType(),
1882 [resultBitwidth](
const APInt &a,
bool & ) {
1883 return a.zextOrTrunc(resultBitwidth);
1887void arith::IndexCastUIOp::getCanonicalizationPatterns(
1888 RewritePatternSet &patterns, MLIRContext *context) {
1889 patterns.
add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1902 if (!srcType || !dstType)
1908OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1910 auto operand = adaptor.getIn();
1915 if (
auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1916 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1918 if (llvm::isa<ShapedType>(resType))
1926 APInt bits = llvm::isa<FloatAttr>(operand)
1927 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1928 : llvm::cast<IntegerAttr>(operand).getValue();
1930 "trying to fold on broken IR: operands have incompatible types");
1932 if (
auto resFloatType = dyn_cast<FloatType>(resType))
1933 return FloatAttr::get(resType,
1934 APFloat(resFloatType.getFloatSemantics(), bits));
1935 return IntegerAttr::get(resType, bits);
1938void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1939 MLIRContext *context) {
1940 patterns.
add<BitcastOfBitcast>(context);
1950 const APInt &
lhs,
const APInt &
rhs) {
1951 switch (predicate) {
1952 case arith::CmpIPredicate::eq:
1954 case arith::CmpIPredicate::ne:
1956 case arith::CmpIPredicate::slt:
1958 case arith::CmpIPredicate::sle:
1960 case arith::CmpIPredicate::sgt:
1962 case arith::CmpIPredicate::sge:
1964 case arith::CmpIPredicate::ult:
1966 case arith::CmpIPredicate::ule:
1968 case arith::CmpIPredicate::ugt:
1970 case arith::CmpIPredicate::uge:
1973 llvm_unreachable(
"unknown cmpi predicate kind");
1978 switch (predicate) {
1979 case arith::CmpIPredicate::eq:
1980 case arith::CmpIPredicate::sle:
1981 case arith::CmpIPredicate::sge:
1982 case arith::CmpIPredicate::ule:
1983 case arith::CmpIPredicate::uge:
1985 case arith::CmpIPredicate::ne:
1986 case arith::CmpIPredicate::slt:
1987 case arith::CmpIPredicate::sgt:
1988 case arith::CmpIPredicate::ult:
1989 case arith::CmpIPredicate::ugt:
1992 llvm_unreachable(
"unknown cmpi predicate kind");
1996 if (
auto intType = dyn_cast<IntegerType>(t)) {
1997 return intType.getWidth();
1999 if (
auto vectorIntType = dyn_cast<VectorType>(t)) {
2000 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
2002 return std::nullopt;
2005OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2007 if (getLhs() == getRhs()) {
2013 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2015 std::optional<int64_t> integerWidth =
2017 if (integerWidth && integerWidth.value() == 1 &&
2018 getPredicate() == arith::CmpIPredicate::ne)
2019 return extOp.getOperand();
2021 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2023 std::optional<int64_t> integerWidth =
2025 if (integerWidth && integerWidth.value() == 1 &&
2026 getPredicate() == arith::CmpIPredicate::ne)
2027 return extOp.getOperand();
2032 getPredicate() == arith::CmpIPredicate::ne)
2039 getPredicate() == arith::CmpIPredicate::eq)
2044 if (adaptor.getLhs() && !adaptor.getRhs()) {
2046 using Pred = CmpIPredicate;
2047 const std::pair<Pred, Pred> invPreds[] = {
2048 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2049 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2050 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2051 {Pred::ne, Pred::ne},
2053 Pred origPred = getPredicate();
2054 for (
auto pred : invPreds) {
2055 if (origPred == pred.first) {
2056 setPredicate(pred.second);
2057 Value
lhs = getLhs();
2058 Value
rhs = getRhs();
2059 getLhsMutable().assign(
rhs);
2060 getRhsMutable().assign(
lhs);
2064 llvm_unreachable(
"unknown cmpi predicate kind");
2069 if (
auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2072 [pred = getPredicate()](
const APInt &
lhs,
const APInt &
rhs) {
2081void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2082 MLIRContext *context) {
2083 patterns.
insert<CmpIExtSI, CmpIExtUI>(context);
2093 const APFloat &
lhs,
const APFloat &
rhs) {
2094 auto cmpResult =
lhs.compare(
rhs);
2095 switch (predicate) {
2096 case arith::CmpFPredicate::AlwaysFalse:
2098 case arith::CmpFPredicate::OEQ:
2099 return cmpResult == APFloat::cmpEqual;
2100 case arith::CmpFPredicate::OGT:
2101 return cmpResult == APFloat::cmpGreaterThan;
2102 case arith::CmpFPredicate::OGE:
2103 return cmpResult == APFloat::cmpGreaterThan ||
2104 cmpResult == APFloat::cmpEqual;
2105 case arith::CmpFPredicate::OLT:
2106 return cmpResult == APFloat::cmpLessThan;
2107 case arith::CmpFPredicate::OLE:
2108 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2109 case arith::CmpFPredicate::ONE:
2110 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2111 case arith::CmpFPredicate::ORD:
2112 return cmpResult != APFloat::cmpUnordered;
2113 case arith::CmpFPredicate::UEQ:
2114 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2115 case arith::CmpFPredicate::UGT:
2116 return cmpResult == APFloat::cmpUnordered ||
2117 cmpResult == APFloat::cmpGreaterThan;
2118 case arith::CmpFPredicate::UGE:
2119 return cmpResult == APFloat::cmpUnordered ||
2120 cmpResult == APFloat::cmpGreaterThan ||
2121 cmpResult == APFloat::cmpEqual;
2122 case arith::CmpFPredicate::ULT:
2123 return cmpResult == APFloat::cmpUnordered ||
2124 cmpResult == APFloat::cmpLessThan;
2125 case arith::CmpFPredicate::ULE:
2126 return cmpResult == APFloat::cmpUnordered ||
2127 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2128 case arith::CmpFPredicate::UNE:
2129 return cmpResult != APFloat::cmpEqual;
2130 case arith::CmpFPredicate::UNO:
2131 return cmpResult == APFloat::cmpUnordered;
2132 case arith::CmpFPredicate::AlwaysTrue:
2135 llvm_unreachable(
"unknown cmpf predicate kind");
2139 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2140 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2143 if (
lhs &&
lhs.getValue().isNaN())
2145 if (
rhs &&
rhs.getValue().isNaN())
2161 using namespace arith;
2163 case CmpFPredicate::UEQ:
2164 case CmpFPredicate::OEQ:
2165 return CmpIPredicate::eq;
2166 case CmpFPredicate::UGT:
2167 case CmpFPredicate::OGT:
2168 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2169 case CmpFPredicate::UGE:
2170 case CmpFPredicate::OGE:
2171 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2172 case CmpFPredicate::ULT:
2173 case CmpFPredicate::OLT:
2174 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2175 case CmpFPredicate::ULE:
2176 case CmpFPredicate::OLE:
2177 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2178 case CmpFPredicate::UNE:
2179 case CmpFPredicate::ONE:
2180 return CmpIPredicate::ne;
2182 llvm_unreachable(
"Unexpected predicate!");
2192 const APFloat &
rhs = flt.getValue();
2200 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2201 int mantissaWidth = floatTy.getFPMantissaWidth();
2202 if (mantissaWidth <= 0)
2208 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2210 intVal = si.getIn();
2211 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2213 intVal = ui.getIn();
2220 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2221 auto intWidth = intTy.getWidth();
2224 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2229 if ((
int)intWidth > mantissaWidth) {
2231 int exponent = ilogb(
rhs);
2232 if (exponent == APFloat::IEK_Inf) {
2233 int maxExponent = ilogb(APFloat::getLargest(
rhs.getSemantics()));
2234 if (maxExponent < (
int)valueBits) {
2241 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2250 switch (op.getPredicate()) {
2251 case CmpFPredicate::ORD:
2256 case CmpFPredicate::UNO:
2269 APFloat signedMax(
rhs.getSemantics());
2270 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2271 APFloat::rmNearestTiesToEven);
2272 if (signedMax <
rhs) {
2273 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2274 pred == CmpIPredicate::sle)
2285 APFloat unsignedMax(
rhs.getSemantics());
2286 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2287 APFloat::rmNearestTiesToEven);
2288 if (unsignedMax <
rhs) {
2289 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2290 pred == CmpIPredicate::ule)
2302 APFloat signedMin(
rhs.getSemantics());
2303 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2304 APFloat::rmNearestTiesToEven);
2305 if (signedMin >
rhs) {
2306 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2307 pred == CmpIPredicate::sge)
2317 APFloat unsignedMin(
rhs.getSemantics());
2318 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2319 APFloat::rmNearestTiesToEven);
2320 if (unsignedMin >
rhs) {
2321 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2322 pred == CmpIPredicate::uge)
2337 APSInt rhsInt(intWidth, isUnsigned);
2338 if (APFloat::opInvalidOp ==
2339 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2345 if (!
rhs.isZero()) {
2346 APFloat apf(floatTy.getFloatSemantics(),
2347 APInt::getZero(floatTy.getWidth()));
2348 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2350 bool equal = apf ==
rhs;
2356 case CmpIPredicate::ne:
2360 case CmpIPredicate::eq:
2364 case CmpIPredicate::ule:
2367 if (
rhs.isNegative()) {
2373 case CmpIPredicate::sle:
2376 if (
rhs.isNegative())
2377 pred = CmpIPredicate::slt;
2379 case CmpIPredicate::ult:
2382 if (
rhs.isNegative()) {
2387 pred = CmpIPredicate::ule;
2389 case CmpIPredicate::slt:
2392 if (!
rhs.isNegative())
2393 pred = CmpIPredicate::sle;
2395 case CmpIPredicate::ugt:
2398 if (
rhs.isNegative()) {
2404 case CmpIPredicate::sgt:
2407 if (
rhs.isNegative())
2408 pred = CmpIPredicate::sge;
2410 case CmpIPredicate::uge:
2413 if (
rhs.isNegative()) {
2418 pred = CmpIPredicate::ugt;
2420 case CmpIPredicate::sge:
2423 if (!
rhs.isNegative())
2424 pred = CmpIPredicate::sgt;
2434 ConstantOp::create(rewriter, op.getLoc(), intVal.
getType(),
2440void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2441 MLIRContext *context) {
2442 patterns.
insert<CmpFIntToFPConst>(context);
2456 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2472 arith::XOrIOp::create(
2473 rewriter, op.getLoc(), op.getCondition(),
2475 op.getCondition().
getType(), 1)));
2483void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2484 MLIRContext *context) {
2485 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2486 SelectI1ToNot, SelectToExtUI>(context);
2489OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2490 Value trueVal = getTrueValue();
2491 Value falseVal = getFalseValue();
2492 if (trueVal == falseVal)
2495 Value condition = getCondition();
2513 if (
getType().isSignlessInteger(1) &&
2519 auto pred = cmp.getPredicate();
2520 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2521 auto cmpLhs = cmp.getLhs();
2522 auto cmpRhs = cmp.getRhs();
2530 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2531 (cmpRhs == trueVal && cmpLhs == falseVal))
2532 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2539 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2541 assert(cond.getType().hasStaticShape() &&
2542 "DenseElementsAttr must have static shape");
2544 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2546 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2547 SmallVector<Attribute> results;
2548 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2549 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2550 cond.value_end<BoolAttr>());
2551 auto lhsVals = llvm::make_range(
lhs.value_begin<Attribute>(),
2552 lhs.value_end<Attribute>());
2553 auto rhsVals = llvm::make_range(
rhs.value_begin<Attribute>(),
2554 rhs.value_end<Attribute>());
2556 for (
auto [condVal, lhsVal, rhsVal] :
2557 llvm::zip_equal(condVals, lhsVals, rhsVals))
2558 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2568ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &
result) {
2569 Type conditionType, resultType;
2570 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2578 conditionType = resultType;
2585 result.addTypes(resultType);
2587 {conditionType, resultType, resultType},
2591void arith::SelectOp::print(OpAsmPrinter &p) {
2592 p <<
" " << getOperands();
2595 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().
getType()))
2596 p << condType <<
", ";
2600LogicalResult arith::SelectOp::verify() {
2601 Type conditionType = getCondition().getType();
2608 if (!llvm::isa<TensorType, VectorType>(resultType))
2609 return emitOpError() <<
"expected condition to be a signless i1, but got "
2612 if (conditionType != shapedConditionType) {
2613 return emitOpError() <<
"expected condition type to have the same shape "
2614 "as the result type, expected "
2615 << shapedConditionType <<
", but got "
2624OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2629 bool bounded =
false;
2631 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2632 bounded = b.ult(b.getBitWidth());
2635 return bounded ?
result : Attribute();
2642OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2647 bool bounded =
false;
2649 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2650 bounded = b.ult(b.getBitWidth());
2653 return bounded ?
result : Attribute();
2660OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2665 bool bounded =
false;
2667 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2668 bounded = b.ult(b.getBitWidth());
2671 return bounded ?
result : Attribute();
2681 bool useOnlyFiniteValue) {
2683 case AtomicRMWKind::maximumf: {
2684 const llvm::fltSemantics &semantic =
2685 llvm::cast<FloatType>(resultType).getFloatSemantics();
2686 APFloat identity = useOnlyFiniteValue
2687 ? APFloat::getLargest(semantic,
true)
2688 : APFloat::getInf(semantic,
true);
2691 case AtomicRMWKind::maxnumf: {
2692 const llvm::fltSemantics &semantic =
2693 llvm::cast<FloatType>(resultType).getFloatSemantics();
2694 APFloat identity = APFloat::getNaN(semantic,
true);
2697 case AtomicRMWKind::addf:
2698 case AtomicRMWKind::addi:
2699 case AtomicRMWKind::maxu:
2700 case AtomicRMWKind::ori:
2701 case AtomicRMWKind::xori:
2703 case AtomicRMWKind::andi:
2706 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2707 case AtomicRMWKind::maxs:
2709 resultType, APInt::getSignedMinValue(
2710 llvm::cast<IntegerType>(resultType).getWidth()));
2711 case AtomicRMWKind::minimumf: {
2712 const llvm::fltSemantics &semantic =
2713 llvm::cast<FloatType>(resultType).getFloatSemantics();
2714 APFloat identity = useOnlyFiniteValue
2715 ? APFloat::getLargest(semantic,
false)
2716 : APFloat::getInf(semantic,
false);
2720 case AtomicRMWKind::minnumf: {
2721 const llvm::fltSemantics &semantic =
2722 llvm::cast<FloatType>(resultType).getFloatSemantics();
2723 APFloat identity = APFloat::getNaN(semantic,
false);
2726 case AtomicRMWKind::mins:
2728 resultType, APInt::getSignedMaxValue(
2729 llvm::cast<IntegerType>(resultType).getWidth()));
2730 case AtomicRMWKind::minu:
2733 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2734 case AtomicRMWKind::muli:
2736 case AtomicRMWKind::mulf:
2748 std::optional<AtomicRMWKind> maybeKind =
2751 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2752 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2753 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2754 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2755 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2756 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2758 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2759 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2760 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::xori; })
2761 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2762 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2763 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2764 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2765 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2766 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2767 .Default(std::nullopt);
2769 return std::nullopt;
2772 bool useOnlyFiniteValue =
false;
2773 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2774 if (fmfOpInterface) {
2775 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2776 useOnlyFiniteValue =
2777 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2785 useOnlyFiniteValue);
2791 bool useOnlyFiniteValue) {
2794 return arith::ConstantOp::create(builder, loc, attr);
2802 case AtomicRMWKind::addf:
2803 return arith::AddFOp::create(builder, loc,
lhs,
rhs);
2804 case AtomicRMWKind::addi:
2805 return arith::AddIOp::create(builder, loc,
lhs,
rhs);
2806 case AtomicRMWKind::mulf:
2807 return arith::MulFOp::create(builder, loc,
lhs,
rhs);
2808 case AtomicRMWKind::muli:
2809 return arith::MulIOp::create(builder, loc,
lhs,
rhs);
2810 case AtomicRMWKind::maximumf:
2811 return arith::MaximumFOp::create(builder, loc,
lhs,
rhs);
2812 case AtomicRMWKind::minimumf:
2813 return arith::MinimumFOp::create(builder, loc,
lhs,
rhs);
2814 case AtomicRMWKind::maxnumf:
2815 return arith::MaxNumFOp::create(builder, loc,
lhs,
rhs);
2816 case AtomicRMWKind::minnumf:
2817 return arith::MinNumFOp::create(builder, loc,
lhs,
rhs);
2818 case AtomicRMWKind::maxs:
2819 return arith::MaxSIOp::create(builder, loc,
lhs,
rhs);
2820 case AtomicRMWKind::mins:
2821 return arith::MinSIOp::create(builder, loc,
lhs,
rhs);
2822 case AtomicRMWKind::maxu:
2823 return arith::MaxUIOp::create(builder, loc,
lhs,
rhs);
2824 case AtomicRMWKind::minu:
2825 return arith::MinUIOp::create(builder, loc,
lhs,
rhs);
2826 case AtomicRMWKind::ori:
2827 return arith::OrIOp::create(builder, loc,
lhs,
rhs);
2828 case AtomicRMWKind::andi:
2829 return arith::AndIOp::create(builder, loc,
lhs,
rhs);
2830 case AtomicRMWKind::xori:
2831 return arith::XOrIOp::create(builder, loc,
lhs,
rhs);
2844#define GET_OP_CLASSES
2845#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2851#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static LogicalResult verifyTruncateOp(Op op)
static Type getElementType(Type type)
Determine the element type of type.
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Specialization of arith.constant op that returns an integer of index type.
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Specialization of arith.constant op that returns an integer value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
llvm::function_ref< Fn > function_ref
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.