26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
44 function_ref<APInt(
const APInt &,
const APInt &)> binFn) {
45 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
46 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
67 static IntegerOverflowFlagsAttr
69 IntegerOverflowFlagsAttr val2) {
71 val1.getValue() & val2.getValue());
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
98 llvm_unreachable(
"unknown cmpi predicate kind");
107 static llvm::RoundingMode
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
121 llvm_unreachable(
"Unhandled rounding mode");
151 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
162 #include "ArithCanonicalization.inc"
172 if (
auto shapedType = dyn_cast<ShapedType>(type))
173 return shapedType.cloneWith(std::nullopt, i1Type);
174 if (llvm::isa<UnrankedTensorType>(type))
183 void arith::ConstantOp::getAsmResultNames(
186 if (
auto intCst = dyn_cast<IntegerAttr>(getValue())) {
187 auto intType = dyn_cast<IntegerType>(type);
190 if (intType && intType.getWidth() == 1)
191 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
195 llvm::raw_svector_ostream specialName(specialNameBuffer);
196 specialName <<
'c' << intCst.getValue();
198 specialName <<
'_' << type;
199 setNameFn(getResult(), specialName.str());
201 setNameFn(getResult(),
"cst");
210 if (llvm::isa<IntegerType>(type) &&
211 !llvm::cast<IntegerType>(type).isSignless())
212 return emitOpError(
"integer return type must be signless");
214 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
216 "value must be an integer, float, or elements attribute");
222 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
224 "intializing scalable vectors with elements attribute is not supported"
225 " unless it's a vector splat");
229 bool arith::ConstantOp::isBuildableWith(
Attribute value,
Type type) {
231 auto typedAttr = dyn_cast<TypedAttr>(value);
232 if (!typedAttr || typedAttr.getType() != type)
235 if (llvm::isa<IntegerType>(type) &&
236 !llvm::cast<IntegerType>(type).isSignless())
239 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
244 if (isBuildableWith(value, type))
245 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
249 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
252 int64_t value,
unsigned width) {
254 arith::ConstantOp::build(builder, result, type,
263 build(builder, state, value, width);
264 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
265 assert(result &&
"builder didn't return the right type");
272 return create(builder, builder.
getLoc(), value, width);
276 Type type, int64_t value) {
277 arith::ConstantOp::build(builder, result, type,
285 build(builder, state, type, value);
286 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
287 assert(result &&
"builder didn't return the right type");
292 Type type, int64_t value) {
293 return create(builder, builder.
getLoc(), type, value);
297 Type type,
const APInt &value) {
298 arith::ConstantOp::build(builder, result, type,
304 const APInt &value) {
306 build(builder, state, type, value);
307 auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
308 assert(result &&
"builder didn't return the right type");
314 const APInt &value) {
315 return create(builder, builder.
getLoc(), type, value);
319 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
320 return constOp.getType().isSignlessInteger();
325 FloatType type,
const APFloat &value) {
326 arith::ConstantOp::build(builder, result, type,
333 const APFloat &value) {
335 build(builder, state, type, value);
336 auto result = dyn_cast<ConstantFloatOp>(builder.
create(state));
337 assert(result &&
"builder didn't return the right type");
343 const APFloat &value) {
344 return create(builder, builder.
getLoc(), type, value);
348 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
349 return llvm::isa<FloatType>(constOp.getType());
355 arith::ConstantOp::build(builder, result, builder.
getIndexType(),
363 build(builder, state, value);
364 auto result = dyn_cast<ConstantIndexOp>(builder.
create(state));
365 assert(result &&
"builder didn't return the right type");
371 return create(builder, builder.
getLoc(), value);
375 if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
376 return constOp.getType().isIndex();
384 "type doesn't have a zero representation");
386 assert(zeroAttr &&
"unsupported type for zero attribute");
387 return arith::ConstantOp::create(builder, loc, zeroAttr);
400 if (
auto sub = getLhs().getDefiningOp<SubIOp>())
401 if (getRhs() == sub.getRhs())
405 if (
auto sub = getRhs().getDefiningOp<SubIOp>())
406 if (getLhs() == sub.getRhs())
409 return constFoldBinaryOp<IntegerAttr>(
410 adaptor.getOperands(),
411 [](APInt a,
const APInt &b) { return std::move(a) + b; });
416 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
417 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
424 std::optional<SmallVector<int64_t, 4>>
425 arith::AddUIExtendedOp::getShapeForUnroll() {
426 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
427 return llvm::to_vector<4>(vt.getShape());
434 return sum.ult(operand) ? APInt::getAllOnes(1) :
APInt::getZero(1);
438 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
440 Type overflowTy = getOverflow().getType();
446 results.push_back(getLhs());
447 results.push_back(falseValue);
455 if (
Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
456 adaptor.getOperands(),
457 [](APInt a,
const APInt &b) { return std::move(a) + b; })) {
458 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
459 ArrayRef({sumAttr, adaptor.getLhs()}),
465 results.push_back(sumAttr);
466 results.push_back(overflowAttr);
473 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
475 patterns.add<AddUIExtendedToAddI>(context);
484 if (getOperand(0) == getOperand(1)) {
485 auto shapedType = dyn_cast<ShapedType>(
getType());
487 if (!shapedType || shapedType.hasStaticShape())
494 if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
496 if (getRhs() ==
add.getRhs())
499 if (getRhs() ==
add.getLhs())
503 return constFoldBinaryOp<IntegerAttr>(
504 adaptor.getOperands(),
505 [](APInt a,
const APInt &b) { return std::move(a) - b; });
510 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
511 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
512 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
529 return constFoldBinaryOp<IntegerAttr>(
530 adaptor.getOperands(),
531 [](
const APInt &a,
const APInt &b) { return a * b; });
534 void arith::MulIOp::getAsmResultNames(
536 if (!isa<IndexType>(
getType()))
542 return op && op->getName().getStringRef() ==
"vector.vscale";
545 IntegerAttr baseValue;
548 isVscale(b.getDefiningOp());
551 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
556 llvm::raw_svector_ostream specialName(specialNameBuffer);
557 specialName <<
'c' << baseValue.getInt() <<
"_vscale";
558 setNameFn(getResult(), specialName.str());
563 patterns.add<MulIMulIConstant>(context);
570 std::optional<SmallVector<int64_t, 4>>
571 arith::MulSIExtendedOp::getShapeForUnroll() {
572 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
573 return llvm::to_vector<4>(vt.getShape());
578 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
583 results.push_back(zero);
584 results.push_back(zero);
589 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
590 adaptor.getOperands(),
591 [](
const APInt &a,
const APInt &b) { return a * b; })) {
593 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
594 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
595 return llvm::APIntOps::mulhs(a, b);
597 assert(highAttr &&
"Unexpected constant-folding failure");
599 results.push_back(lowAttr);
600 results.push_back(highAttr);
607 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
609 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
616 std::optional<SmallVector<int64_t, 4>>
617 arith::MulUIExtendedOp::getShapeForUnroll() {
618 if (
auto vt = dyn_cast<VectorType>(
getType(0)))
619 return llvm::to_vector<4>(vt.getShape());
624 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
629 results.push_back(zero);
630 results.push_back(zero);
638 results.push_back(getLhs());
639 results.push_back(zero);
644 if (
Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
645 adaptor.getOperands(),
646 [](
const APInt &a,
const APInt &b) { return a * b; })) {
648 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
649 adaptor.getOperands(), [](
const APInt &a,
const APInt &b) {
650 return llvm::APIntOps::mulhu(a, b);
652 assert(highAttr &&
"Unexpected constant-folding failure");
654 results.push_back(lowAttr);
655 results.push_back(highAttr);
662 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
664 patterns.add<MulUIExtendedToMulI>(context);
673 arith::IntegerOverflowFlags ovfFlags) {
675 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
678 if (mul.getLhs() == rhs)
681 if (mul.getRhs() == rhs)
687 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
693 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
698 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
699 [&](APInt a,
const APInt &b) {
727 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
733 if (
Value val =
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
737 bool overflowOrDiv0 =
false;
738 auto result = constFoldBinaryOp<IntegerAttr>(
739 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
740 if (overflowOrDiv0 || !b) {
741 overflowOrDiv0 = true;
744 return a.sdiv_ov(b, overflowOrDiv0);
747 return overflowOrDiv0 ?
Attribute() : result;
774 APInt one(a.getBitWidth(), 1,
true);
775 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
776 return val.sadd_ov(one, overflow);
783 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
788 bool overflowOrDiv0 =
false;
789 auto result = constFoldBinaryOp<IntegerAttr>(
790 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
791 if (overflowOrDiv0 || !b) {
792 overflowOrDiv0 = true;
795 APInt quotient = a.udiv(b);
798 APInt one(a.getBitWidth(), 1,
true);
799 return quotient.uadd_ov(one, overflowOrDiv0);
802 return overflowOrDiv0 ?
Attribute() : result;
813 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
821 bool overflowOrDiv0 =
false;
822 auto result = constFoldBinaryOp<IntegerAttr>(
823 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
824 if (overflowOrDiv0 || !b) {
825 overflowOrDiv0 = true;
831 unsigned bits = a.getBitWidth();
833 bool aGtZero = a.sgt(zero);
834 bool bGtZero = b.sgt(zero);
835 if (aGtZero && bGtZero) {
842 bool overflowNegA =
false;
843 bool overflowNegB =
false;
844 bool overflowDiv =
false;
845 bool overflowNegRes =
false;
846 if (!aGtZero && !bGtZero) {
848 APInt posA = zero.ssub_ov(a, overflowNegA);
849 APInt posB = zero.ssub_ov(b, overflowNegB);
851 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
854 if (!aGtZero && bGtZero) {
856 APInt posA = zero.ssub_ov(a, overflowNegA);
857 APInt div = posA.sdiv_ov(b, overflowDiv);
858 APInt res = zero.ssub_ov(div, overflowNegRes);
859 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
863 APInt posB = zero.ssub_ov(b, overflowNegB);
864 APInt div = a.sdiv_ov(posB, overflowDiv);
865 APInt res = zero.ssub_ov(div, overflowNegRes);
867 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
871 return overflowOrDiv0 ?
Attribute() : result;
882 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
888 bool overflowOrDiv =
false;
889 auto result = constFoldBinaryOp<IntegerAttr>(
890 adaptor.getOperands(), [&](APInt a,
const APInt &b) {
892 overflowOrDiv = true;
895 return a.sfloordiv_ov(b, overflowOrDiv);
898 return overflowOrDiv ?
Attribute() : result;
905 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
912 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
913 [&](APInt a,
const APInt &b) {
914 if (div0 || b.isZero()) {
928 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
935 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
936 [&](APInt a,
const APInt &b) {
937 if (div0 || b.isZero()) {
953 for (
bool reversePrev : {
false,
true}) {
954 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
955 .getDefiningOp<arith::AndIOp>();
959 Value other = (reversePrev ? op.getLhs() : op.getRhs());
960 if (other != prev.getLhs() && other != prev.getRhs())
963 return prev.getResult();
975 intValue.isAllOnes())
980 intValue.isAllOnes())
985 intValue.isAllOnes())
992 return constFoldBinaryOp<IntegerAttr>(
993 adaptor.getOperands(),
994 [](APInt a,
const APInt &b) { return std::move(a) & b; });
1004 if (rhsVal.isZero())
1007 if (rhsVal.isAllOnes())
1008 return adaptor.getRhs();
1015 intValue.isAllOnes())
1016 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1020 intValue.isAllOnes())
1021 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1023 return constFoldBinaryOp<IntegerAttr>(
1024 adaptor.getOperands(),
1025 [](APInt a,
const APInt &b) { return std::move(a) | b; });
1032 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1037 if (getLhs() == getRhs())
1041 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1042 if (prev.getRhs() == getRhs())
1043 return prev.getLhs();
1044 if (prev.getLhs() == getRhs())
1045 return prev.getRhs();
1049 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1050 if (prev.getRhs() == getLhs())
1051 return prev.getLhs();
1052 if (prev.getLhs() == getLhs())
1053 return prev.getRhs();
1056 return constFoldBinaryOp<IntegerAttr>(
1057 adaptor.getOperands(),
1058 [](APInt a,
const APInt &b) { return std::move(a) ^ b; });
1063 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1070 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1072 if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1073 return op.getOperand();
1074 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
1075 [](
const APFloat &a) { return -a; });
1082 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1087 return constFoldBinaryOp<FloatAttr>(
1088 adaptor.getOperands(),
1089 [](
const APFloat &a,
const APFloat &b) { return a + b; });
1096 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1101 return constFoldBinaryOp<FloatAttr>(
1102 adaptor.getOperands(),
1103 [](
const APFloat &a,
const APFloat &b) { return a - b; });
1110 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1112 if (getLhs() == getRhs())
1119 return constFoldBinaryOp<FloatAttr>(
1120 adaptor.getOperands(),
1121 [](
const APFloat &a,
const APFloat &b) { return llvm::maximum(a, b); });
1128 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1130 if (getLhs() == getRhs())
1137 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1146 if (getLhs() == getRhs())
1152 if (intValue.isMaxSignedValue())
1155 if (intValue.isMinSignedValue())
1159 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1160 [](
const APInt &a,
const APInt &b) {
1161 return llvm::APIntOps::smax(a, b);
1171 if (getLhs() == getRhs())
1177 if (intValue.isMaxValue())
1180 if (intValue.isMinValue())
1184 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1185 [](
const APInt &a,
const APInt &b) {
1186 return llvm::APIntOps::umax(a, b);
1194 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1196 if (getLhs() == getRhs())
1203 return constFoldBinaryOp<FloatAttr>(
1204 adaptor.getOperands(),
1205 [](
const APFloat &a,
const APFloat &b) { return llvm::minimum(a, b); });
1212 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1214 if (getLhs() == getRhs())
1221 return constFoldBinaryOp<FloatAttr>(
1222 adaptor.getOperands(),
1223 [](
const APFloat &a,
const APFloat &b) { return llvm::minnum(a, b); });
1232 if (getLhs() == getRhs())
1238 if (intValue.isMinSignedValue())
1241 if (intValue.isMaxSignedValue())
1245 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1246 [](
const APInt &a,
const APInt &b) {
1247 return llvm::APIntOps::smin(a, b);
1257 if (getLhs() == getRhs())
1263 if (intValue.isMinValue())
1266 if (intValue.isMaxValue())
1270 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1271 [](
const APInt &a,
const APInt &b) {
1272 return llvm::APIntOps::umin(a, b);
1280 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1285 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1286 arith::FastMathFlags::nsz)) {
1292 return constFoldBinaryOp<FloatAttr>(
1293 adaptor.getOperands(),
1294 [](
const APFloat &a,
const APFloat &b) { return a * b; });
1306 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1311 return constFoldBinaryOp<FloatAttr>(
1312 adaptor.getOperands(),
1313 [](
const APFloat &a,
const APFloat &b) { return a / b; });
1325 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1326 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1327 [](
const APFloat &a,
const APFloat &b) {
1332 (void)result.mod(b);
1341 template <
typename... Types>
1347 template <
typename... ShapedTypes,
typename... ElementTypes>
1350 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1354 if (!llvm::isa<ElementTypes...>(underlyingType))
1357 return underlyingType;
1361 template <
typename... ElementTypes>
1368 template <
typename... ElementTypes>
1377 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1378 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1379 if (!rankedTensorA || !rankedTensorB)
1381 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1385 if (inputs.size() != 1 || outputs.size() != 1)
1397 template <
typename ValType,
typename Op>
1402 if (llvm::cast<ValType>(srcType).getWidth() >=
1403 llvm::cast<ValType>(dstType).getWidth())
1405 << dstType <<
" must be wider than operand type " << srcType;
1411 template <
typename ValType,
typename Op>
1416 if (llvm::cast<ValType>(srcType).getWidth() <=
1417 llvm::cast<ValType>(dstType).getWidth())
1419 << dstType <<
" must be shorter than operand type " << srcType;
1425 template <
template <
typename>
class WidthComparator,
typename... ElementTypes>
1430 auto srcType =
getTypeIfLike<ElementTypes...>(inputs.front());
1431 auto dstType =
getTypeIfLike<ElementTypes...>(outputs.front());
1432 if (!srcType || !dstType)
1435 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1436 srcType.getIntOrFloatBitWidth());
1442 APFloat sourceValue,
const llvm::fltSemantics &targetSemantics,
1443 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1444 bool losesInfo =
false;
1445 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1446 if (losesInfo || status != APFloat::opOK)
1456 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1457 if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1458 getInMutable().assign(lhs.getIn());
1463 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1464 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1465 adaptor.getOperands(),
getType(),
1466 [bitWidth](
const APInt &a,
bool &castStatus) {
1467 return a.zext(bitWidth);
1472 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1476 return verifyExtOp<IntegerType>(*
this);
1483 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1484 if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1485 getInMutable().assign(lhs.getIn());
1490 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1491 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1492 adaptor.getOperands(),
getType(),
1493 [bitWidth](
const APInt &a,
bool &castStatus) {
1494 return a.sext(bitWidth);
1499 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1504 patterns.add<ExtSIOfExtUI>(context);
1508 return verifyExtOp<IntegerType>(*
this);
1517 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1518 if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1519 if (truncFOp.getOperand().getType() ==
getType()) {
1520 arith::FastMathFlags truncFMF =
1521 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1522 bool isTruncContract =
1524 arith::FastMathFlags extFMF =
1525 getFastmath().value_or(arith::FastMathFlags::none);
1526 bool isExtContract =
1528 if (isTruncContract && isExtContract) {
1529 return truncFOp.getOperand();
1535 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1536 return constFoldCastOp<FloatAttr, FloatAttr>(
1537 adaptor.getOperands(),
getType(),
1538 [&targetSemantics](
const APFloat &a,
bool &castStatus) {
1549 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1558 bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
1560 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1564 return verifyExtOp<FloatType>(*
this);
1571 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1572 if (
matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1579 if (llvm::cast<IntegerType>(srcType).getWidth() >
1580 llvm::cast<IntegerType>(dstType).getWidth()) {
1587 if (srcType == dstType)
1592 if (
matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1593 setOperand(getOperand().getDefiningOp()->getOperand(0));
1598 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1599 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1600 adaptor.getOperands(),
getType(),
1601 [bitWidth](
const APInt &a,
bool &castStatus) {
1602 return a.trunc(bitWidth);
1607 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1613 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1618 return verifyTruncateOp<IntegerType>(*
this);
1627 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1629 if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1630 Value src = extOp.getIn();
1632 auto intermediateType =
1635 if (llvm::APFloatBase::isRepresentableBy(
1636 srcType.getFloatSemantics(),
1637 intermediateType.getFloatSemantics())) {
1639 if (srcType.getWidth() > resElemType.getWidth()) {
1645 if (srcType == resElemType)
1650 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1651 return constFoldCastOp<FloatAttr, FloatAttr>(
1652 adaptor.getOperands(),
getType(),
1653 [
this, &targetSemantics](
const APFloat &a,
bool &castStatus) {
1654 RoundingMode roundingMode =
1655 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1656 llvm::RoundingMode llvmRoundingMode =
1658 FailureOr<APFloat> result =
1670 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1674 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1678 return verifyTruncateOp<FloatType>(*
this);
1685 bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
1687 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1691 return verifyTruncateOp<FloatType>(*
this);
1700 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1709 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1716 template <
typename From,
typename To>
1721 auto srcType = getTypeIfLike<From>(inputs.front());
1722 auto dstType = getTypeIfLike<To>(outputs.back());
1724 return srcType && dstType;
1732 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1735 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1737 return constFoldCastOp<IntegerAttr, FloatAttr>(
1738 adaptor.getOperands(),
getType(),
1739 [&resEleType](
const APInt &a,
bool &castStatus) {
1740 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1741 APFloat apf(floatTy.getFloatSemantics(),
1743 apf.convertFromAPInt(a,
false,
1744 APFloat::rmNearestTiesToEven);
1754 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1757 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1759 return constFoldCastOp<IntegerAttr, FloatAttr>(
1760 adaptor.getOperands(),
getType(),
1761 [&resEleType](
const APInt &a,
bool &castStatus) {
1762 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1763 APFloat apf(floatTy.getFloatSemantics(),
1765 apf.convertFromAPInt(a,
true,
1766 APFloat::rmNearestTiesToEven);
1776 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1779 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1781 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1782 return constFoldCastOp<FloatAttr, IntegerAttr>(
1783 adaptor.getOperands(),
getType(),
1784 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1786 APSInt api(bitWidth,
true);
1787 castStatus = APFloat::opInvalidOp !=
1788 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1798 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1801 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1803 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1804 return constFoldCastOp<FloatAttr, IntegerAttr>(
1805 adaptor.getOperands(),
getType(),
1806 [&bitWidth](
const APFloat &a,
bool &castStatus) {
1808 APSInt api(bitWidth,
false);
1809 castStatus = APFloat::opInvalidOp !=
1810 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1823 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1824 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1825 if (!srcType || !dstType)
1832 bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
1837 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1839 unsigned resultBitwidth = 64;
1841 resultBitwidth = intTy.getWidth();
1843 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1844 adaptor.getOperands(),
getType(),
1845 [resultBitwidth](
const APInt &a,
bool & ) {
1846 return a.sextOrTrunc(resultBitwidth);
1850 void arith::IndexCastOp::getCanonicalizationPatterns(
1852 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1859 bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
1864 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1866 unsigned resultBitwidth = 64;
1868 resultBitwidth = intTy.getWidth();
1870 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1871 adaptor.getOperands(),
getType(),
1872 [resultBitwidth](
const APInt &a,
bool & ) {
1873 return a.zextOrTrunc(resultBitwidth);
1877 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1879 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1890 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1891 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1892 if (!srcType || !dstType)
1898 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1900 auto operand = adaptor.getIn();
1905 if (
auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1906 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
1908 if (llvm::isa<ShapedType>(resType))
1912 if (llvm::isa<ub::PoisonAttr>(operand))
1916 APInt bits = llvm::isa<FloatAttr>(operand)
1917 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1918 : llvm::cast<IntegerAttr>(operand).getValue();
1920 "trying to fold on broken IR: operands have incompatible types");
1922 if (
auto resFloatType = dyn_cast<FloatType>(resType))
1924 APFloat(resFloatType.getFloatSemantics(), bits));
1930 patterns.add<BitcastOfBitcast>(context);
1940 const APInt &lhs,
const APInt &rhs) {
1941 switch (predicate) {
1942 case arith::CmpIPredicate::eq:
1944 case arith::CmpIPredicate::ne:
1946 case arith::CmpIPredicate::slt:
1947 return lhs.slt(rhs);
1948 case arith::CmpIPredicate::sle:
1949 return lhs.sle(rhs);
1950 case arith::CmpIPredicate::sgt:
1951 return lhs.sgt(rhs);
1952 case arith::CmpIPredicate::sge:
1953 return lhs.sge(rhs);
1954 case arith::CmpIPredicate::ult:
1955 return lhs.ult(rhs);
1956 case arith::CmpIPredicate::ule:
1957 return lhs.ule(rhs);
1958 case arith::CmpIPredicate::ugt:
1959 return lhs.ugt(rhs);
1960 case arith::CmpIPredicate::uge:
1961 return lhs.uge(rhs);
1963 llvm_unreachable(
"unknown cmpi predicate kind");
1968 switch (predicate) {
1969 case arith::CmpIPredicate::eq:
1970 case arith::CmpIPredicate::sle:
1971 case arith::CmpIPredicate::sge:
1972 case arith::CmpIPredicate::ule:
1973 case arith::CmpIPredicate::uge:
1975 case arith::CmpIPredicate::ne:
1976 case arith::CmpIPredicate::slt:
1977 case arith::CmpIPredicate::sgt:
1978 case arith::CmpIPredicate::ult:
1979 case arith::CmpIPredicate::ugt:
1982 llvm_unreachable(
"unknown cmpi predicate kind");
1986 if (
auto intType = dyn_cast<IntegerType>(t)) {
1987 return intType.getWidth();
1989 if (
auto vectorIntType = dyn_cast<VectorType>(t)) {
1990 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1992 return std::nullopt;
1995 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1997 if (getLhs() == getRhs()) {
2003 if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2005 std::optional<int64_t> integerWidth =
2007 if (integerWidth && integerWidth.value() == 1 &&
2008 getPredicate() == arith::CmpIPredicate::ne)
2009 return extOp.getOperand();
2011 if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2013 std::optional<int64_t> integerWidth =
2015 if (integerWidth && integerWidth.value() == 1 &&
2016 getPredicate() == arith::CmpIPredicate::ne)
2017 return extOp.getOperand();
2022 getPredicate() == arith::CmpIPredicate::ne)
2029 getPredicate() == arith::CmpIPredicate::eq)
2034 if (adaptor.getLhs() && !adaptor.getRhs()) {
2036 using Pred = CmpIPredicate;
2037 const std::pair<Pred, Pred> invPreds[] = {
2038 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2039 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2040 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2041 {Pred::ne, Pred::ne},
2043 Pred origPred = getPredicate();
2044 for (
auto pred : invPreds) {
2045 if (origPred == pred.first) {
2046 setPredicate(pred.second);
2047 Value lhs = getLhs();
2048 Value rhs = getRhs();
2049 getLhsMutable().assign(rhs);
2050 getRhsMutable().assign(lhs);
2054 llvm_unreachable(
"unknown cmpi predicate kind");
2059 if (
auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2060 return constFoldBinaryOp<IntegerAttr>(
2062 [pred = getPredicate()](
const APInt &lhs,
const APInt &rhs) {
2073 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2083 const APFloat &lhs,
const APFloat &rhs) {
2084 auto cmpResult = lhs.compare(rhs);
2085 switch (predicate) {
2086 case arith::CmpFPredicate::AlwaysFalse:
2088 case arith::CmpFPredicate::OEQ:
2089 return cmpResult == APFloat::cmpEqual;
2090 case arith::CmpFPredicate::OGT:
2091 return cmpResult == APFloat::cmpGreaterThan;
2092 case arith::CmpFPredicate::OGE:
2093 return cmpResult == APFloat::cmpGreaterThan ||
2094 cmpResult == APFloat::cmpEqual;
2095 case arith::CmpFPredicate::OLT:
2096 return cmpResult == APFloat::cmpLessThan;
2097 case arith::CmpFPredicate::OLE:
2098 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2099 case arith::CmpFPredicate::ONE:
2100 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2101 case arith::CmpFPredicate::ORD:
2102 return cmpResult != APFloat::cmpUnordered;
2103 case arith::CmpFPredicate::UEQ:
2104 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2105 case arith::CmpFPredicate::UGT:
2106 return cmpResult == APFloat::cmpUnordered ||
2107 cmpResult == APFloat::cmpGreaterThan;
2108 case arith::CmpFPredicate::UGE:
2109 return cmpResult == APFloat::cmpUnordered ||
2110 cmpResult == APFloat::cmpGreaterThan ||
2111 cmpResult == APFloat::cmpEqual;
2112 case arith::CmpFPredicate::ULT:
2113 return cmpResult == APFloat::cmpUnordered ||
2114 cmpResult == APFloat::cmpLessThan;
2115 case arith::CmpFPredicate::ULE:
2116 return cmpResult == APFloat::cmpUnordered ||
2117 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2118 case arith::CmpFPredicate::UNE:
2119 return cmpResult != APFloat::cmpEqual;
2120 case arith::CmpFPredicate::UNO:
2121 return cmpResult == APFloat::cmpUnordered;
2122 case arith::CmpFPredicate::AlwaysTrue:
2125 llvm_unreachable(
"unknown cmpf predicate kind");
2128 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2129 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2130 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2133 if (lhs && lhs.getValue().isNaN())
2135 if (rhs && rhs.getValue().isNaN())
2151 using namespace arith;
2153 case CmpFPredicate::UEQ:
2154 case CmpFPredicate::OEQ:
2155 return CmpIPredicate::eq;
2156 case CmpFPredicate::UGT:
2157 case CmpFPredicate::OGT:
2158 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2159 case CmpFPredicate::UGE:
2160 case CmpFPredicate::OGE:
2161 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2162 case CmpFPredicate::ULT:
2163 case CmpFPredicate::OLT:
2164 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2165 case CmpFPredicate::ULE:
2166 case CmpFPredicate::OLE:
2167 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2168 case CmpFPredicate::UNE:
2169 case CmpFPredicate::ONE:
2170 return CmpIPredicate::ne;
2172 llvm_unreachable(
"Unexpected predicate!");
2182 const APFloat &rhs = flt.getValue();
2190 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2191 int mantissaWidth = floatTy.getFPMantissaWidth();
2192 if (mantissaWidth <= 0)
2198 if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2200 intVal = si.getIn();
2201 }
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2203 intVal = ui.getIn();
2210 auto intTy = llvm::cast<IntegerType>(intVal.
getType());
2211 auto intWidth = intTy.getWidth();
2214 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2219 if ((
int)intWidth > mantissaWidth) {
2221 int exponent = ilogb(rhs);
2222 if (exponent == APFloat::IEK_Inf) {
2223 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2224 if (maxExponent < (
int)valueBits) {
2231 if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
2240 switch (op.getPredicate()) {
2241 case CmpFPredicate::ORD:
2246 case CmpFPredicate::UNO:
2259 APFloat signedMax(rhs.getSemantics());
2260 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth),
true,
2261 APFloat::rmNearestTiesToEven);
2262 if (signedMax < rhs) {
2263 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2264 pred == CmpIPredicate::sle)
2275 APFloat unsignedMax(rhs.getSemantics());
2276 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth),
false,
2277 APFloat::rmNearestTiesToEven);
2278 if (unsignedMax < rhs) {
2279 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2280 pred == CmpIPredicate::ule)
2292 APFloat signedMin(rhs.getSemantics());
2293 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth),
true,
2294 APFloat::rmNearestTiesToEven);
2295 if (signedMin > rhs) {
2296 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2297 pred == CmpIPredicate::sge)
2307 APFloat unsignedMin(rhs.getSemantics());
2308 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth),
false,
2309 APFloat::rmNearestTiesToEven);
2310 if (unsignedMin > rhs) {
2311 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2312 pred == CmpIPredicate::uge)
2327 APSInt rhsInt(intWidth, isUnsigned);
2328 if (APFloat::opInvalidOp ==
2329 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2335 if (!rhs.isZero()) {
2336 APFloat apf(floatTy.getFloatSemantics(),
2338 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2340 bool equal = apf == rhs;
2346 case CmpIPredicate::ne:
2350 case CmpIPredicate::eq:
2354 case CmpIPredicate::ule:
2357 if (rhs.isNegative()) {
2363 case CmpIPredicate::sle:
2366 if (rhs.isNegative())
2367 pred = CmpIPredicate::slt;
2369 case CmpIPredicate::ult:
2372 if (rhs.isNegative()) {
2377 pred = CmpIPredicate::ule;
2379 case CmpIPredicate::slt:
2382 if (!rhs.isNegative())
2383 pred = CmpIPredicate::sle;
2385 case CmpIPredicate::ugt:
2388 if (rhs.isNegative()) {
2394 case CmpIPredicate::sgt:
2397 if (rhs.isNegative())
2398 pred = CmpIPredicate::sge;
2400 case CmpIPredicate::uge:
2403 if (rhs.isNegative()) {
2408 pred = CmpIPredicate::ugt;
2410 case CmpIPredicate::sge:
2413 if (!rhs.isNegative())
2414 pred = CmpIPredicate::sgt;
2424 ConstantOp::create(rewriter, op.getLoc(), intVal.
getType(),
2446 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2462 arith::XOrIOp::create(
2463 rewriter, op.getLoc(), op.getCondition(),
2464 arith::ConstantIntOp::create(rewriter, op.getLoc(),
2465 op.getCondition().
getType(), 1)));
2475 results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2479 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2480 Value trueVal = getTrueValue();
2481 Value falseVal = getFalseValue();
2482 if (trueVal == falseVal)
2485 Value condition = getCondition();
2496 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2499 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2503 if (
getType().isSignlessInteger(1) &&
2509 auto pred = cmp.getPredicate();
2510 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2511 auto cmpLhs = cmp.getLhs();
2512 auto cmpRhs = cmp.getRhs();
2520 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2521 (cmpRhs == trueVal && cmpLhs == falseVal))
2522 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2529 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2531 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2533 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2535 results.reserve(
static_cast<size_t>(cond.getNumElements()));
2536 auto condVals = llvm::make_range(cond.value_begin<
BoolAttr>(),
2538 auto lhsVals = llvm::make_range(lhs.value_begin<
Attribute>(),
2540 auto rhsVals = llvm::make_range(rhs.value_begin<
Attribute>(),
2543 for (
auto [condVal, lhsVal, rhsVal] :
2544 llvm::zip_equal(condVals, lhsVals, rhsVals))
2545 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2556 Type conditionType, resultType;
2565 conditionType = resultType;
2574 {conditionType, resultType, resultType},
2579 p <<
" " << getOperands();
2582 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().
getType()))
2583 p << condType <<
", ";
2588 Type conditionType = getCondition().getType();
2595 if (!llvm::isa<TensorType, VectorType>(resultType))
2596 return emitOpError() <<
"expected condition to be a signless i1, but got "
2599 if (conditionType != shapedConditionType) {
2600 return emitOpError() <<
"expected condition type to have the same shape "
2601 "as the result type, expected "
2602 << shapedConditionType <<
", but got "
2611 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2616 bool bounded =
false;
2617 auto result = constFoldBinaryOp<IntegerAttr>(
2618 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2619 bounded = b.ult(b.getBitWidth());
2629 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2634 bool bounded =
false;
2635 auto result = constFoldBinaryOp<IntegerAttr>(
2636 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2637 bounded = b.ult(b.getBitWidth());
2647 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2652 bool bounded =
false;
2653 auto result = constFoldBinaryOp<IntegerAttr>(
2654 adaptor.getOperands(), [&](
const APInt &a,
const APInt &b) {
2655 bounded = b.ult(b.getBitWidth());
2668 bool useOnlyFiniteValue) {
2670 case AtomicRMWKind::maximumf: {
2671 const llvm::fltSemantics &semantic =
2672 llvm::cast<FloatType>(resultType).getFloatSemantics();
2673 APFloat identity = useOnlyFiniteValue
2674 ? APFloat::getLargest(semantic,
true)
2675 : APFloat::getInf(semantic,
true);
2678 case AtomicRMWKind::maxnumf: {
2679 const llvm::fltSemantics &semantic =
2680 llvm::cast<FloatType>(resultType).getFloatSemantics();
2681 APFloat identity = APFloat::getNaN(semantic,
true);
2684 case AtomicRMWKind::addf:
2685 case AtomicRMWKind::addi:
2686 case AtomicRMWKind::maxu:
2687 case AtomicRMWKind::ori:
2688 case AtomicRMWKind::xori:
2690 case AtomicRMWKind::andi:
2693 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2694 case AtomicRMWKind::maxs:
2696 resultType, APInt::getSignedMinValue(
2697 llvm::cast<IntegerType>(resultType).getWidth()));
2698 case AtomicRMWKind::minimumf: {
2699 const llvm::fltSemantics &semantic =
2700 llvm::cast<FloatType>(resultType).getFloatSemantics();
2701 APFloat identity = useOnlyFiniteValue
2702 ? APFloat::getLargest(semantic,
false)
2703 : APFloat::getInf(semantic,
false);
2707 case AtomicRMWKind::minnumf: {
2708 const llvm::fltSemantics &semantic =
2709 llvm::cast<FloatType>(resultType).getFloatSemantics();
2710 APFloat identity = APFloat::getNaN(semantic,
false);
2713 case AtomicRMWKind::mins:
2715 resultType, APInt::getSignedMaxValue(
2716 llvm::cast<IntegerType>(resultType).getWidth()));
2717 case AtomicRMWKind::minu:
2720 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2721 case AtomicRMWKind::muli:
2723 case AtomicRMWKind::mulf:
2735 std::optional<AtomicRMWKind> maybeKind =
2738 .Case([](arith::AddFOp op) {
return AtomicRMWKind::addf; })
2739 .Case([](arith::MulFOp op) {
return AtomicRMWKind::mulf; })
2740 .Case([](arith::MaximumFOp op) {
return AtomicRMWKind::maximumf; })
2741 .Case([](arith::MinimumFOp op) {
return AtomicRMWKind::minimumf; })
2742 .Case([](arith::MaxNumFOp op) {
return AtomicRMWKind::maxnumf; })
2743 .Case([](arith::MinNumFOp op) {
return AtomicRMWKind::minnumf; })
2745 .Case([](arith::AddIOp op) {
return AtomicRMWKind::addi; })
2746 .Case([](arith::OrIOp op) {
return AtomicRMWKind::ori; })
2747 .Case([](arith::XOrIOp op) {
return AtomicRMWKind::xori; })
2748 .Case([](arith::AndIOp op) {
return AtomicRMWKind::andi; })
2749 .Case([](arith::MaxUIOp op) {
return AtomicRMWKind::maxu; })
2750 .Case([](arith::MinUIOp op) {
return AtomicRMWKind::minu; })
2751 .Case([](arith::MaxSIOp op) {
return AtomicRMWKind::maxs; })
2752 .Case([](arith::MinSIOp op) {
return AtomicRMWKind::mins; })
2753 .Case([](arith::MulIOp op) {
return AtomicRMWKind::muli; })
2754 .Default([](
Operation *op) {
return std::nullopt; });
2756 return std::nullopt;
2759 bool useOnlyFiniteValue =
false;
2760 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2761 if (fmfOpInterface) {
2762 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2763 useOnlyFiniteValue =
2764 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2772 useOnlyFiniteValue);
2778 bool useOnlyFiniteValue) {
2781 return arith::ConstantOp::create(builder, loc, attr);
2789 case AtomicRMWKind::addf:
2790 return arith::AddFOp::create(builder, loc, lhs, rhs);
2791 case AtomicRMWKind::addi:
2792 return arith::AddIOp::create(builder, loc, lhs, rhs);
2793 case AtomicRMWKind::mulf:
2794 return arith::MulFOp::create(builder, loc, lhs, rhs);
2795 case AtomicRMWKind::muli:
2796 return arith::MulIOp::create(builder, loc, lhs, rhs);
2797 case AtomicRMWKind::maximumf:
2798 return arith::MaximumFOp::create(builder, loc, lhs, rhs);
2799 case AtomicRMWKind::minimumf:
2800 return arith::MinimumFOp::create(builder, loc, lhs, rhs);
2801 case AtomicRMWKind::maxnumf:
2802 return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
2803 case AtomicRMWKind::minnumf:
2804 return arith::MinNumFOp::create(builder, loc, lhs, rhs);
2805 case AtomicRMWKind::maxs:
2806 return arith::MaxSIOp::create(builder, loc, lhs, rhs);
2807 case AtomicRMWKind::mins:
2808 return arith::MinSIOp::create(builder, loc, lhs, rhs);
2809 case AtomicRMWKind::maxu:
2810 return arith::MaxUIOp::create(builder, loc, lhs, rhs);
2811 case AtomicRMWKind::minu:
2812 return arith::MinUIOp::create(builder, loc, lhs, rhs);
2813 case AtomicRMWKind::ori:
2814 return arith::OrIOp::create(builder, loc, lhs, rhs);
2815 case AtomicRMWKind::andi:
2816 return arith::AndIOp::create(builder, loc, lhs, rhs);
2817 case AtomicRMWKind::xori:
2818 return arith::XOrIOp::create(builder, loc, lhs, rhs);
2831 #define GET_OP_CLASSES
2832 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2838 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
union mlir::linalg::@1252::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Specialization of arith.constant op that returns an integer of index type.
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Specialization of arith.constant op that returns an integer value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)