26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
39 llvm::RoundingMode::NearestTiesToEven;
48 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
49 const APInt &lhsVal = llvm::cast<IntegerAttr>(
lhs).getValue();
50 const APInt &rhsVal = llvm::cast<IntegerAttr>(
rhs).getValue();
51 APInt value = binFn(lhsVal, rhsVal);
52 return IntegerAttr::get(res.
getType(), value);
71static IntegerOverflowFlagsAttr
73 IntegerOverflowFlagsAttr val2) {
74 return IntegerOverflowFlagsAttr::get(val1.getContext(),
75 val1.getValue() & val2.getValue());
81 case arith::CmpIPredicate::eq:
82 return arith::CmpIPredicate::ne;
83 case arith::CmpIPredicate::ne:
84 return arith::CmpIPredicate::eq;
85 case arith::CmpIPredicate::slt:
86 return arith::CmpIPredicate::sge;
87 case arith::CmpIPredicate::sle:
88 return arith::CmpIPredicate::sgt;
89 case arith::CmpIPredicate::sgt:
90 return arith::CmpIPredicate::sle;
91 case arith::CmpIPredicate::sge:
92 return arith::CmpIPredicate::slt;
93 case arith::CmpIPredicate::ult:
94 return arith::CmpIPredicate::uge;
95 case arith::CmpIPredicate::ule:
96 return arith::CmpIPredicate::ugt;
97 case arith::CmpIPredicate::ugt:
98 return arith::CmpIPredicate::ule;
99 case arith::CmpIPredicate::uge:
100 return arith::CmpIPredicate::ult;
102 llvm_unreachable(
"unknown cmpi predicate kind");
111static llvm::RoundingMode
115 switch (*roundingMode) {
116 case RoundingMode::downward:
117 return llvm::RoundingMode::TowardNegative;
118 case RoundingMode::to_nearest_away:
119 return llvm::RoundingMode::NearestTiesToAway;
120 case RoundingMode::to_nearest_even:
121 return llvm::RoundingMode::NearestTiesToEven;
122 case RoundingMode::toward_zero:
123 return llvm::RoundingMode::TowardZero;
124 case RoundingMode::upward:
125 return llvm::RoundingMode::TowardPositive;
127 llvm_unreachable(
"Unhandled rounding mode");
131 return arith::CmpIPredicateAttr::get(pred.getContext(),
157 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
161 if (!shapedType.hasStaticShape())
171#include "ArithCanonicalization.inc"
180 auto i1Type = IntegerType::get(type.
getContext(), 1);
181 if (
auto shapedType = dyn_cast<ShapedType>(type))
182 return shapedType.cloneWith(std::nullopt, i1Type);
183 if (llvm::isa<UnrankedTensorType>(type))
184 return UnrankedTensorType::get(i1Type);
192void arith::ConstantOp::getAsmResultNames(
195 if (
auto intCst = dyn_cast<IntegerAttr>(getValue())) {
196 auto intType = dyn_cast<IntegerType>(type);
199 if (intType && intType.getWidth() == 1)
200 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
203 SmallString<32> specialNameBuffer;
204 llvm::raw_svector_ostream specialName(specialNameBuffer);
205 specialName <<
'c' << intCst.getValue();
207 specialName <<
'_' << type;
208 setNameFn(getResult(), specialName.str());
210 setNameFn(getResult(),
"cst");
216LogicalResult arith::ConstantOp::verify() {
219 if (llvm::isa<IntegerType>(type) &&
220 !llvm::cast<IntegerType>(type).isSignless())
221 return emitOpError(
"integer return type must be signless");
223 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
225 "value must be an integer, float, or elements attribute");
231 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
233 "initializing scalable vectors with elements attribute is not supported"
234 " unless it's a vector splat");
238bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
240 auto typedAttr = dyn_cast<TypedAttr>(value);
241 if (!typedAttr || typedAttr.getType() != type)
245 if (!intType.isSignless())
249 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
252ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
253 Type type, Location loc) {
254 if (isBuildableWith(value, type))
255 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
259OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
264 arith::ConstantOp::build(builder,
result, type,
274 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
275 assert(
result &&
"builder didn't return the right type");
287 arith::ConstantOp::build(builder,
result, type,
296 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
297 assert(
result &&
"builder didn't return the right type");
308 arith::ConstantOp::build(builder,
result, type,
314 const APInt &
value) {
317 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
318 assert(
result &&
"builder didn't return the right type");
324 const APInt &
value) {
329 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
330 return constOp.getType().isSignlessInteger();
335 FloatType type,
const APFloat &
value) {
336 arith::ConstantOp::build(builder,
result, type,
343 const APFloat &
value) {
346 auto result = dyn_cast<ConstantFloatOp>(builder.
create(state));
347 assert(
result &&
"builder didn't return the right type");
353 const APFloat &
value) {
358 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
359 return llvm::isa<FloatType>(constOp.getType());
374 auto result = dyn_cast<ConstantIndexOp>(builder.
create(state));
375 assert(
result &&
"builder didn't return the right type");
385 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
386 return constOp.getType().isIndex();
394 "type doesn't have a zero representation");
396 assert(zeroAttr &&
"unsupported type for zero attribute");
397 return arith::ConstantOp::create(builder, loc, zeroAttr);
410 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
411 if (getRhs() == sub.getRhs())
415 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
416 if (getLhs() == sub.getRhs())
420 adaptor.getOperands(),
421 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
426 patterns.
add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
427 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
434std::optional<SmallVector<int64_t, 4>>
435arith::AddUIExtendedOp::getShapeForUnroll() {
436 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
437 return llvm::to_vector<4>(vt.getShape());
444 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
448arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
449 SmallVectorImpl<OpFoldResult> &results) {
450 Type overflowTy = getOverflow().getType();
456 results.push_back(getLhs());
457 results.push_back(falseValue);
466 adaptor.getOperands(),
467 [](APInt a,
const APInt &
b) { return std::move(a) + b; })) {
470 results.push_back(sumAttr);
471 results.push_back(sumAttr);
475 ArrayRef({sumAttr, adaptor.getLhs()}),
481 results.push_back(sumAttr);
482 results.push_back(overflowAttr);
489void arith::AddUIExtendedOp::getCanonicalizationPatterns(
490 RewritePatternSet &patterns, MLIRContext *context) {
491 patterns.
add<AddUIExtendedToAddI>(context);
498OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
500 if (getOperand(0) == getOperand(1)) {
501 auto shapedType = dyn_cast<ShapedType>(
getType());
503 if (!shapedType || shapedType.hasStaticShape())
510 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
512 if (getRhs() ==
add.getRhs())
515 if (getRhs() ==
add.getLhs())
520 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
521 if (getLhs() == sub.getLhs())
525 adaptor.getOperands(),
526 [](APInt a,
const APInt &
b) { return std::move(a) - b; });
529void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
530 MLIRContext *context) {
531 patterns.
add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
532 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
533 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
540OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
551 adaptor.getOperands(),
552 [](
const APInt &a,
const APInt &
b) { return a * b; });
555void arith::MulIOp::getAsmResultNames(
557 if (!isa<IndexType>(
getType()))
562 auto isVscale = [](Operation *op) {
563 return op && op->getName().getStringRef() ==
"vector.vscale";
566 IntegerAttr baseValue;
567 auto isVscaleExpr = [&](Value a, Value
b) {
569 isVscale(
b.getDefiningOp());
572 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
576 SmallString<32> specialNameBuffer;
577 llvm::raw_svector_ostream specialName(specialNameBuffer);
578 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
579 setNameFn(getResult(), specialName.str());
582void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
583 MLIRContext *context) {
584 patterns.
add<MulIMulIConstant>(context);
591std::optional<SmallVector<int64_t, 4>>
592arith::MulSIExtendedOp::getShapeForUnroll() {
593 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
594 return llvm::to_vector<4>(vt.getShape());
599arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
600 SmallVectorImpl<OpFoldResult> &results) {
603 Attribute zero = adaptor.getRhs();
604 results.push_back(zero);
605 results.push_back(zero);
611 adaptor.getOperands(),
612 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
615 llvm::APIntOps::mulhs);
616 assert(highAttr &&
"Unexpected constant-folding failure");
618 results.push_back(lowAttr);
619 results.push_back(highAttr);
626void arith::MulSIExtendedOp::getCanonicalizationPatterns(
627 RewritePatternSet &patterns, MLIRContext *context) {
628 patterns.
add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
635std::optional<SmallVector<int64_t, 4>>
636arith::MulUIExtendedOp::getShapeForUnroll() {
637 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
638 return llvm::to_vector<4>(vt.getShape());
643arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
644 SmallVectorImpl<OpFoldResult> &results) {
647 Attribute zero = adaptor.getRhs();
648 results.push_back(zero);
649 results.push_back(zero);
657 results.push_back(getLhs());
658 results.push_back(zero);
664 adaptor.getOperands(),
665 [](
const APInt &a,
const APInt &
b) { return a * b; })) {
668 llvm::APIntOps::mulhu);
669 assert(highAttr &&
"Unexpected constant-folding failure");
671 results.push_back(lowAttr);
672 results.push_back(highAttr);
679void arith::MulUIExtendedOp::getCanonicalizationPatterns(
680 RewritePatternSet &patterns, MLIRContext *context) {
681 patterns.
add<MulUIExtendedToMulI>(context);
690 arith::IntegerOverflowFlags ovfFlags) {
691 auto mul =
lhs.getDefiningOp<mlir::arith::MulIOp>();
692 if (!
mul || !bitEnumContainsAll(
mul.getOverflowFlags(), ovfFlags))
704OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
710 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
716 [&](APInt a,
const APInt &
b) {
724 return div0 ? Attribute() :
result;
744OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
750 if (Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
754 bool overflowOrDiv0 =
false;
756 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
757 if (overflowOrDiv0 || !b) {
758 overflowOrDiv0 = true;
761 return a.sdiv_ov(
b, overflowOrDiv0);
764 return overflowOrDiv0 ? Attribute() :
result;
791 APInt one(a.getBitWidth(), 1,
true);
792 APInt val = a.ssub_ov(one, overflow).sdiv_ov(
b, overflow);
793 return val.sadd_ov(one, overflow);
800OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
805 bool overflowOrDiv0 =
false;
807 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
808 if (overflowOrDiv0 || !b) {
809 overflowOrDiv0 = true;
812 APInt quotient = a.udiv(
b);
815 APInt one(a.getBitWidth(), 1,
true);
816 return quotient.uadd_ov(one, overflowOrDiv0);
819 return overflowOrDiv0 ? Attribute() :
result;
830OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
838 bool overflowOrDiv0 =
false;
840 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
841 if (overflowOrDiv0 || !b) {
842 overflowOrDiv0 = true;
848 unsigned bits = a.getBitWidth();
849 APInt zero = APInt::getZero(bits);
850 bool aGtZero = a.sgt(zero);
851 bool bGtZero =
b.sgt(zero);
852 if (aGtZero && bGtZero) {
859 bool overflowNegA =
false;
860 bool overflowNegB =
false;
861 bool overflowDiv =
false;
862 bool overflowNegRes =
false;
863 if (!aGtZero && !bGtZero) {
865 APInt posA = zero.ssub_ov(a, overflowNegA);
866 APInt posB = zero.ssub_ov(
b, overflowNegB);
868 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
871 if (!aGtZero && bGtZero) {
873 APInt posA = zero.ssub_ov(a, overflowNegA);
874 APInt
div = posA.sdiv_ov(
b, overflowDiv);
875 APInt res = zero.ssub_ov(
div, overflowNegRes);
876 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
880 APInt posB = zero.ssub_ov(
b, overflowNegB);
881 APInt
div = a.sdiv_ov(posB, overflowDiv);
882 APInt res = zero.ssub_ov(
div, overflowNegRes);
884 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
888 return overflowOrDiv0 ? Attribute() :
result;
899OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
905 bool overflowOrDiv =
false;
907 adaptor.getOperands(), [&](APInt a,
const APInt &
b) {
909 overflowOrDiv = true;
912 return a.sfloordiv_ov(
b, overflowOrDiv);
915 return overflowOrDiv ? Attribute() :
result;
922OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
930 [&](APInt a,
const APInt &
b) {
931 if (div0 || b.isZero()) {
938 return div0 ? Attribute() :
result;
949OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
957 [&](APInt a,
const APInt &
b) {
958 if (div0 || b.isZero()) {
965 return div0 ? Attribute() :
result;
983 for (
bool reversePrev : {
false,
true}) {
984 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
985 .getDefiningOp<arith::AndIOp>();
989 Value other = (reversePrev ? op.getLhs() : op.getRhs());
990 if (other != prev.getLhs() && other != prev.getRhs())
993 return prev.getResult();
998OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
1005 intValue.isAllOnes())
1010 intValue.isAllOnes())
1015 intValue.isAllOnes())
1023 adaptor.getOperands(),
1024 [](APInt a,
const APInt &
b) { return std::move(a) & b; });
1031OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1034 if (rhsVal.isZero())
1037 if (rhsVal.isAllOnes())
1038 return adaptor.getRhs();
1045 intValue.isAllOnes())
1046 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1050 intValue.isAllOnes())
1051 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1054 adaptor.getOperands(),
1055 [](APInt a,
const APInt &
b) { return std::move(a) | b; });
1062OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1067 if (getLhs() == getRhs())
1071 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1072 if (prev.getRhs() == getRhs())
1073 return prev.getLhs();
1074 if (prev.getLhs() == getRhs())
1075 return prev.getRhs();
1079 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1080 if (prev.getRhs() == getLhs())
1081 return prev.getLhs();
1082 if (prev.getLhs() == getLhs())
1083 return prev.getRhs();
1087 adaptor.getOperands(),
1088 [](APInt a,
const APInt &
b) { return std::move(a) ^ b; });
1091void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1092 MLIRContext *context) {
1093 patterns.
add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1100OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1102 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1103 return op.getOperand();
1105 [](
const APFloat &a) { return -a; });
1112OpFoldResult arith::FlushDenormalsOp::fold(FoldAdaptor adaptor) {
1118 if (
auto op = this->getOperand().getDefiningOp<arith::FlushDenormalsOp>())
1119 return op.getResult();
1123 adaptor.getOperands(), [](
const APFloat &a) {
1125 return APFloat::getZero(a.getSemantics(), a.isNegative());
1134OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1139 auto rm = getRoundingmode();
1141 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1143 result.add(b, convertArithRoundingModeToLLVMIR(rm));
1152OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1157 auto rm = getRoundingmode();
1159 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1161 result.subtract(b, convertArithRoundingModeToLLVMIR(rm));
1166void arith::SubFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1167 MLIRContext *context) {
1168 patterns.
add<SubFOfNegZero>(context);
1175OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1177 if (getLhs() == getRhs())
1191OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1193 if (getLhs() == getRhs())
1207OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1209 if (getLhs() == getRhs())
1215 if (intValue.isMaxSignedValue())
1218 if (intValue.isMinSignedValue())
1223 llvm::APIntOps::smax);
1230OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1232 if (getLhs() == getRhs())
1238 if (intValue.isMaxValue())
1241 if (intValue.isMinValue())
1246 llvm::APIntOps::umax);
1253OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1255 if (getLhs() == getRhs())
1269OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1271 if (getLhs() == getRhs())
1285OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1287 if (getLhs() == getRhs())
1293 if (intValue.isMinSignedValue())
1296 if (intValue.isMaxSignedValue())
1301 llvm::APIntOps::smin);
1308OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1310 if (getLhs() == getRhs())
1316 if (intValue.isMinValue())
1319 if (intValue.isMaxValue())
1324 llvm::APIntOps::umin);
1331OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1336 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1337 arith::FastMathFlags::nsz)) {
1343 auto rm = getRoundingmode();
1345 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1347 result.multiply(b, convertArithRoundingModeToLLVMIR(rm));
1352void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1353 MLIRContext *context) {
1354 patterns.
add<MulFOfNegF>(context);
1361OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1366 auto rm = getRoundingmode();
1368 adaptor.getOperands(), [rm](
const APFloat &a,
const APFloat &
b) {
1370 result.divide(b, convertArithRoundingModeToLLVMIR(rm));
1375void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1376 MLIRContext *context) {
1377 patterns.
add<DivFOfNegF>(context);
1384OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1386 [](
const APFloat &a,
const APFloat &
b) {
1391 (void)result.mod(b);
1400template <
typename... Types>
1406template <
typename... ShapedTypes,
typename... ElementTypes>
1409 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1413 if (!llvm::isa<ElementTypes...>(underlyingType))
1416 return underlyingType;
1420template <
typename... ElementTypes>
1427template <
typename... ElementTypes>
1436 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1437 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1438 if (!rankedTensorA || !rankedTensorB)
1440 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1444 if (inputs.size() != 1 || outputs.size() != 1)
1456template <
typename ValType,
typename Op>
1461 if (llvm::cast<ValType>(srcType).getWidth() >=
1462 llvm::cast<ValType>(dstType).getWidth())
1464 << dstType <<
" must be wider than operand type " << srcType;
1470template <
typename ValType,
typename Op>
1475 if (llvm::cast<ValType>(srcType).getWidth() <=
1476 llvm::cast<ValType>(dstType).getWidth())
1478 << dstType <<
" must be shorter than operand type " << srcType;
1484template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1489 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1490 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1491 if (!srcType || !dstType)
1494 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1495 srcType.getIntOrFloatBitWidth());
1500static FailureOr<APFloat>
1502 const llvm::fltSemantics &targetSemantics,
1506 using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior;
1507 if (sourceValue.isInfinity() &&
1508 (targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
1509 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly))
1511 if (sourceValue.isNaN() &&
1512 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly)
1515 bool losesInfo =
false;
1516 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1517 if (losesInfo || status != APFloat::opOK)
1527OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1528 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1529 getInMutable().assign(
lhs.getIn());
1534 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1536 adaptor.getOperands(),
getType(),
1537 [bitWidth](
const APInt &a,
bool &castStatus) {
1538 return a.zext(bitWidth);
1546LogicalResult arith::ExtUIOp::verify() {
1554OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1555 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1556 getInMutable().assign(
lhs.getIn());
1561 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1563 adaptor.getOperands(),
getType(),
1564 [bitWidth](
const APInt &a,
bool &castStatus) {
1565 return a.sext(bitWidth);
1573void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1574 MLIRContext *context) {
1575 patterns.
add<ExtSIOfExtUI>(context);
1578LogicalResult arith::ExtSIOp::verify() {
1588OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1589 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1590 if (truncFOp.getOperand().getType() ==
getType()) {
1591 arith::FastMathFlags truncFMF =
1592 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1593 bool isTruncContract =
1594 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1595 arith::FastMathFlags extFMF =
1596 getFastmath().value_or(arith::FastMathFlags::none);
1597 bool isExtContract =
1598 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1599 if (isTruncContract && isExtContract) {
1600 return truncFOp.getOperand();
1606 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1608 adaptor.getOperands(),
getType(),
1609 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1629bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
1634LogicalResult arith::ScalingExtFOp::verify() {
1642OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1645 Value src = getOperand().getDefiningOp()->getOperand(0);
1650 if (llvm::cast<IntegerType>(srcType).getWidth() >
1651 llvm::cast<IntegerType>(dstType).getWidth()) {
1658 if (srcType == dstType)
1664 setOperand(getOperand().getDefiningOp()->getOperand(0));
1669 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1671 adaptor.getOperands(),
getType(),
1672 [bitWidth](
const APInt &a,
bool &castStatus) {
1673 return a.trunc(bitWidth);
1681void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1682 MLIRContext *context) {
1684 .
add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1688LogicalResult arith::TruncIOp::verify() {
1698OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1700 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1701 Value src = extOp.getIn();
1703 auto intermediateType =
1706 if (llvm::APFloatBase::isRepresentableBy(
1707 srcType.getFloatSemantics(),
1708 intermediateType.getFloatSemantics())) {
1710 if (srcType.getWidth() > resElemType.getWidth()) {
1716 if (srcType == resElemType)
1721 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1723 adaptor.getOperands(),
getType(),
1724 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1725 llvm::RoundingMode llvmRoundingMode =
1727 FailureOr<APFloat>
result =
1737void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1738 MLIRContext *context) {
1739 patterns.
add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1746LogicalResult arith::TruncFOp::verify() {
1754OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) {
1756 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1758 adaptor.getOperands(),
getType(),
1759 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1760 llvm::RoundingMode llvmRoundingMode =
1762 FailureOr<APFloat>
result =
1777 if (!srcType || !dstType)
1779 return srcType != dstType &&
1783LogicalResult arith::ConvertFOp::verify() {
1786 if (srcType == dstType)
1787 return emitError(
"result element type ")
1788 << dstType <<
" must be different from operand element type "
1790 if (srcType.getWidth() != dstType.getWidth())
1791 return emitError(
"result element type ")
1792 << dstType <<
" must have the same bitwidth as operand element type "
1801bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
1806LogicalResult arith::ScalingTruncFOp::verify() {
1814void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1815 MLIRContext *context) {
1816 patterns.
add<AndOfExtUI, AndOfExtSI>(context);
1823void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1824 MLIRContext *context) {
1825 patterns.
add<OrOfExtUI, OrOfExtSI>(context);
1832template <
typename From,
typename To>
1840 return srcType && dstType;
1851OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1854 adaptor.getOperands(),
getType(),
1855 [&resEleType](
const APInt &a,
bool &castStatus) {
1856 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1857 APFloat apf(floatTy.getFloatSemantics(),
1858 APInt::getZero(floatTy.getWidth()));
1859 apf.convertFromAPInt(a,
false,
1860 APFloat::rmNearestTiesToEven);
1865void arith::UIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1866 MLIRContext *context) {
1867 patterns.
add<UIToFPOfExtUI>(context);
1878OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1881 adaptor.getOperands(),
getType(),
1882 [&resEleType](
const APInt &a,
bool &castStatus) {
1883 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1884 APFloat apf(floatTy.getFloatSemantics(),
1885 APInt::getZero(floatTy.getWidth()));
1886 apf.convertFromAPInt(a,
true,
1887 APFloat::rmNearestTiesToEven);
1892void arith::SIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1893 MLIRContext *context) {
1894 patterns.
add<SIToFPOfExtSI, SIToFPOfExtUI>(context);
1905OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1907 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1909 adaptor.getOperands(),
getType(),
1910 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1912 APSInt api(bitWidth,
true);
1913 castStatus = APFloat::opInvalidOp !=
1914 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1927OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1929 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1931 adaptor.getOperands(),
getType(),
1932 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1934 APSInt api(bitWidth,
false);
1935 castStatus = APFloat::opInvalidOp !=
1936 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1950 return intTy.getWidth();
1951 return IndexType::kInternalStorageBitWidth;
1960 if (!srcType || !dstType)
1964 (srcType.isSignlessInteger() && dstType.
isIndex());
1967bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1972OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1974 unsigned resultBitwidth = 64;
1976 resultBitwidth = intTy.getWidth();
1979 adaptor.getOperands(),
getType(),
1980 [resultBitwidth](
const APInt &a,
bool & ) {
1981 return a.sextOrTrunc(resultBitwidth);
1988 if (
auto inner = getOperand().getDefiningOp<arith::IndexCastOp>()) {
1989 Value x = inner.getOperand();
1998void arith::IndexCastOp::getCanonicalizationPatterns(
1999 RewritePatternSet &patterns, MLIRContext *context) {
2000 patterns.
add<IndexCastOfExtSI>(context);
2007bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
2012OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
2014 unsigned resultBitwidth = 64;
2016 resultBitwidth = intTy.getWidth();
2019 adaptor.getOperands(),
getType(),
2020 [resultBitwidth](
const APInt &a,
bool & ) {
2021 return a.zextOrTrunc(resultBitwidth);
2028 if (
auto inner = getOperand().getDefiningOp<arith::IndexCastUIOp>()) {
2029 Value x = inner.getOperand();
2038void arith::IndexCastUIOp::getCanonicalizationPatterns(
2039 RewritePatternSet &patterns, MLIRContext *context) {
2040 patterns.
add<IndexCastUIOfExtUI>(context);
2053 if (!srcType || !dstType)
2059OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
2061 auto operand = adaptor.getIn();
2066 if (
auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
2067 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
2069 if (llvm::isa<ShapedType>(resType))
2077 APInt bits = llvm::isa<FloatAttr>(operand)
2078 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
2079 : llvm::cast<IntegerAttr>(operand).getValue();
2081 "trying to fold on broken IR: operands have incompatible types");
2083 if (
auto resFloatType = dyn_cast<FloatType>(resType))
2084 return FloatAttr::get(resType,
2085 APFloat(resFloatType.getFloatSemantics(), bits));
2086 return IntegerAttr::get(resType, bits);
2089void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2090 MLIRContext *context) {
2091 patterns.
add<BitcastOfBitcast>(context);
2101 const APInt &
lhs,
const APInt &
rhs) {
2102 switch (predicate) {
2103 case arith::CmpIPredicate::eq:
2105 case arith::CmpIPredicate::ne:
2107 case arith::CmpIPredicate::slt:
2109 case arith::CmpIPredicate::sle:
2111 case arith::CmpIPredicate::sgt:
2113 case arith::CmpIPredicate::sge:
2115 case arith::CmpIPredicate::ult:
2117 case arith::CmpIPredicate::ule:
2119 case arith::CmpIPredicate::ugt:
2121 case arith::CmpIPredicate::uge:
2124 llvm_unreachable(
"unknown cmpi predicate kind");
2129 switch (predicate) {
2130 case arith::CmpIPredicate::eq:
2131 case arith::CmpIPredicate::sle:
2132 case arith::CmpIPredicate::sge:
2133 case arith::CmpIPredicate::ule:
2134 case arith::CmpIPredicate::uge:
2136 case arith::CmpIPredicate::ne:
2137 case arith::CmpIPredicate::slt:
2138 case arith::CmpIPredicate::sgt:
2139 case arith::CmpIPredicate::ult:
2140 case arith::CmpIPredicate::ugt:
2143 llvm_unreachable(
"unknown cmpi predicate kind");
2147 if (
auto intType = dyn_cast<IntegerType>(t)) {
2148 return intType.getWidth();
2150 if (
auto vectorIntType = dyn_cast<VectorType>(t)) {
2151 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
2153 return std::nullopt;
2156OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2158 if (getLhs() == getRhs()) {
2164 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2166 std::optional<int64_t> integerWidth =
2168 if (integerWidth && integerWidth.value() == 1 &&
2169 getPredicate() == arith::CmpIPredicate::ne)
2170 return extOp.getOperand();
2172 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2174 std::optional<int64_t> integerWidth =
2176 if (integerWidth && integerWidth.value() == 1 &&
2177 getPredicate() == arith::CmpIPredicate::ne)
2178 return extOp.getOperand();
2183 getPredicate() == arith::CmpIPredicate::ne)
2190 getPredicate() == arith::CmpIPredicate::eq)
2195 if (adaptor.getLhs() && !adaptor.getRhs()) {
2197 using Pred = CmpIPredicate;
2198 const std::pair<Pred, Pred> invPreds[] = {
2199 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2200 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2201 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2202 {Pred::ne, Pred::ne},
2204 Pred origPred = getPredicate();
2205 for (
auto pred : invPreds) {
2206 if (origPred == pred.first) {
2207 setPredicate(pred.second);
2208 Value
lhs = getLhs();
2209 Value
rhs = getRhs();
2210 getLhsMutable().assign(
rhs);
2211 getRhsMutable().assign(
lhs);
2215 llvm_unreachable(
"unknown cmpi predicate kind");
2220 if (
auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2223 [pred = getPredicate()](
const APInt &
lhs,
const APInt &
rhs) {
2232void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2233 MLIRContext *context) {
2234 patterns.
insert<CmpIExtSI, CmpIExtUI>(context);
2244 const APFloat &
lhs,
const APFloat &
rhs) {
2245 auto cmpResult =
lhs.compare(
rhs);
2246 switch (predicate) {
2247 case arith::CmpFPredicate::AlwaysFalse:
2249 case arith::CmpFPredicate::OEQ:
2250 return cmpResult == APFloat::cmpEqual;
2251 case arith::CmpFPredicate::OGT:
2252 return cmpResult == APFloat::cmpGreaterThan;
2253 case arith::CmpFPredicate::OGE:
2254 return cmpResult == APFloat::cmpGreaterThan ||
2255 cmpResult == APFloat::cmpEqual;
2256 case arith::CmpFPredicate::OLT:
2257 return cmpResult == APFloat::cmpLessThan;
2258 case arith::CmpFPredicate::OLE:
2259 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2260 case arith::CmpFPredicate::ONE:
2261 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2262 case arith::CmpFPredicate::ORD:
2263 return cmpResult != APFloat::cmpUnordered;
2264 case arith::CmpFPredicate::UEQ:
2265 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2266 case arith::CmpFPredicate::UGT:
2267 return cmpResult == APFloat::cmpUnordered ||
2268 cmpResult == APFloat::cmpGreaterThan;
2269 case arith::CmpFPredicate::UGE:
2270 return cmpResult == APFloat::cmpUnordered ||
2271 cmpResult == APFloat::cmpGreaterThan ||
2272 cmpResult == APFloat::cmpEqual;
2273 case arith::CmpFPredicate::ULT:
2274 return cmpResult == APFloat::cmpUnordered ||
2275 cmpResult == APFloat::cmpLessThan;
2276 case arith::CmpFPredicate::ULE:
2277 return cmpResult == APFloat::cmpUnordered ||
2278 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2279 case arith::CmpFPredicate::UNE:
2280 return cmpResult != APFloat::cmpEqual;
2281 case arith::CmpFPredicate::UNO:
2282 return cmpResult == APFloat::cmpUnordered;
2283 case arith::CmpFPredicate::AlwaysTrue:
2286 llvm_unreachable(
"unknown cmpf predicate kind");
2290 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2291 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2294 if (
lhs &&
lhs.getValue().isNaN())
2296 if (
rhs &&
rhs.getValue().isNaN())
2312 using namespace arith;
2314 case CmpFPredicate::UEQ:
2315 case CmpFPredicate::OEQ:
2316 return CmpIPredicate::eq;
2317 case CmpFPredicate::UGT:
2318 case CmpFPredicate::OGT:
2319 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2320 case CmpFPredicate::UGE:
2321 case CmpFPredicate::OGE:
2322 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2323 case CmpFPredicate::ULT:
2324 case CmpFPredicate::OLT:
2325 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2326 case CmpFPredicate::ULE:
2327 case CmpFPredicate::OLE:
2328 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2329 case CmpFPredicate::UNE:
2330 case CmpFPredicate::ONE:
2331 return CmpIPredicate::ne;
2333 llvm_unreachable(
"Unexpected predicate!");
2343 const APFloat &
rhs = flt.getValue();
2351 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2352 int mantissaWidth = floatTy.getFPMantissaWidth();
2353 if (mantissaWidth <= 0)
2359 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2361 intVal = si.getIn();
2362 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2364 intVal = ui.getIn();
2371 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2372 auto intWidth = intTy.getWidth();
2375 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2380 if ((
int)intWidth > mantissaWidth) {
2382 int exponent = ilogb(
rhs);
2383 if (exponent == APFloat::IEK_Inf) {
2384 int maxExponent = ilogb(APFloat::getLargest(
rhs.getSemantics()));
2385 if (maxExponent < (
int)valueBits) {
2392 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2401 switch (op.getPredicate()) {
2402 case CmpFPredicate::ORD:
2407 case CmpFPredicate::UNO:
2420 APFloat signedMax(
rhs.getSemantics());
2421 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2422 APFloat::rmNearestTiesToEven);
2423 if (signedMax <
rhs) {
2424 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2425 pred == CmpIPredicate::sle)
2436 APFloat unsignedMax(
rhs.getSemantics());
2437 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2438 APFloat::rmNearestTiesToEven);
2439 if (unsignedMax <
rhs) {
2440 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2441 pred == CmpIPredicate::ule)
2453 APFloat signedMin(
rhs.getSemantics());
2454 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2455 APFloat::rmNearestTiesToEven);
2456 if (signedMin >
rhs) {
2457 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2458 pred == CmpIPredicate::sge)
2468 APFloat unsignedMin(
rhs.getSemantics());
2469 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2470 APFloat::rmNearestTiesToEven);
2471 if (unsignedMin >
rhs) {
2472 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2473 pred == CmpIPredicate::uge)
2488 APSInt rhsInt(intWidth, isUnsigned);
2489 if (APFloat::opInvalidOp ==
2490 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2496 if (!
rhs.isZero()) {
2497 APFloat apf(floatTy.getFloatSemantics(),
2498 APInt::getZero(floatTy.getWidth()));
2499 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2501 bool equal = apf ==
rhs;
2507 case CmpIPredicate::ne:
2511 case CmpIPredicate::eq:
2515 case CmpIPredicate::ule:
2518 if (
rhs.isNegative()) {
2524 case CmpIPredicate::sle:
2527 if (
rhs.isNegative())
2528 pred = CmpIPredicate::slt;
2530 case CmpIPredicate::ult:
2533 if (
rhs.isNegative()) {
2538 pred = CmpIPredicate::ule;
2540 case CmpIPredicate::slt:
2543 if (!
rhs.isNegative())
2544 pred = CmpIPredicate::sle;
2546 case CmpIPredicate::ugt:
2549 if (
rhs.isNegative()) {
2555 case CmpIPredicate::sgt:
2558 if (
rhs.isNegative())
2559 pred = CmpIPredicate::sge;
2561 case CmpIPredicate::uge:
2564 if (
rhs.isNegative()) {
2569 pred = CmpIPredicate::ugt;
2571 case CmpIPredicate::sge:
2574 if (!
rhs.isNegative())
2575 pred = CmpIPredicate::sgt;
2585 ConstantOp::create(rewriter, op.getLoc(), intVal.
getType(),
2591void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2592 MLIRContext *context) {
2593 patterns.
insert<CmpFIntToFPConst>(context);
2607 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2623 arith::XOrIOp::create(
2624 rewriter, op.getLoc(), op.getCondition(),
2626 op.getCondition().
getType(), 1)));
2634void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2635 MLIRContext *context) {
2636 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2637 SelectI1ToNot, SelectToExtUI>(context);
2640OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2641 Value trueVal = getTrueValue();
2642 Value falseVal = getFalseValue();
2643 if (trueVal == falseVal)
2646 Value condition = getCondition();
2664 if (
getType().isSignlessInteger(1) &&
2670 auto pred = cmp.getPredicate();
2671 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2672 auto cmpLhs = cmp.getLhs();
2673 auto cmpRhs = cmp.getRhs();
2681 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2682 (cmpRhs == trueVal && cmpLhs == falseVal))
2683 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2690 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2692 assert(cond.getType().hasStaticShape() &&
2693 "DenseElementsAttr must have static shape");
2695 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2697 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2698 SmallVector<Attribute> results;
2699 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2700 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2701 cond.value_end<BoolAttr>());
2702 auto lhsVals = llvm::make_range(
lhs.value_begin<Attribute>(),
2703 lhs.value_end<Attribute>());
2704 auto rhsVals = llvm::make_range(
rhs.value_begin<Attribute>(),
2705 rhs.value_end<Attribute>());
2707 for (
auto [condVal, lhsVal, rhsVal] :
2708 llvm::zip_equal(condVals, lhsVals, rhsVals))
2709 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2719ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &
result) {
2720 Type conditionType, resultType;
2721 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2729 conditionType = resultType;
2736 result.addTypes(resultType);
2738 {conditionType, resultType, resultType},
2742void arith::SelectOp::print(OpAsmPrinter &p) {
2743 p <<
" " << getOperands();
2746 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().
getType()))
2747 p << condType <<
", ";
2751LogicalResult arith::SelectOp::verify() {
2752 Type conditionType = getCondition().getType();
2759 if (!llvm::isa<TensorType, VectorType>(resultType))
2760 return emitOpError() <<
"expected condition to be a signless i1, but got "
2763 if (conditionType != shapedConditionType) {
2764 return emitOpError() <<
"expected condition type to have the same shape "
2765 "as the result type, expected "
2766 << shapedConditionType <<
", but got "
2775OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2780 bool bounded =
false;
2782 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2783 bounded = b.ult(b.getBitWidth());
2786 return bounded ?
result : Attribute();
2793OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2798 bool bounded =
false;
2800 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2801 bounded = b.ult(b.getBitWidth());
2804 return bounded ?
result : Attribute();
2811OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2816 bool bounded =
false;
2818 adaptor.getOperands(), [&](
const APInt &a,
const APInt &
b) {
2819 bounded = b.ult(b.getBitWidth());
2822 return bounded ?
result : Attribute();
2832 bool useOnlyFiniteValue) {
2834 case AtomicRMWKind::maximumf: {
2835 const llvm::fltSemantics &semantic =
2836 llvm::cast<FloatType>(resultType).getFloatSemantics();
2837 APFloat identity = useOnlyFiniteValue
2838 ? APFloat::getLargest(semantic,
true)
2839 : APFloat::getInf(semantic,
true);
2842 case AtomicRMWKind::maxnumf: {
2843 const llvm::fltSemantics &semantic =
2844 llvm::cast<FloatType>(resultType).getFloatSemantics();
2845 APFloat identity = APFloat::getNaN(semantic,
true);
2848 case AtomicRMWKind::addf:
2849 case AtomicRMWKind::addi:
2850 case AtomicRMWKind::maxu:
2851 case AtomicRMWKind::ori:
2852 case AtomicRMWKind::xori:
2854 case AtomicRMWKind::andi:
2857 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2858 case AtomicRMWKind::maxs:
2860 resultType, APInt::getSignedMinValue(
2861 llvm::cast<IntegerType>(resultType).getWidth()));
2862 case AtomicRMWKind::minimumf: {
2863 const llvm::fltSemantics &semantic =
2864 llvm::cast<FloatType>(resultType).getFloatSemantics();
2865 APFloat identity = useOnlyFiniteValue
2866 ? APFloat::getLargest(semantic,
false)
2867 : APFloat::getInf(semantic,
false);
2871 case AtomicRMWKind::minnumf: {
2872 const llvm::fltSemantics &semantic =
2873 llvm::cast<FloatType>(resultType).getFloatSemantics();
2874 APFloat identity = APFloat::getNaN(semantic,
false);
2877 case AtomicRMWKind::mins:
2879 resultType, APInt::getSignedMaxValue(
2880 llvm::cast<IntegerType>(resultType).getWidth()));
2881 case AtomicRMWKind::minu:
2884 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2885 case AtomicRMWKind::muli:
2887 case AtomicRMWKind::mulf:
2899 std::optional<AtomicRMWKind> maybeKind =
2902 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2903 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2904 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2905 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2906 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2907 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2909 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2910 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2911 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::xori; })
2912 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2913 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2914 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2915 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2916 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2917 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2918 .Default(std::nullopt);
2920 return std::nullopt;
2923 bool useOnlyFiniteValue =
false;
2924 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2925 if (fmfOpInterface) {
2926 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2927 useOnlyFiniteValue =
2928 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2936 useOnlyFiniteValue);
2942 bool useOnlyFiniteValue) {
2944 useOnlyFiniteValue))
2945 return arith::ConstantOp::create(builder, loc, attr);
2954 case AtomicRMWKind::addf:
2955 return arith::AddFOp::create(builder, loc,
lhs,
rhs);
2956 case AtomicRMWKind::addi:
2957 return arith::AddIOp::create(builder, loc,
lhs,
rhs);
2958 case AtomicRMWKind::mulf:
2959 return arith::MulFOp::create(builder, loc,
lhs,
rhs);
2960 case AtomicRMWKind::muli:
2961 return arith::MulIOp::create(builder, loc,
lhs,
rhs);
2962 case AtomicRMWKind::maximumf:
2963 return arith::MaximumFOp::create(builder, loc,
lhs,
rhs);
2964 case AtomicRMWKind::minimumf:
2965 return arith::MinimumFOp::create(builder, loc,
lhs,
rhs);
2966 case AtomicRMWKind::maxnumf:
2967 return arith::MaxNumFOp::create(builder, loc,
lhs,
rhs);
2968 case AtomicRMWKind::minnumf:
2969 return arith::MinNumFOp::create(builder, loc,
lhs,
rhs);
2970 case AtomicRMWKind::maxs:
2971 return arith::MaxSIOp::create(builder, loc,
lhs,
rhs);
2972 case AtomicRMWKind::mins:
2973 return arith::MinSIOp::create(builder, loc,
lhs,
rhs);
2974 case AtomicRMWKind::maxu:
2975 return arith::MaxUIOp::create(builder, loc,
lhs,
rhs);
2976 case AtomicRMWKind::minu:
2977 return arith::MinUIOp::create(builder, loc,
lhs,
rhs);
2978 case AtomicRMWKind::ori:
2979 return arith::OrIOp::create(builder, loc,
lhs,
rhs);
2980 case AtomicRMWKind::andi:
2981 return arith::AndIOp::create(builder, loc,
lhs,
rhs);
2982 case AtomicRMWKind::xori:
2983 return arith::XOrIOp::create(builder, loc,
lhs,
rhs);
2996#define GET_OP_CLASSES
2997#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
3003#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static constexpr llvm::RoundingMode kDefaultRoundingMode
Default rounding mode according to default LLVM floating-point environment.
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=kDefaultRoundingMode)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(std::optional< RoundingMode > roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static unsigned getIndexCastWidth(Type t)
Return the bit-width of t for the purpose of index_cast width checks.
static LogicalResult verifyTruncateOp(Op op)
static Type getElementType(Type type)
Determine the element type of type.
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Specialization of arith.constant op that returns an integer of index type.
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Specialization of arith.constant op that returns an integer value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
llvm::function_ref< Fn > function_ref
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.